package cc.redberry.core.transformations;

import cc.redberry.core.indexgenerator.IndexGenerator;
import cc.redberry.core.indexmapping.Mapping;
import cc.redberry.core.indices.IndicesFactory;
import cc.redberry.core.indices.IndicesSymmetries;
import cc.redberry.core.indices.IndicesUtils;
import cc.redberry.core.indices.SimpleIndices;
import cc.redberry.core.number.Complex;
import cc.redberry.core.tensor.ApplyIndexMapping;
import cc.redberry.core.tensor.Power;
import cc.redberry.core.tensor.Product;
import cc.redberry.core.tensor.ProductBuilder;
import cc.redberry.core.tensor.SimpleTensor;
import cc.redberry.core.tensor.Sum;
import cc.redberry.core.tensor.SumBuilder;
import cc.redberry.core.tensor.Tensor;
import cc.redberry.core.tensor.TensorField;
import cc.redberry.core.tensor.Tensors;
import cc.redberry.core.tensor.functions.ScalarFunction;
import cc.redberry.core.transformations.substitutions.SubstitutionTransformation;
import cc.redberry.core.transformations.symmetrization.SymmetrizeSimpleTensorTransformation;
import cc.redberry.core.utils.ArraysUtils;
import cc.redberry.core.utils.TensorUtils;
import gnu.trove.set.hash.TIntHashSet;
import java.util.Iterator;

/* loaded from: input_file:cc/redberry/core/transformations/DifferentiateTransformation.class */
public final class DifferentiateTransformation implements Transformation {
    private final SimpleTensor[] vars;
    private final Transformation[] expandAndContract;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:cc/redberry/core/transformations/DifferentiateTransformation$SimpleTensorDifferentiationRule.class */
    public static abstract class SimpleTensorDifferentiationRule {
        protected final SimpleTensor var;

        protected SimpleTensorDifferentiationRule(SimpleTensor simpleTensor) {
            this.var = simpleTensor;
        }

        Tensor differentiateSimpleTensor(SimpleTensor simpleTensor) {
            return simpleTensor.getName() != this.var.getName() ? Complex.ZERO : differentiateSimpleTensorWithoutCheck(simpleTensor);
        }

        abstract SimpleTensorDifferentiationRule newRuleForTensor(Tensor tensor);

        abstract Tensor differentiateSimpleTensorWithoutCheck(SimpleTensor simpleTensor);

        abstract int[] getForbidden();
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:cc/redberry/core/transformations/DifferentiateTransformation$SymbolicDifferentiationRule.class */
    public static final class SymbolicDifferentiationRule extends SimpleTensorDifferentiationRule {
        private SymbolicDifferentiationRule(SimpleTensor simpleTensor) {
            super(simpleTensor);
        }

        @Override // cc.redberry.core.transformations.DifferentiateTransformation.SimpleTensorDifferentiationRule
        Tensor differentiateSimpleTensorWithoutCheck(SimpleTensor simpleTensor) {
            return Complex.ONE;
        }

        @Override // cc.redberry.core.transformations.DifferentiateTransformation.SimpleTensorDifferentiationRule
        SimpleTensorDifferentiationRule newRuleForTensor(Tensor tensor) {
            return this;
        }

