package cc.redberry.core.transformations.collect;

import cc.redberry.core.context.defaults.LatinLowerCaseConverter;
import cc.redberry.core.groups.permutations.Permutations;
import cc.redberry.core.indexgenerator.IndexGeneratorImpl;
import cc.redberry.core.indexmapping.IndexMapping;
import cc.redberry.core.indexmapping.IndexMappings;
import cc.redberry.core.indices.Indices;
import cc.redberry.core.indices.IndicesBuilder;
import cc.redberry.core.indices.IndicesFactory;
import cc.redberry.core.indices.IndicesSymmetries;
import cc.redberry.core.indices.IndicesUtils;
import cc.redberry.core.indices.SimpleIndices;
import cc.redberry.core.number.Complex;
import cc.redberry.core.tensor.ApplyIndexMapping;
import cc.redberry.core.tensor.Expression;
import cc.redberry.core.tensor.Product;
import cc.redberry.core.tensor.SimpleTensor;
import cc.redberry.core.tensor.SumBuilder;
import cc.redberry.core.tensor.Tensor;
import cc.redberry.core.tensor.TensorField;
import cc.redberry.core.tensor.Tensors;
import cc.redberry.core.transformations.EliminateMetricsTransformation;
import cc.redberry.core.transformations.Transformation;
import cc.redberry.core.transformations.expand.ExpandPort;
import cc.redberry.core.transformations.options.Creator;
import cc.redberry.core.transformations.options.Option;
import cc.redberry.core.transformations.options.Options;
import cc.redberry.core.transformations.powerexpand.PowerUnfoldTransformation;
import cc.redberry.core.utils.ArraysUtils;
import cc.redberry.core.utils.IntArrayList;
import cc.redberry.core.utils.OutputPort;
import cc.redberry.core.utils.TensorUtils;
import gnu.trove.map.hash.TIntObjectHashMap;
import gnu.trove.set.hash.TIntHashSet;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;

/* loaded from: input_file:cc/redberry/core/transformations/collect/CollectTransformation.class */
public class CollectTransformation implements Transformation {
    private final TIntHashSet patternsNames;
    private final Transformation powerExpand;
    private final Transformation[] transformations;
    private final boolean expandSymbolic;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:cc/redberry/core/transformations/collect/CollectTransformation$CollectOptions.class */
    public static final class CollectOptions {

        @Option(name = "Simplifications", index = LatinLowerCaseConverter.TYPE)
        public Transformation simplifications = Transformation.IDENTITY;

        @Option(name = "ExpandSymbolic", index = 1)
        public boolean expandSymbolic = true;
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:cc/redberry/core/transformations/collect/CollectTransformation$DirectIndexMapping.class */
    public static abstract class DirectIndexMapping implements IndexMapping {
        final int[] from;
        final int[] to;

