package cc.redberry.core.transformations.substitutions;

import cc.redberry.core.context.NameDescriptorForTensorField;
import cc.redberry.core.indexgenerator.IndexGenerator;
import cc.redberry.core.indexmapping.IndexMappings;
import cc.redberry.core.indexmapping.Mapping;
import cc.redberry.core.indices.IndicesUtils;
import cc.redberry.core.indices.SimpleIndices;
import cc.redberry.core.indices.UnsafeIndicesFactory;
import cc.redberry.core.tensor.ApplyIndexMapping;
import cc.redberry.core.tensor.SimpleTensor;
import cc.redberry.core.tensor.Tensor;
import cc.redberry.core.tensor.TensorField;
import cc.redberry.core.tensor.Tensors;
import cc.redberry.core.transformations.DifferentiateTransformation;
import cc.redberry.core.utils.IntArray;
import cc.redberry.core.utils.TensorUtils;
import gnu.trove.set.hash.TIntHashSet;
import java.util.ArrayList;
import java.util.HashMap;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:cc/redberry/core/transformations/substitutions/PrimitiveTensorFieldSubstitution.class */
public class PrimitiveTensorFieldSubstitution extends PrimitiveSubstitution {
    private NameDescriptorForTensorField fromDescriptor;
    private final IntArray orders;
    private final HashMap<IntArray, DFromTo> derivatives;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:cc/redberry/core/transformations/substitutions/PrimitiveTensorFieldSubstitution$DFromTo.class */
    public static class DFromTo {
        final TensorField from;
        final Tensor to;

        private DFromTo(TensorField tensorField, Tensor tensor) {
            this.from = tensorField;
            this.to = tensor;
        }
    }

    public PrimitiveTensorFieldSubstitution(Tensor tensor, Tensor tensor2) {
        super(tensor, tensor2);
        this.derivatives = new HashMap<>();
        this.fromDescriptor = ((TensorField) tensor).getNameDescriptor();
        this.orders = new IntArray(this.fromDescriptor.getDerivativeOrders());
        this.derivatives.put(this.orders, new DFromTo((TensorField) tensor, tensor2));
    }

    @Override // cc.redberry.core.transformations.substitutions.PrimitiveSubstitution
    Tensor newTo_(Tensor tensor, SubstitutionIterator substitutionIterator) {
        TensorField tensorField = (TensorField) tensor;
        NameDescriptorForTensorField nameDescriptor = tensorField.getNameDescriptor();
        if (nameDescriptor.getParent().getId() != this.fromDescriptor.getParent().getId()) {
            return tensor;
        }
        for (int size = tensor.size() - 1; size >= 0; size--) {
            if (nameDescriptor.getDerivativeOrder(size) < this.fromDescriptor.getDerivativeOrder(size)) {
                return tensor;
            }
        }
        IntArray intArray = new IntArray(nameDescriptor.getDerivativeOrders());
        DFromTo dFromTo = this.derivatives.get(intArray);
        if (dFromTo == null) {
            TensorField tensorField2 = (TensorField) this.from;
            Tensor tensor2 = this.to;
            IndexGenerator indexGenerator = null;
            for (int length = intArray.length() - 1; length >= 0; length--) {
                for (int i = intArray.get(length) - this.orders.get(length); i > 0; i--) {
                    SimpleTensor simpleTensor = (SimpleTensor) this.from.get(length);
                    int[] iArr = new int[simpleTensor.getIndices().size()];
                    if (iArr.length != 0 && indexGenerator == null) {
                        TIntHashSet tIntHashSet = new TIntHashSet(substitutionIterator.getForbidden());
                        tIntHashSet.addAll(TensorUtils.getAllIndicesNamesT(this.from));
                        tIntHashSet.addAll(TensorUtils.getAllIndicesNamesT(this.to));
                        indexGenerator = new IndexGenerator(tIntHashSet.toArray());
                    }
                    for (int length2 = iArr.length - 1; length2 >= 0; length2--) {
                        iArr[length2] = IndicesUtils.setRawState(IndicesUtils.getRawStateInt(simpleTensor.getIndices().get(length2)), indexGenerator.generate(IndicesUtils.getType(simpleTensor.getIndices().get(length2))));
                    }
                    SimpleIndices createIsolatedUnsafeWithoutSort = UnsafeIndicesFactory.createIsolatedUnsafeWithoutSort(null, iArr);
                    SimpleTensor indices = Tensors.setIndices(simpleTensor, createIsolatedUnsafeWithoutSort);
                    tensorField2 = Tensors.fieldDerivative(tensorField2, createIsolatedUnsafeWithoutSort.getInverted(), length);
                    tensor2 = new DifferentiateTransformation(indices).transform(tensor2);
                }
            }
            dFromTo = new DFromTo(tensorField2, tensor2);
            this.derivatives.put(intArray, dFromTo);
        }
        return __newTo(dFromTo, tensorField, tensor, substitutionIterator);
    }

    private Tensor __newTo(DFromTo dFromTo, TensorField tensorField, Tensor tensor, SubstitutionIterator substitutionIterator) {
        TensorField tensorField2 = dFromTo.from;
        Mapping take = IndexMappings.simpleTensorsPort(tensorField2, tensorField).take();
        if (take == null) {
            return tensor;
        }
        SimpleIndices[] argIndices = tensorField2.getArgIndices();
        SimpleIndices[] argIndices2 = tensorField.getArgIndices();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int size = tensorField2.size() - 1; size >= 0; size--) {
            if (!IndexMappings.positiveMappingExists(tensor.get(size), tensorField2.get(size))) {
                int[] copy = argIndices[size].getAllIndices().copy();
                int[] copy2 = argIndices2[size].getAllIndices().copy();
                if (!$assertionsDisabled && copy2.length != copy.length) {
                    throw new AssertionError();
                }
                arrayList.add(ApplyIndexMapping.applyIndexMapping(tensorField2.get(size), new Mapping(copy, copy2), new int[0]));
                arrayList2.add(tensor.get(size));
            }
        }
        Tensor transform = new SubstitutionTransformation((Tensor[]) arrayList.toArray(new Tensor[arrayList.size()]), (Tensor[]) arrayList2.toArray(new Tensor[arrayList2.size()]), false).transform(dFromTo.to);
        if (!TensorUtils.isSymbolic(transform)) {
            transform = ApplyIndexMapping.applyIndexMapping(transform, take, substitutionIterator.getForbidden());
        } else if (take.getSign()) {
            transform = Tensors.negate(transform);
        }
        return transform;
    }

    static {
        $assertionsDisabled = !PrimitiveTensorFieldSubstitution.class.desiredAssertionStatus();
    }
}