        @Override // cc.redberry.core.transformations.DifferentiateTransformation.SimpleTensorDifferentiationRule
        int[] getForbidden() {
            return new int[0];
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:cc/redberry/core/transformations/DifferentiateTransformation$SymmetricDifferentiationRule.class */
    public static final class SymmetricDifferentiationRule extends SimpleTensorDifferentiationRule {
        private final Tensor derivative;
        private final int[] allFreeFrom;
        private final int[] freeVarIndices;

        private SymmetricDifferentiationRule(SimpleTensor simpleTensor, Tensor tensor, int[] iArr, int[] iArr2) {
            super(simpleTensor);
            this.derivative = tensor;
            this.allFreeFrom = iArr;
            this.freeVarIndices = iArr2;
        }

        SymmetricDifferentiationRule(SimpleTensor simpleTensor) {
            super(simpleTensor);
            SimpleIndices indices = simpleTensor.getIndices();
            int[] iArr = new int[indices.size()];
            int[] iArr2 = new int[indices.size()];
            int length = iArr2.length;
            IndexGenerator indexGenerator = new IndexGenerator(indices);
            for (int i = 0; i < length; i++) {
                byte type = IndicesUtils.getType(indices.get(i));
                int rawStateInt = IndicesUtils.getRawStateInt(indices.get(i));
                iArr[i] = IndicesUtils.setRawState(indexGenerator.generate(type), IndicesUtils.inverseIndexState(rawStateInt));
                iArr2[i] = IndicesUtils.setRawState(indexGenerator.generate(type), rawStateInt);
            }
            int[] addAll = ArraysUtils.addAll(iArr, iArr2);
            SimpleTensor simpleTensor2 = Tensors.simpleTensor("@!@#@##_AS@23@@#", IndicesFactory.createSimple((IndicesSymmetries) null, addAll));
            Tensor applyIndexMapping = ApplyIndexMapping.applyIndexMapping(SymmetrizeSimpleTensorTransformation.symmetrize(simpleTensor2, iArr, indices.getSymmetries().getInnerSymmetries()), new Mapping(addAll, ArraysUtils.addAll(indices.getInverted().getAllIndices().copy(), iArr2)), new int[0]);
            ProductBuilder productBuilder = new ProductBuilder(0, length);
            for (int i2 = 0; i2 < length; i2++) {
                productBuilder.put(Tensors.createMetricOrKronecker(iArr2[i2], iArr[i2]));
            }
            this.derivative = new SubstitutionTransformation(simpleTensor2, productBuilder.build()).transform(applyIndexMapping);
            this.freeVarIndices = simpleTensor.getIndices().getFree().getInverted().getAllIndices().copy();
            this.allFreeFrom = ArraysUtils.addAll(iArr2, this.freeVarIndices);
        }

        @Override // cc.redberry.core.transformations.DifferentiateTransformation.SimpleTensorDifferentiationRule
        Tensor differentiateSimpleTensorWithoutCheck(SimpleTensor simpleTensor) {
            return ApplyIndexMapping.applyIndexMapping(this.derivative, new Mapping(this.allFreeFrom, ArraysUtils.addAll(simpleTensor.getIndices().getAllIndices().copy(), this.freeVarIndices)), new int[0]);
        }

        @Override // cc.redberry.core.transformations.DifferentiateTransformation.SimpleTensorDifferentiationRule
        SimpleTensorDifferentiationRule newRuleForTensor(Tensor tensor) {
            return new SymmetricDifferentiationRule(this.var, ApplyIndexMapping.renameDummy(this.derivative, TensorUtils.getAllIndicesNamesT(tensor).toArray()), this.allFreeFrom, this.freeVarIndices);
        }

        @Override // cc.redberry.core.transformations.DifferentiateTransformation.SimpleTensorDifferentiationRule
        int[] getForbidden() {
            return TensorUtils.getAllIndicesNamesT(this.derivative).toArray();
        }
    }

    public DifferentiateTransformation(SimpleTensor... simpleTensorArr) {
        this.vars = simpleTensorArr;
        this.expandAndContract = new Transformation[0];
    }

    public DifferentiateTransformation(Transformation[] transformationArr, SimpleTensor... simpleTensorArr) {
        this.vars = simpleTensorArr;
        this.expandAndContract = transformationArr;
    }

    @Override // cc.redberry.core.transformations.Transformation
    public Tensor transform(Tensor tensor) {
        return differentiate(tensor, this.expandAndContract, this.vars);
    }

    public static Tensor differentiate(Tensor tensor, SimpleTensor simpleTensor, int i) {
        if (simpleTensor.getIndices().size() != 0 && i > 1) {
            throw new IllegalArgumentException();
        }
        while (i > 0) {
            tensor = differentiate(tensor, new Transformation[0], simpleTensor);
            i--;
        }
        return tensor;
    }

    public static Tensor differentiate(Tensor tensor, SimpleTensor... simpleTensorArr) {
        return simpleTensorArr.length == 0 ? tensor : simpleTensorArr.length == 1 ? differentiate(tensor, new Transformation[0], simpleTensorArr[0]) : differentiate(tensor, new Transformation[0], simpleTensorArr);
    }

    public static Tensor differentiate(Tensor tensor, Transformation[] transformationArr, SimpleTensor... simpleTensorArr) {
        if (simpleTensorArr.length == 0) {
            return tensor;
        }
        if (simpleTensorArr.length == 1) {
            return differentiate(tensor, transformationArr, simpleTensorArr[0]);
        }
        boolean z = false;
        int length = simpleTensorArr.length;
        int i = 0;
        while (true) {
            if (i >= length) {
                break;
            }
            if (simpleTensorArr[i].getIndices().size() != 0) {
                z = true;
                break;
            }
            i++;
        }
        SimpleTensor[] simpleTensorArr2 = simpleTensorArr;
        if (z) {
            TIntHashSet allIndicesNamesT = TensorUtils.getAllIndicesNamesT(tensor);
            for (SimpleTensor simpleTensor : simpleTensorArr) {
                allIndicesNamesT.addAll(IndicesUtils.getIndicesNames(simpleTensor.getIndices().getFree()));
            }
            simpleTensorArr2 = (SimpleTensor[]) simpleTensorArr.clone();
            for (int i2 = 0; i2 < simpleTensorArr.length; i2++) {
                if (!allIndicesNamesT.isEmpty() && simpleTensorArr2[i2].getIndices().size() != 0) {
                    if (simpleTensorArr2[i2].getIndices().size() != simpleTensorArr2[i2].getIndices().getFree().size()) {
                        simpleTensorArr2[i2] = (SimpleTensor) ApplyIndexMapping.renameDummy(simpleTensorArr2[i2], allIndicesNamesT.toArray());
                    }
                    allIndicesNamesT.addAll(IndicesUtils.getIndicesNames(simpleTensorArr2[i2].getIndices()));
                }
            }
            tensor = ApplyIndexMapping.renameIndicesOfFieldsArguments(ApplyIndexMapping.renameDummy(tensor, TensorUtils.getAllIndicesNamesT(simpleTensorArr2).toArray(), allIndicesNamesT), allIndicesNamesT);
        }
        for (SimpleTensor simpleTensor2 : simpleTensorArr2) {
            tensor = differentiate1(tensor, createRule(simpleTensor2), transformationArr);
        }
        return tensor;
    }

    private static Tensor differentiate(Tensor tensor, Transformation[] transformationArr, SimpleTensor simpleTensor) {
        if (simpleTensor.getIndices().size() != 0) {
            TIntHashSet allIndicesNamesT = TensorUtils.getAllIndicesNamesT(tensor);
            simpleTensor = (SimpleTensor) ApplyIndexMapping.renameDummy(simpleTensor, TensorUtils.getAllIndicesNamesT(tensor).toArray());
            allIndicesNamesT.addAll(IndicesUtils.getIndicesNames(simpleTensor.getIndices()));
            tensor = ApplyIndexMapping.renameIndicesOfFieldsArguments(ApplyIndexMapping.renameDummy(tensor, TensorUtils.getAllIndicesNamesT(simpleTensor).toArray(), allIndicesNamesT), allIndicesNamesT);
        }
        return differentiate1(tensor, createRule(simpleTensor), transformationArr);
    }

    private static Tensor differentiateWithRenaming(Tensor tensor, SimpleTensorDifferentiationRule simpleTensorDifferentiationRule, Transformation[] transformationArr) {
        SimpleTensorDifferentiationRule newRuleForTensor = simpleTensorDifferentiationRule.newRuleForTensor(tensor);
        return differentiate1(ApplyIndexMapping.renameDummy(tensor, newRuleForTensor.getForbidden()), newRuleForTensor, transformationArr);
    }

    private static Tensor differentiate1(Tensor tensor, SimpleTensorDifferentiationRule simpleTensorDifferentiationRule, Transformation[] transformationArr) {
        if (tensor.getClass() == SimpleTensor.class) {
            return applyTransformations(simpleTensorDifferentiationRule.differentiateSimpleTensor((SimpleTensor) tensor), transformationArr);
        }
        if (tensor.getClass() == TensorField.class) {
            TensorField tensorField = (TensorField) tensor;
            SumBuilder sumBuilder = new SumBuilder(tensor.size());
            for (int size = tensor.size() - 1; size >= 0; size--) {
                Tensor differentiate1 = differentiate1(tensorField.get(size), simpleTensorDifferentiationRule, transformationArr);
                if (!TensorUtils.isZero(differentiate1)) {
                    sumBuilder.put(Tensors.multiply(differentiate1, Tensors.fieldDerivative(tensorField, tensorField.getArgIndices(size).getInverted(), size)));
                }
            }
            return applyTransformations(EliminateMetricsTransformation.eliminate(sumBuilder.build()), transformationArr);
        }
        if (tensor instanceof Sum) {
            SumBuilder sumBuilder2 = new SumBuilder();
            Iterator<Tensor> it = tensor.iterator();
            while (it.hasNext()) {
                sumBuilder2.put(applyTransformations(differentiate1(it.next(), simpleTensorDifferentiationRule, transformationArr), transformationArr));
            }
            return sumBuilder2.build();
        }
        if (tensor instanceof ScalarFunction) {
            return applyTransformations(Tensors.multiply(((ScalarFunction) tensor).derivative(), differentiateWithRenaming(tensor.get(0), simpleTensorDifferentiationRule, transformationArr)), transformationArr);
        }
        if (tensor instanceof Power) {
            return applyTransformations(Tensors.sum(Tensors.multiply(tensor.get(1), Tensors.pow(tensor.get(0), Tensors.sum(tensor.get(1), Complex.MINUS_ONE)), differentiate1(tensor.get(0), simpleTensorDifferentiationRule, transformationArr)), Tensors.multiply(tensor, Tensors.log(tensor.get(0)), differentiateWithRenaming(tensor.get(1), simpleTensorDifferentiationRule, transformationArr))), transformationArr);
        }
        if (!(tensor instanceof Product)) {
            if (tensor instanceof Complex) {
                return Complex.ZERO;
            }
            throw new UnsupportedOperationException();
        }
        SumBuilder sumBuilder3 = new SumBuilder();
        for (int size2 = tensor.size() - 1; size2 >= 0; size2--) {
            Tensor tensor2 = tensor.set(size2, differentiate1(tensor.get(size2), simpleTensorDifferentiationRule, transformationArr));
            if (simpleTensorDifferentiationRule.var.getIndices().size() != 0) {
                tensor2 = EliminateMetricsTransformation.eliminate(tensor2);
            }
            sumBuilder3.put(applyTransformations(tensor2, transformationArr));
        }
        return sumBuilder3.build();
    }

    private static Tensor applyTransformations(Tensor tensor, Transformation[] transformationArr) {
        for (Transformation transformation : transformationArr) {
            tensor = transformation.transform(tensor);
        }
        return tensor;
    }

    private static SimpleTensorDifferentiationRule createRule(SimpleTensor simpleTensor) {
        return simpleTensor.getIndices().size() == 0 ? new SymbolicDifferentiationRule(simpleTensor) : new SymmetricDifferentiationRule(simpleTensor);
    }
}
