package cc.redberry.core.solver;

import cc.redberry.core.context.CC;
import cc.redberry.core.indices.IndexType;
import cc.redberry.core.indices.IndicesFactory;
import cc.redberry.core.indices.IndicesSymmetries;
import cc.redberry.core.indices.IndicesUtils;
import cc.redberry.core.indices.SimpleIndices;
import cc.redberry.core.number.Complex;
import cc.redberry.core.tensor.Expression;
import cc.redberry.core.tensor.SimpleTensor;
import cc.redberry.core.tensor.Split;
import cc.redberry.core.tensor.Sum;
import cc.redberry.core.tensor.Tensor;
import cc.redberry.core.tensor.Tensors;
import cc.redberry.core.tensorgenerator.GeneratedTensor;
import cc.redberry.core.tensorgenerator.TensorGenerator;
import cc.redberry.core.tensorgenerator.TensorGeneratorUtils;
import cc.redberry.core.transformations.CollectNonScalarsTransformation;
import cc.redberry.core.transformations.EliminateMetricsTransformation;
import cc.redberry.core.transformations.Transformation;
import cc.redberry.core.transformations.TransformationCollection;
import cc.redberry.core.transformations.expand.ExpandTransformation;
import cc.redberry.core.utils.TensorUtils;
import gnu.trove.set.hash.TIntHashSet;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;

/* loaded from: input_file:cc/redberry/core/solver/ReduceEngine.class */
public final class ReduceEngine {
    private static final int ITERATION_LIMIT = 10000;

    private ReduceEngine() {
    }

    public static ReducedSystem reduceToSymbolicSystem(Expression[] expressionArr, SimpleTensor[] simpleTensorArr, Transformation[] transformationArr) {
        return reduceToSymbolicSystem(expressionArr, simpleTensorArr, transformationArr, new boolean[simpleTensorArr.length]);
    }

