package cc.redberry.core.tensorgenerator;

import cc.redberry.core.context.CC;
import cc.redberry.core.indexmapping.IndexMappings;
import cc.redberry.core.indexmapping.Mapping;
import cc.redberry.core.indices.Indices;
import cc.redberry.core.indices.IndicesFactory;
import cc.redberry.core.indices.IndicesSymmetries;
import cc.redberry.core.indices.SimpleIndices;
import cc.redberry.core.indices.StructureOfIndices;
import cc.redberry.core.number.Complex;
import cc.redberry.core.number.Rational;
import cc.redberry.core.solver.frobenius.FrobeniusSolver;
import cc.redberry.core.tensor.ApplyIndexMapping;
import cc.redberry.core.tensor.FastTensors;
import cc.redberry.core.tensor.Product;
import cc.redberry.core.tensor.SimpleTensor;
import cc.redberry.core.tensor.Sum;
import cc.redberry.core.tensor.SumBuilder;
import cc.redberry.core.tensor.Tensor;
import cc.redberry.core.tensor.TensorBuilder;
import cc.redberry.core.tensor.Tensors;
import cc.redberry.core.transformations.expand.ExpandTransformation;
import cc.redberry.core.transformations.symmetrization.SymmetrizeTransformation;
import cc.redberry.core.transformations.symmetrization.SymmetrizeUpperLowerIndicesTransformation;
import cc.redberry.core.utils.TensorUtils;
import gnu.trove.map.hash.TIntObjectHashMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.math3.util.ArithmeticUtils;

/* loaded from: input_file:cc/redberry/core/tensorgenerator/TensorGenerator.class */
public class TensorGenerator {
    private final Tensor[] samples;
    private final int[] lowerArray;
    private final int[] upperArray;
    private final List<SimpleTensor> coefficients = new ArrayList();
    private final boolean symmetricForm;
    private final SimpleIndices indices;
    private Tensor result;
    private final boolean withCoefficients;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:cc/redberry/core/tensorgenerator/TensorGenerator$Wrapper.class */
    public static class Wrapper {
        private final Tensor tensor;
        private final StructureOfIndices freeIndices;

        private Wrapper(Tensor tensor) {
            this.tensor = tensor;
            this.freeIndices = StructureOfIndices.create(IndicesFactory.createSimple((IndicesSymmetries) null, tensor.getIndices().getFree()));
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            Wrapper wrapper = (Wrapper) obj;
            return this.freeIndices.equals(wrapper.freeIndices) && IndexMappings.anyMappingExists(this.tensor, wrapper.tensor);
        }

        public int hashCode() {
            return this.tensor.hashCode();
        }
    }

    private TensorGenerator(SimpleIndices simpleIndices, Tensor[] tensorArr, boolean z, boolean z2, boolean z3) {
        if (z3) {
            this.samples = expandSamples(tensorArr);
        } else {
            this.samples = tensorArr;
        }
        this.indices = simpleIndices;
        this.symmetricForm = z;
        this.lowerArray = simpleIndices.getLower().toArray();
        this.upperArray = simpleIndices.getUpper().toArray();
        this.withCoefficients = z2;
        Arrays.sort(this.lowerArray);
        Arrays.sort(this.upperArray);
        generate();
    }

    /* JADX WARN: Type inference failed for: r2v3, types: [int[], int[][]] */
    private void generate() {
        Tensor tensor;
        int length = this.lowerArray.length;
        int[] iArr = new int[this.samples.length + 1];
        int i = 0;
        while (i < this.samples.length) {
            iArr[i] = this.samples[i].getIndices().getFree().getLower().size();
            i++;
        }
        iArr[i] = length;
        int length2 = this.upperArray.length;
        int[] iArr2 = new int[this.samples.length + 1];
        int i2 = 0;
        while (i2 < this.samples.length) {
            iArr2[i2] = this.samples[i2].getIndices().getFree().getUpper().size();
            i2++;
        }
        iArr2[i2] = length2;
        FrobeniusSolver frobeniusSolver = new FrobeniusSolver(new int[]{iArr, iArr2});
        SumBuilder sumBuilder = new SumBuilder();
        while (true) {
            int[] take = frobeniusSolver.take();
            if (take == null) {
                break;
            }
            ArrayList arrayList = new ArrayList();
            int i3 = 0;
            int i4 = 0;
            for (int i5 = 0; i5 < take.length; i5++) {
                for (int i6 = 0; i6 < take[i5]; i6++) {
                    Tensor tensor2 = this.samples[i5];
                    Indices lower = tensor2.getIndices().getFree().getLower();
                    Indices upper = tensor2.getIndices().getFree().getUpper();
                    int[] iArr3 = new int[upper.size() + lower.size()];
                    int[] iArr4 = (int[]) iArr3.clone();
                    for (int i7 = 0; i7 < upper.size(); i7++) {
                        iArr3[i7] = upper.get(i7);
                        int i8 = i3;
                        i3++;
                        iArr4[i7] = this.upperArray[i8];
                    }
                    for (int i9 = 0; i9 < lower.size(); i9++) {
                        iArr3[i9 + upper.size()] = lower.get(i9);
                        int i10 = i4;
                        i4++;
                        iArr4[i9 + upper.size()] = this.lowerArray[i10];
                    }
                    arrayList.add(ApplyIndexMapping.applyIndexMapping(tensor2, new Mapping(iArr3, iArr4), this.indices.getAllIndices().copy()));
                }
            }
            Tensor[] tensorArr = (Tensor[]) arrayList.toArray(new Tensor[arrayList.size()]);
            Tensors.resolveAllDummies(tensorArr);
            Tensor symmetrizeUpperLowerIndices = SymmetrizeUpperLowerIndicesTransformation.symmetrizeUpperLowerIndices(Tensors.multiplyAndRenameConflictingDummies(tensorArr));
            if (this.symmetricForm || !(symmetrizeUpperLowerIndices instanceof Sum)) {
                if (this.withCoefficients) {
                    tensor = CC.generateNewSymbol();
                    this.coefficients.add((SimpleTensor) tensor);
                } else {
                    tensor = Complex.ONE;
                }
                Tensor[] tensorArr2 = new Tensor[3];
                tensorArr2[0] = tensor;
                tensorArr2[1] = symmetrizeUpperLowerIndices;
                tensorArr2[2] = symmetrizeUpperLowerIndices instanceof Sum ? new Complex(new Rational(1, symmetrizeUpperLowerIndices.size())) : Complex.ONE;
                symmetrizeUpperLowerIndices = Tensors.multiply(tensorArr2);
            } else if (this.withCoefficients) {
                symmetrizeUpperLowerIndices = FastTensors.multiplySumElementsOnFactors((Sum) symmetrizeUpperLowerIndices);
            }
            sumBuilder.put(symmetrizeUpperLowerIndices);
        }
        this.result = this.indices.getSymmetries().isTrivial() ? sumBuilder.build() : symmetrize(sumBuilder.build());
    }