        private DirectIndexMapping(int[] iArr, int[] iArr2) {
            ArraysUtils.quickSort(iArr, iArr2);
            this.from = iArr;
            this.to = iArr2;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:cc/redberry/core/transformations/collect/CollectTransformation$Split.class */
    public static final class Split {
        final Tensor[] factors;
        final ArrayList<Tensor> summands;
        final int hashCode;
        final int[] forbidden;

        private Split(Tensor[] tensorArr, Tensor tensor) {
            this.summands = new ArrayList<>();
            this.factors = tensorArr;
            this.summands.add(tensor);
            Arrays.sort(tensorArr);
            this.hashCode = Arrays.hashCode(tensorArr);
            this.forbidden = IndicesUtils.getIndicesNames(new IndicesBuilder().append(tensorArr).getIndices());
        }

        public int hashCode() {
            return this.hashCode;
        }

        Tensor toTensor(Transformation[] transformationArr) {
            Tensor applySequentially = Transformation.Util.applySequentially(Tensors.sum((Tensor[]) this.summands.toArray(new Tensor[this.summands.size()])), transformationArr);
            Tensor[] tensorArr = new Tensor[this.factors.length + 1];
            tensorArr[tensorArr.length - 1] = applySequentially;
            System.arraycopy(this.factors, 0, tensorArr, 0, this.factors.length);
            return Tensors.multiply(tensorArr);
        }

        public String toString() {
            return Tensors.multiply(this.factors) + " : " + Tensors.sum((Tensor[]) this.summands.toArray(new Tensor[this.summands.size()]));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:cc/redberry/core/transformations/collect/CollectTransformation$StateSensitiveMapping.class */
    public static final class StateSensitiveMapping extends DirectIndexMapping {
        private StateSensitiveMapping(int[] iArr, int[] iArr2) {
            super(iArr, iArr2);
        }

        @Override // cc.redberry.core.indexmapping.IndexMapping
        public int map(int i) {
            int binarySearch = Arrays.binarySearch(this.from, i);
            return binarySearch >= 0 ? this.to[binarySearch] : i;
        }
    }

    public CollectTransformation(SimpleTensor[] simpleTensorArr, Transformation[] transformationArr, boolean z) {
        this.patternsNames = new TIntHashSet();
        this.powerExpand = new PowerUnfoldTransformation(simpleTensorArr);
        for (SimpleTensor simpleTensor : simpleTensorArr) {
            this.patternsNames.add(simpleTensor.getName());
        }
        this.transformations = transformationArr;
        this.expandSymbolic = z;
    }

    public CollectTransformation(SimpleTensor[] simpleTensorArr, Transformation[] transformationArr) {
        this(simpleTensorArr, transformationArr, true);
    }

    public CollectTransformation(SimpleTensor... simpleTensorArr) {
        this(simpleTensorArr, new Transformation[0]);
    }

    @Creator(vararg = true, hasArgs = true)
    public CollectTransformation(SimpleTensor[] simpleTensorArr, @Options CollectOptions collectOptions) {
        this(simpleTensorArr, new Transformation[]{collectOptions.simplifications}, collectOptions.expandSymbolic);
    }

    public CollectTransformation(SimpleTensor[] simpleTensorArr, boolean z) {
        this(simpleTensorArr, new Transformation[0], z);
    }

    @Override // cc.redberry.core.transformations.Transformation
    public Tensor transform(Tensor tensor) {
        return tensor instanceof Expression ? Transformation.Util.applyToEachChild(tensor, this) : transform1(tensor);
    }

    private Tensor transform1(Tensor tensor) {
        SumBuilder sumBuilder = new SumBuilder();
        TIntObjectHashMap tIntObjectHashMap = new TIntObjectHashMap();
        OutputPort<Tensor> createPort = ExpandPort.createPort(tensor, this.expandSymbolic);
        while (true) {
            Tensor take = createPort.take();
            if (take == null) {
                break;
            }
            Split split = split(take);
            if (split.factors.length != 0) {
                ArrayList arrayList = (ArrayList) tIntObjectHashMap.get(split.hashCode);
                if (arrayList != null) {
                    Iterator it = arrayList.iterator();
                    while (true) {
                        if (!it.hasNext()) {
                            arrayList.add(split);
                            break;
                        }
                        Split split2 = (Split) it.next();
                        int[] matchFactors = matchFactors(split2.factors, split.factors);
                        if (matchFactors != null) {
                            split2.summands.add(ApplyIndexMapping.applyIndexMappingAutomatically(split.summands.get(0), IndexMappings.createBijectiveProductPort((Tensor[]) Permutations.permute(split.factors, matchFactors), split2.factors).take(), split2.forbidden));
                            break;
                        }
                    }
                } else {
                    ArrayList arrayList2 = new ArrayList();
                    arrayList2.add(split);
                    tIntObjectHashMap.put(split.hashCode, arrayList2);
                }
            } else {
                sumBuilder.put(take);
            }
        }
        Tensor applySequentially = Transformation.Util.applySequentially(sumBuilder.build(), this.transformations);
        SumBuilder sumBuilder2 = new SumBuilder();
        sumBuilder2.put(applySequentially);
        Iterator it2 = tIntObjectHashMap.valueCollection().iterator();
        while (it2.hasNext()) {
            Iterator it3 = ((ArrayList) it2.next()).iterator();
            while (it3.hasNext()) {
                sumBuilder2.put(((Split) it3.next()).toTensor(this.transformations));
            }
        }
        return sumBuilder2.build();
    }

    private boolean match(Tensor tensor) {
        if (tensor instanceof SimpleTensor) {
            return this.patternsNames.contains(tensor.hashCode());
        }
        if (TensorUtils.isPositiveIntegerPower(tensor)) {
            return this.patternsNames.contains(tensor.get(0).hashCode());
        }
        return false;
    }

    private Split split(Tensor tensor) {
        Tensor[] tensorArr;
        Tensor tensor2;
        if ((tensor instanceof SimpleTensor) || TensorUtils.isPositiveIntegerPowerOfSimpleTensor(tensor)) {
            if (!match(tensor)) {
                return new Split(new Tensor[0], tensor);
            }
            tensorArr = new Tensor[]{tensor};
            tensor2 = Complex.ONE;
        } else {
            if (!(tensor instanceof Product) && !TensorUtils.isPositiveIntegerPowerOfProduct(tensor)) {
                return new Split(new Tensor[0], tensor);
            }
            tensor = this.powerExpand.transform(tensor);
            boolean z = false;
            Iterator<Tensor> it = (tensor instanceof Product ? tensor : tensor.get(0)).iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                if (match(it.next())) {
                    z = true;
                    break;
                }
            }
            if (!z) {
                return new Split(new Tensor[0], tensor);
            }
            if (!$assertionsDisabled && !(tensor instanceof Product)) {
                throw new AssertionError();
            }
            ArrayList arrayList = new ArrayList();
            tensor2 = tensor;
            Iterator<Tensor> it2 = tensor.iterator();
            while (it2.hasNext()) {
                Tensor next = it2.next();
                if (match(next)) {
                    arrayList.add(next);
                    if (!$assertionsDisabled && tensor2 == Complex.ONE) {
                        throw new AssertionError();
                    }
                    tensor2 = tensor2 instanceof Product ? ((Product) tensor2).remove(next) : Complex.ONE;
                }
            }
            tensorArr = (Tensor[]) arrayList.toArray(new Tensor[arrayList.size()]);
        }
        TIntHashSet tIntHashSet = new TIntHashSet(IndicesUtils.getIndicesNames(tensor.getIndices().getFree()));
        Indices indices = new IndicesBuilder().append(tensorArr).getIndices();
        TIntHashSet tIntHashSet2 = new TIntHashSet(IndicesUtils.getIntersections(indices.getUpper().toArray(), indices.getLower().toArray()));
        IntArrayList intArrayList = new IntArrayList();
        IntArrayList intArrayList2 = new IntArrayList();
        ArrayList arrayList2 = new ArrayList();
        IndexGeneratorImpl indexGeneratorImpl = new IndexGeneratorImpl(TensorUtils.getAllIndicesNamesT(tensor).toArray());
        for (int i = 0; i < tensorArr.length; i++) {
            intArrayList.clear();
            intArrayList2.clear();
            SimpleIndices createSimple = IndicesFactory.createSimple((IndicesSymmetries) null, tensorArr[i].getIndices());
            for (int size = createSimple.size() - 1; size >= 0; size--) {
                int i2 = createSimple.get(size);
                if (tIntHashSet.contains(IndicesUtils.getNameWithType(i2))) {
                    int rawState = IndicesUtils.setRawState(IndicesUtils.getRawStateInt(i2), indexGeneratorImpl.generate(IndicesUtils.getType(i2)));
                    intArrayList.add(i2);
                    intArrayList2.add(rawState);
                    arrayList2.add(Tensors.createKronecker(i2, IndicesUtils.inverseIndexState(rawState)));
                } else if (IndicesUtils.getState(i2) && tIntHashSet2.contains(IndicesUtils.getNameWithType(i2))) {
                    int rawState2 = IndicesUtils.setRawState(IndicesUtils.getRawStateInt(i2), indexGeneratorImpl.generate(IndicesUtils.getType(i2)));
                    intArrayList.add(i2);
                    intArrayList2.add(rawState2);
                    arrayList2.add(Tensors.createKronecker(i2, IndicesUtils.inverseIndexState(rawState2)));
                }
            }
            tensorArr[i] = applyDirectMapping(tensorArr[i], new StateSensitiveMapping(intArrayList.toArray(), intArrayList2.toArray()));
        }
        arrayList2.add(tensor2);
        return new Split(tensorArr, EliminateMetricsTransformation.eliminate(Tensors.multiply((Tensor[]) arrayList2.toArray(new Tensor[arrayList2.size()]))));
    }

    static int[] matchFactors(Tensor[] tensorArr, Tensor[] tensorArr2) {
        if (tensorArr.length != tensorArr2.length) {
            return null;
        }
        int i = 0;
        int length = tensorArr.length;
        int[] iArr = new int[length];
        Arrays.fill(iArr, -1);
        for (int i2 = 1; i2 <= length; i2++) {
            if (i2 == length || tensorArr[i2].hashCode() != tensorArr2[i2 - 1].hashCode()) {
                if (i2 - 1 != i) {
                    for (int i3 = i; i3 < i2; i3++) {
                        for (int i4 = i; i4 < i2; i4++) {
                            if (iArr[i4] == -1 && matchSimpleTensors(tensorArr[i3], tensorArr2[i4])) {
                                iArr[i4] = i3;
                            }
                        }
                        return null;
                    }
                }
                if (!matchSimpleTensors(tensorArr[i2 - 1], tensorArr2[i2 - 1])) {
                    return null;
                }
                iArr[i2 - 1] = i2 - 1;
                i = i2;
            }
        }
        return Permutations.inverse(iArr);
    }

    private static boolean matchSimpleTensors(Tensor tensor, Tensor tensor2) {
        if (tensor.getClass() != tensor2.getClass() || tensor.hashCode() != tensor2.hashCode()) {
            return false;
        }
        if (TensorUtils.isPositiveIntegerPowerOfSimpleTensor(tensor)) {
            return TensorUtils.isPositiveIntegerPowerOfSimpleTensor(tensor2) && tensor.get(1).equals(tensor2.get(1)) && matchSimpleTensors(tensor.get(0), tensor2.get(0));
        }
        if (!(tensor instanceof TensorField)) {
            return true;
        }
        for (int size = tensor.size() - 1; size >= 0; size--) {
            if (!IndexMappings.positiveMappingExists(tensor.get(size), tensor2.get(size))) {
                return false;
            }
        }
        return true;
    }

    private static Tensor applyDirectMapping(Tensor tensor, DirectIndexMapping directIndexMapping) {
        if (tensor instanceof SimpleTensor) {
            SimpleTensor simpleTensor = (SimpleTensor) tensor;
            SimpleIndices applyIndexMapping = simpleTensor.getIndices().applyIndexMapping((IndexMapping) directIndexMapping);
            return tensor instanceof TensorField ? Tensors.field(simpleTensor.getName(), applyIndexMapping, ((TensorField) simpleTensor).getArgIndices(), ((TensorField) simpleTensor).getArguments()) : Tensors.simpleTensor(simpleTensor.getName(), applyIndexMapping);
        }
        if ($assertionsDisabled || tensor.getIndices().size() == 0) {
            return tensor;
        }
        throw new AssertionError();
    }

    static {
        $assertionsDisabled = !CollectTransformation.class.desiredAssertionStatus();
    }
}