    public static ReducedSystem reduceToSymbolicSystem(Expression[] expressionArr, SimpleTensor[] simpleTensorArr, Transformation[] transformationArr, boolean[] zArr) {
        int i;
        Tensor[] tensorArr = new Tensor[expressionArr.length];
        for (int length = expressionArr.length - 1; length >= 0; length--) {
            tensorArr[length] = Tensors.subtract(expressionArr[length].get(0), expressionArr[length].get(1));
            tensorArr[length] = ExpandTransformation.expand(tensorArr[length], EliminateMetricsTransformation.ELIMINATE_METRICS);
            tensorArr[length] = EliminateMetricsTransformation.eliminate(tensorArr[length]);
        }
        TIntHashSet tIntHashSet = new TIntHashSet(simpleTensorArr.length);
        for (SimpleTensor simpleTensor : simpleTensorArr) {
            tIntHashSet.add(simpleTensor.getName());
        }
        Tensor[] samples = getSamples(tensorArr, tIntHashSet);
        if (samples.length == 0) {
            for (SimpleTensor simpleTensor2 : simpleTensorArr) {
                if (simpleTensor2.getIndices().size() != 0) {
                    return null;
                }
            }
        }
        Expression[] expressionArr2 = new Expression[simpleTensorArr.length];
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < expressionArr2.length; i2++) {
            if (simpleTensorArr[i2].getIndices().size() == 0) {
                SimpleTensor generateNewSymbol = CC.generateNewSymbol();
                arrayList.add(generateNewSymbol);
                expressionArr2[i2] = Tensors.expression(simpleTensorArr[i2], generateNewSymbol);
            } else {
                GeneratedTensor generateStructure = TensorGenerator.generateStructure(simpleTensorArr[i2].getIndices(), samples, zArr[i2], true, true);
                arrayList.ensureCapacity(generateStructure.coefficients.length);
                for (SimpleTensor simpleTensor3 : generateStructure.coefficients) {
                    arrayList.add(simpleTensor3);
                }
                expressionArr2[i2] = Tensors.expression(simpleTensorArr[i2], generateStructure.generatedTensor);
            }
        }
        ArrayList arrayList2 = new ArrayList(Arrays.asList(transformationArr));
        arrayList2.add(0, EliminateMetricsTransformation.ELIMINATE_METRICS);
        TransformationCollection transformationCollection = new TransformationCollection(arrayList2);
        ArrayList arrayList3 = new ArrayList();
        for (Tensor tensor : tensorArr) {
            int i3 = ITERATION_LIMIT;
            do {
                for (Expression expression : expressionArr2) {
                    tensor = expression.transform(tensor);
                }
                tensor = CollectNonScalarsTransformation.collectNonScalars(transformationCollection.transform(ExpandTransformation.expand(tensor, transformationCollection)));
                if (!TensorUtils.containsSimpleTensors(tensor, tIntHashSet)) {
                    break;
                }
                i = i3;
                i3--;
            } while (i > 0);
            if (i3 <= 0) {
                throw new RuntimeException("Maximum number of iterations exceeded: the system cannot be reduced after 10 000 iterations.");
            }
            if (tensor.getIndices().size() == 0) {
                arrayList3.add(Tensors.expression(tensor, Complex.ZERO));
            } else if (tensor instanceof Sum) {
                Iterator<Tensor> it = tensor.iterator();
                while (it.hasNext()) {
                    arrayList3.add(Tensors.expression(Split.splitScalars(it.next()).summand, Complex.ZERO));
                }
            } else {
                arrayList3.add(Tensors.expression(Split.splitScalars(tensor).summand, Complex.ZERO));
            }
        }
        return new ReducedSystem((Expression[]) arrayList3.toArray(new Expression[arrayList3.size()]), (SimpleTensor[]) arrayList.toArray(new SimpleTensor[arrayList.size()]), expressionArr2);
    }

    private static Tensor[] getSamples(Tensor[] tensorArr, TIntHashSet tIntHashSet) {
        Collection<SimpleTensor> allDiffSimpleTensors = TensorUtils.getAllDiffSimpleTensors(tensorArr);
        ArrayList arrayList = new ArrayList(allDiffSimpleTensors.size() + 1);
        HashSet hashSet = new HashSet();
        for (SimpleTensor simpleTensor : allDiffSimpleTensors) {
            if (!tIntHashSet.contains(simpleTensor.getName()) && simpleTensor.getIndices().size() != 0) {
                if (Tensors.isKroneckerOrMetric(simpleTensor)) {
                    hashSet.add(IndicesUtils.getTypeEnum(simpleTensor.getIndices().get(0)));
                } else {
                    SimpleIndices indices = simpleTensor.getIndices();
                    int[] iArr = new int[indices.size()];
                    for (int size = indices.size() - 1; size >= 0; size--) {
                        hashSet.add(IndicesUtils.getTypeEnum(indices.get(size)));
                        iArr[size] = IndicesUtils.createIndex(size, IndicesUtils.getType(indices.get(size)), IndicesUtils.getState(indices.get(size)));
                    }
                    arrayList.addAll(Arrays.asList(TensorGeneratorUtils.allStatesCombinations(Tensors.setIndices(simpleTensor, IndicesFactory.createSimple((IndicesSymmetries) null, indices)))));
                }
            }
        }
        Iterator it = hashSet.iterator();
        while (it.hasNext()) {
            byte type = ((IndexType) it.next()).getType();
            arrayList.add(Tensors.createKronecker(IndicesUtils.setType(type, 0), Integer.MIN_VALUE | IndicesUtils.setType(type, 1)));
            if (CC.isMetric(type)) {
                arrayList.add(Tensors.createMetric(IndicesUtils.setType(type, 0), IndicesUtils.setType(type, 1)));
                arrayList.add(Tensors.createMetric(Integer.MIN_VALUE | IndicesUtils.setType(type, 0), Integer.MIN_VALUE | IndicesUtils.setType(type, 1)));
            }
        }
        return (Tensor[]) arrayList.toArray(new Tensor[arrayList.size()]);
    }
}