    private Tensor symmetrize(Tensor tensor) {
        Tensor expand = ExpandTransformation.expand(new SymmetrizeTransformation(this.indices, false).transform(tensor));
        if (!(expand instanceof Sum)) {
            return expand;
        }
        TIntObjectHashMap tIntObjectHashMap = new TIntObjectHashMap();
        Tensor tensor2 = null;
        TensorBuilder builder = expand.getBuilder();
        Iterator<Tensor> it = expand.iterator();
        while (it.hasNext()) {
            Tensor next = it.next();
            if (!$assertionsDisabled && !(next instanceof Product)) {
                throw new AssertionError();
            }
            if (next instanceof Product) {
                Tensor[] allScalarsWithoutFactor = ((Product) next).getAllScalarsWithoutFactor();
                if (allScalarsWithoutFactor.length == 0) {
                    continue;
                } else {
                    if (!$assertionsDisabled && allScalarsWithoutFactor.length != 1) {
                        throw new AssertionError();
                    }
                    Tensor tensor3 = allScalarsWithoutFactor[0];
                    List list = (List) tIntObjectHashMap.get(tensor3.hashCode());
                    if (list == null) {
                        list = new ArrayList();
                        tIntObjectHashMap.put(tensor3.hashCode(), list);
                    }
                    Mapping mapping = null;
                    Iterator it2 = list.iterator();
                    while (true) {
                        if (!it2.hasNext()) {
                            break;
                        }
                        Tensor[] tensorArr = (Tensor[]) it2.next();
                        mapping = IndexMappings.getFirst(tensorArr[0], tensor3);
                        if (mapping != null) {
                            tensor2 = mapping.getSign() ? Tensors.negate(tensorArr[1]) : tensorArr[1];
                        }
                    }
                    if (mapping == null) {
                        if (tensor3 instanceof SimpleTensor) {
                            tensor2 = tensor3;
                        } else if (this.withCoefficients) {
                            tensor2 = CC.generateNewSymbol();
                            this.coefficients.add((SimpleTensor) tensor2);
                            this.coefficients.removeAll(TensorUtils.getAllSymbols(tensor3));
                        }
                        list.add(new Tensor[]{tensor3, tensor2});
                    }
                    builder.put(Tensors.multiply(((Product) next).getFactor(), tensor2, ((Product) next).getDataSubProduct()));
                }
            }
        }
        return builder.build();
    }

    private Tensor result() {
        return this.result;
    }

    public static Tensor generate(SimpleIndices simpleIndices, Tensor[] tensorArr, boolean z, boolean z2, boolean z3) {
        return new TensorGenerator(simpleIndices, tensorArr, z, z2, z3).result();
    }

    public static GeneratedTensor generateStructure(SimpleIndices simpleIndices, Tensor[] tensorArr, boolean z, boolean z2, boolean z3) {
        TensorGenerator tensorGenerator = new TensorGenerator(simpleIndices, tensorArr, z, z2, z3);
        return new GeneratedTensor((SimpleTensor[]) TensorUtils.getAllSymbols(tensorGenerator.result()).toArray(new SimpleTensor[0]), tensorGenerator.result());
    }

    private static Tensor[] expandSamples(Tensor[] tensorArr) {
        HashSet hashSet = new HashSet();
        for (Tensor tensor : tensorArr) {
            hashSet.add(new Wrapper(tensor));
        }
        ArrayList arrayList = new ArrayList();
        Iterator it = hashSet.iterator();
        while (it.hasNext()) {
            Wrapper wrapper = (Wrapper) it.next();
            arrayList.ensureCapacity(ArithmeticUtils.pow(2, wrapper.tensor.getIndices().getFree().size()));
            arrayList.addAll(Arrays.asList(TensorGeneratorUtils.allStatesCombinations(wrapper.tensor)));
        }
        return (Tensor[]) arrayList.toArray(new Tensor[arrayList.size()]);
    }

    static {
        $assertionsDisabled = !TensorGenerator.class.desiredAssertionStatus();
    }
}
