package cc.redberry.core.transformations.substitutions;

import cc.redberry.core.number.Complex;
import cc.redberry.core.tensor.Expression;
import cc.redberry.core.tensor.Product;
import cc.redberry.core.tensor.SimpleTensor;
import cc.redberry.core.tensor.Sum;
import cc.redberry.core.tensor.Tensor;
import cc.redberry.core.tensor.TensorField;
import cc.redberry.core.tensor.Tensors;
import cc.redberry.core.transformations.Transformation;
import cc.redberry.core.utils.TensorUtils;
import java.util.Iterator;

/* loaded from: input_file:cc/redberry/core/transformations/substitutions/SubstitutionTransformation.class */
public final class SubstitutionTransformation implements Transformation {
    private final PrimitiveSubstitution[] primitiveSubstitutions;
    private final boolean applyIfModified;

    private SubstitutionTransformation(PrimitiveSubstitution[] primitiveSubstitutionArr, boolean z) {
        this.primitiveSubstitutions = primitiveSubstitutionArr;
        this.applyIfModified = z;
    }

    public SubstitutionTransformation(Expression[] expressionArr, boolean z) {
        this.applyIfModified = z;
        this.primitiveSubstitutions = new PrimitiveSubstitution[expressionArr.length];
        for (int length = expressionArr.length - 1; length >= 0; length--) {
            this.primitiveSubstitutions[length] = createPrimitiveSubstitution(expressionArr[length].get(0), expressionArr[length].get(1));
        }
    }

    public SubstitutionTransformation(Expression expression) {
        this(expression.get(0), expression.get(1));
    }

    public SubstitutionTransformation(Expression... expressionArr) {
        this(expressionArr, expressionArr.length == 1 ? !TensorUtils.shareSimpleTensors(expressionArr[0].get(0), expressionArr[0].get(1)) : false);
    }

    public SubstitutionTransformation(Tensor tensor, Tensor tensor2, boolean z) {
        checkConsistence(tensor, tensor2);
        this.primitiveSubstitutions = new PrimitiveSubstitution[1];
        this.primitiveSubstitutions[0] = createPrimitiveSubstitution(tensor, tensor2);
        this.applyIfModified = z;
    }

    public SubstitutionTransformation(Tensor[] tensorArr, Tensor[] tensorArr2) {
        this(tensorArr, tensorArr2, tensorArr.length == 1 ? !TensorUtils.shareSimpleTensors(tensorArr[0], tensorArr2[0]) : false);
    }

    public SubstitutionTransformation(Tensor tensor, Tensor tensor2) {
        this(tensor, tensor2, !TensorUtils.shareSimpleTensors(tensor, tensor2));
    }

    public SubstitutionTransformation(Tensor[] tensorArr, Tensor[] tensorArr2, boolean z) {
        checkConsistence(tensorArr, tensorArr2);
        this.primitiveSubstitutions = new PrimitiveSubstitution[tensorArr.length];
        for (int i = 0; i < tensorArr.length; i++) {
            this.primitiveSubstitutions[i] = createPrimitiveSubstitution(tensorArr[i], tensorArr2[i]);
        }
        this.applyIfModified = z;
    }

    public SubstitutionTransformation asSimpleSubstitution() {
        SubstitutionTransformation substitutionTransformation = new SubstitutionTransformation((PrimitiveSubstitution[]) this.primitiveSubstitutions.clone(), this.applyIfModified);
        for (int length = this.primitiveSubstitutions.length - 1; length >= 0; length--) {
            substitutionTransformation.primitiveSubstitutions[length] = new PrimitiveSimpleTensorSubstitution(substitutionTransformation.primitiveSubstitutions[length].from, substitutionTransformation.primitiveSubstitutions[length].to);
        }
        return substitutionTransformation;
    }

    private static void checkConsistence(Tensor[] tensorArr, Tensor[] tensorArr2) {
        if (tensorArr.length != tensorArr2.length) {
            throw new IllegalArgumentException("from array and to array have different length.");
        }
        for (int length = tensorArr.length - 1; length >= 0; length--) {
            checkConsistence(tensorArr[length], tensorArr2[length]);
        }
    }

    private static void checkConsistence(Tensor tensor, Tensor tensor2) {
        if (!TensorUtils.isZeroOrIndeterminate(tensor2) && !tensor.getIndices().getFree().equalsRegardlessOrder(tensor2.getIndices().getFree())) {
            throw new IllegalArgumentException("Tensor from free indices not equal to tensor to free indices: " + tensor.getIndices().getFree() + "  " + tensor2.getIndices().getFree());
        }
    }

    private static PrimitiveSubstitution createPrimitiveSubstitution(Tensor tensor, Tensor tensor2) {
        if (tensor.getClass() == SimpleTensor.class) {
            return new PrimitiveSimpleTensorSubstitution(tensor, tensor2);
        }
        if (tensor.getClass() != TensorField.class) {
            if (tensor.getClass() != Product.class) {
                return tensor.getClass() == Sum.class ? new PrimitiveSumSubstitution(tensor, tensor2) : new PrimitiveSimpleTensorSubstitution(tensor, tensor2);
            }
            if (tensor.size() == 2 && (tensor.get(0) instanceof Complex)) {
                return createPrimitiveSubstitution(tensor.get(1), Tensors.divide(tensor2, tensor.get(0)));
            }
            return new PrimitiveProductSubstitution(tensor, tensor2);
        }
        boolean z = false;
        Iterator<Tensor> it = tensor.iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            if (!(it.next() instanceof SimpleTensor)) {
                z = true;
                break;
            }
        }
        return z ? new PrimitiveSimpleTensorSubstitution(tensor, tensor2) : new PrimitiveTensorFieldSubstitution(tensor, tensor2);
    }

    @Override // cc.redberry.core.transformations.Transformation
    public Tensor transform(Tensor tensor) {
        SubstitutionIterator substitutionIterator = new SubstitutionIterator(tensor);
        while (true) {
            Tensor next = substitutionIterator.next();
            Tensor tensor2 = next;
            if (next == null) {
                return substitutionIterator.result();
            }
            if (this.applyIfModified || !substitutionIterator.isCurrentModified()) {
                Tensor tensor3 = tensor2;
                for (PrimitiveSubstitution primitiveSubstitution : this.primitiveSubstitutions) {
                    tensor2 = primitiveSubstitution.newTo(tensor3, substitutionIterator);
                    if (tensor2 != tensor3 && !this.applyIfModified) {
                        break;
                    }
                    tensor3 = tensor2;
                }
                substitutionIterator.set(tensor2);
            }
        }
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append('{');
        int i = 0;
        while (true) {
            PrimitiveSubstitution primitiveSubstitution = this.primitiveSubstitutions[i];
            sb.append(primitiveSubstitution.from).append(" -> ").append(primitiveSubstitution.to);
            if (i == this.primitiveSubstitutions.length - 1) {
                return sb.append('}').toString();
            }
            sb.append(',');
            i++;
        }
    }
}
