package cc.redberry.transformation.integral;

import cc.redberry.core.indexmapping.IndexMappingDirect;
import cc.redberry.core.indices.IndicesTypeStructure;
import cc.redberry.core.indices.SimpleIndices;
import cc.redberry.core.tensor.Integral;
import cc.redberry.core.tensor.SimpleTensor;
import cc.redberry.core.tensor.Sum;
import cc.redberry.core.tensor.Tensor;
import cc.redberry.core.tensor.TensorIterator;
import cc.redberry.core.transformations.ApplyIndexMappingDirectTransformation;
import cc.redberry.transformation.Transformation;
import cc.redberry.transformation.substitutions.SubstitutionsFactory;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;

/* loaded from: input_file:cc/redberry/transformation/integral/CollectIntegralFromSum.class */
public class CollectIntegralFromSum implements Transformation {
    public static CollectIntegralFromSum INSTANCE = new CollectIntegralFromSum();

    /* loaded from: input_file:cc/redberry/transformation/integral/CollectIntegralFromSum$IntegralIP.class */
    private static class IntegralIP {
        List<Integral> integrals;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:cc/redberry/transformation/integral/CollectIntegralFromSum$IntegralIP$IndicesComparator.class */
        public static class IndicesComparator implements Comparator<SimpleTensor> {
            static IndicesComparator INSTANCE = new IndicesComparator();

            private IndicesComparator() {
            }

            @Override // java.util.Comparator
            public int compare(SimpleTensor simpleTensor, SimpleTensor simpleTensor2) {
                SimpleIndices indices = simpleTensor.getIndices();
                SimpleIndices indices2 = simpleTensor2.getIndices();
                if (indices.size() < indices2.size()) {
                    return -1;
                }
                if (indices.size() > indices2.size()) {
                    return 1;
                }
                return Integer.compare(new IndicesTypeStructure(indices).hashCode(), new IndicesTypeStructure(indices2).hashCode());
            }
        }

        private IntegralIP() {
            this.integrals = new ArrayList();
        }

        void put(Integral integral) {
            for (Integral integral2 : this.integrals) {
                Transformation[] subs = getSubs(integral.vars(), integral2.vars());
                if (subs != null) {
                    Tensor target = integral.target();
                    for (Transformation transformation : subs) {
                        target = transformation.transform(target);
                    }
                    integral2.setTarget(new Sum(integral2.target(), target));
                    return;
                }
            }
            this.integrals.add(integral);
        }

        Transformation[] getSubs(SimpleTensor[] simpleTensorArr, SimpleTensor[] simpleTensorArr2) {
            int length = simpleTensorArr.length;
            if (length != simpleTensorArr2.length) {
                return null;
            }
            Arrays.sort(simpleTensorArr, IndicesComparator.INSTANCE);
            Arrays.sort(simpleTensorArr2, IndicesComparator.INSTANCE);
            Transformation[] transformationArr = new Transformation[length];
            for (int i = 0; i < length; i++) {
                SimpleIndices indices = simpleTensorArr[i].getIndices();
                SimpleIndices indices2 = simpleTensorArr2[i].getIndices();
                if (!indices.similarTypeStructure(indices2)) {
                    return null;
                }
                ApplyIndexMappingDirectTransformation.INSTANCE.perform(simpleTensorArr[i], new IndexMappingDirect(indices, indices2));
                transformationArr[i] = SubstitutionsFactory.createSubstitution(simpleTensorArr[i], simpleTensorArr2[i]);
            }
            return transformationArr;
        }

        Tensor result() {
            return new Sum(this.integrals);
        }
    }

    private CollectIntegralFromSum() {
    }

    @Override // cc.redberry.transformation.Transformation
    public Tensor transform(Tensor tensor) {
        if (!(tensor instanceof Sum)) {
            return tensor;
        }
        IntegralIP integralIP = new IntegralIP();
        TensorIterator it = tensor.iterator();
        while (it.hasNext()) {
            Tensor next = it.next();
            if (next instanceof Integral) {
                integralIP.put((Integral) next);
                it.remove();
            }
        }
        ((Sum) tensor).add(integralIP.result());
        return tensor.equivalent();
    }
}
