package cc.redberry.core.transformations.expand;

import cc.redberry.concurrent.OutputPort;
import cc.redberry.core.number.Complex;
import cc.redberry.core.tensor.ApplyIndexMapping;
import cc.redberry.core.tensor.FastTensors;
import cc.redberry.core.tensor.Power;
import cc.redberry.core.tensor.Product;
import cc.redberry.core.tensor.Sum;
import cc.redberry.core.tensor.SumBuilder;
import cc.redberry.core.tensor.Tensor;
import cc.redberry.core.tensor.Tensors;
import cc.redberry.core.transformations.Transformation;
import cc.redberry.core.utils.ArraysUtils;
import cc.redberry.core.utils.TensorUtils;
import gnu.trove.set.hash.TIntHashSet;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.concurrent.atomic.AtomicLong;

/* loaded from: input_file:cc/redberry/core/transformations/expand/ExpandUtils.class */
public final class ExpandUtils {
    public static final Transformation expandIndexlessSubproduct;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:cc/redberry/core/transformations/expand/ExpandUtils$ExpandPairPort.class */
    public static final class ExpandPairPort implements OutputPort<Tensor> {
        private final Tensor sum1;
        private final Tensor sum2;
        private final Tensor[] factors;
        private final AtomicLong atomicLong;

        public ExpandPairPort(Sum sum, Sum sum2) {
            this.atomicLong = new AtomicLong();
            this.sum1 = sum;
            this.sum2 = sum2;
            this.factors = new Tensor[0];
        }

        public ExpandPairPort(Sum sum, Sum sum2, Tensor[] tensorArr) {
            this.atomicLong = new AtomicLong();
            this.sum1 = sum;
            this.sum2 = sum2;
            this.factors = tensorArr;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // cc.redberry.concurrent.OutputPort
        public Tensor take() {
            long andIncrement = this.atomicLong.getAndIncrement();
            if (andIncrement >= this.sum1.size() * this.sum2.size()) {
                return null;
            }
            int size = (int) (andIncrement / this.sum2.size());
            int size2 = (int) (andIncrement % this.sum2.size());
            return this.factors.length == 0 ? Tensors.multiply(this.sum1.get(size), this.sum2.get(size2)) : Tensors.multiply(ArraysUtils.addAll(this.factors, this.sum1.get(size), this.sum2.get(size2)));
        }
    }

    public static Tensor expandPairOfSums(Sum sum, Sum sum2, Tensor[] tensorArr, Transformation[] transformationArr) {
        ExpandPairPort expandPairPort = new ExpandPairPort(sum, sum2, tensorArr);
        SumBuilder sumBuilder = new SumBuilder();
        while (true) {
            Tensor take = expandPairPort.take();
            if (take == null) {
                return sumBuilder.build();
            }
            sumBuilder.put(apply(transformationArr, take));
        }
    }

    public static Tensor expandPairOfSums(Sum sum, Sum sum2, Transformation[] transformationArr) {
        return expandPairOfSums(sum, sum2, new Tensor[0], transformationArr);
    }

    public static Tensor expandProductOfSums(Product product, Transformation[] transformationArr) {
        ArrayList arrayList;
        Tensor expandProductOfSums1;
        Tensor expandProductOfSums12;
        Tensor indexlessSubProduct = product.getIndexlessSubProduct();
        Tensor dataSubProduct = product.getDataSubProduct();
        boolean z = false;
        boolean z2 = false;
        boolean z3 = false;
        if ((indexlessSubProduct instanceof Sum) && sumContainsIndexed(indexlessSubProduct)) {
            z3 = true;
            z = true;
            z2 = true;
        }
        if (indexlessSubProduct instanceof Product) {
            Iterator<Tensor> it = indexlessSubProduct.iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                Tensor next = it.next();
                if (next instanceof Sum) {
                    if (sumContainsIndexed(next)) {
                        z3 = true;
                        z2 = true;
                        z = true;
                        break;
                    }
                    z = true;
                }
            }
        }
        if (!z2) {
            if (dataSubProduct instanceof Sum) {
                z2 = true;
            }
            if (dataSubProduct instanceof Product) {
                Iterator<Tensor> it2 = dataSubProduct.iterator();
                while (true) {
                    if (!it2.hasNext()) {
                        break;
                    }
                    if (it2.next() instanceof Sum) {
                        z2 = true;
                        break;
                    }
                }
            }
        }
        if (!z2 && !z) {
            return product;
        }
        if (!z2) {
            return Tensors.multiply(expandProductOfSums1(indexlessSubProduct, transformationArr, false), dataSubProduct);
        }
        if (!z) {
            Tensor expandProductOfSums13 = expandProductOfSums1(dataSubProduct, transformationArr, true);
            return expandProductOfSums13 instanceof Sum ? FastTensors.multiplySumElementsOnScalarFactorAndExpandScalars((Sum) expandProductOfSums13, indexlessSubProduct) : expandIndexlessSubproduct.transform(Tensors.multiply(indexlessSubProduct, expandProductOfSums13));
        }
        if (z3) {
            if (dataSubProduct instanceof Product) {
                arrayList = new ArrayList(Arrays.asList(dataSubProduct.toArray()));
            } else {
                arrayList = new ArrayList();
                arrayList.add(dataSubProduct);
            }
            if (indexlessSubProduct instanceof Sum) {
                arrayList.add(indexlessSubProduct);
                expandProductOfSums1 = Complex.ONE;
                expandProductOfSums12 = expandProductOfSums1(arrayList, transformationArr, true);
            } else {
                if (!$assertionsDisabled && !(indexlessSubProduct instanceof Product)) {
                    throw new AssertionError();
                }
                ArrayList arrayList2 = new ArrayList(indexlessSubProduct.size());
                boolean z4 = false;
                Iterator<Tensor> it3 = indexlessSubProduct.iterator();
                while (it3.hasNext()) {
                    Tensor next2 = it3.next();
                    if (sumContainsIndexed(next2)) {
                        arrayList.add(next2);
                    } else {
                        if (next2 instanceof Sum) {
                            z4 = true;
                        }
                        arrayList2.add(next2);
                    }
                }
                expandProductOfSums1 = z4 ? expandProductOfSums1(arrayList2, transformationArr, false) : Tensors.multiply((Tensor[]) arrayList2.toArray(new Tensor[arrayList2.size()]));
                expandProductOfSums12 = expandProductOfSums1(arrayList, transformationArr, true);
            }
        } else {
            expandProductOfSums1 = expandProductOfSums1(indexlessSubProduct, transformationArr, false);
            expandProductOfSums12 = expandProductOfSums1(dataSubProduct, transformationArr, true);
        }
        return expandProductOfSums12 instanceof Sum ? FastTensors.multiplySumElementsOnScalarFactorAndExpandScalars((Sum) expandProductOfSums12, expandProductOfSums1) : Tensors.multiply(expandProductOfSums1, expandProductOfSums12);
    }

    public static Tensor apply(Transformation[] transformationArr, Tensor tensor) {
        for (Transformation transformation : transformationArr) {
            tensor = transformation.transform(tensor);
        }
        return tensor;
    }

    public static Tensor expandProductOfSums1(Iterable<Tensor> iterable, Transformation[] transformationArr, boolean z) {
        Transformation[] transformationArr2 = z ? (Transformation[]) ArraysUtils.addAll(new Transformation[]{expandIndexlessSubproduct}, transformationArr) : transformationArr;
        int i = 10;
        boolean z2 = iterable instanceof Tensor;
        if (z2) {
            if (!(iterable instanceof Product)) {
                return (Tensor) iterable;
            }
            i = ((Tensor) iterable).size();
        }
        ArrayList arrayList = new ArrayList(i);
        ArrayList arrayList2 = new ArrayList(i);
        for (Tensor tensor : iterable) {
            if (tensor instanceof Sum) {
                arrayList2.add((Sum) tensor);
            } else {
                arrayList.add(tensor);
            }
        }
        if (arrayList2.isEmpty()) {
            return z2 ? (Tensor) iterable : Tensors.multiply((Tensor[]) arrayList.toArray(new Tensor[arrayList.size()]));
        }
        if (arrayList2.size() == 1) {
            return z ? apply(transformationArr, FastTensors.multiplySumElementsOnFactorAndExpand((Sum) arrayList2.get(0), Tensors.multiply((Tensor[]) arrayList.toArray(new Tensor[arrayList.size()])))) : apply(transformationArr, FastTensors.multiplySumElementsOnFactor((Sum) arrayList2.get(0), Tensors.multiply((Tensor[]) arrayList.toArray(new Tensor[arrayList.size()]))));
        }
        Tensor tensor2 = (Tensor) arrayList2.get(0);
        int i2 = 1;
        int size = arrayList2.size();
        while (i2 != size - 1) {
            if (tensor2 == null) {
                tensor2 = (Tensor) arrayList2.get(i2);
            } else {
                tensor2 = expandPairOfSums((Sum) tensor2, (Sum) arrayList2.get(i2), transformationArr2);
                if (!(tensor2 instanceof Sum)) {
                    arrayList.add(tensor2);
                    tensor2 = null;
                }
            }
            i2++;
        }
        return tensor2 == null ? z ? apply(transformationArr, FastTensors.multiplySumElementsOnFactorAndExpand((Sum) arrayList2.get(i2), Tensors.multiply((Tensor[]) arrayList.toArray(new Tensor[arrayList.size()])))) : apply(transformationArr, FastTensors.multiplySumElementsOnFactor((Sum) arrayList2.get(i2), Tensors.multiply((Tensor[]) arrayList.toArray(new Tensor[arrayList.size()])))) : expandPairOfSums((Sum) tensor2, (Sum) arrayList2.get(i2), (Tensor[]) arrayList.toArray(new Tensor[arrayList.size()]), transformationArr2);
    }

    public static boolean isExpandablePower(Tensor tensor) {
        return (tensor instanceof Power) && (tensor.get(0) instanceof Sum) && TensorUtils.isInteger(tensor.get(1));
    }

    public static boolean sumContainsIndexed(Tensor tensor) {
        if (!(tensor instanceof Sum)) {
            return false;
        }
        Iterator<Tensor> it = tensor.iterator();
        while (it.hasNext()) {
            if (it.next().getIndices().size() != 0) {
                return true;
            }
        }
        return false;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v7, types: [cc.redberry.core.tensor.Tensor] */
    public static Tensor expandSymbolicPower(Sum sum, int i, Transformation[] transformationArr) {
        Sum sum2 = sum;
        for (int i2 = i - 1; i2 >= 1; i2--) {
            sum2 = expandPairOfSums(sum2, sum, transformationArr);
        }
        return sum2;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v14, types: [cc.redberry.core.tensor.Tensor] */
    public static Tensor expandPower(Sum sum, int i, int[] iArr, Transformation[] transformationArr) {
        Sum sum2 = sum;
        TIntHashSet tIntHashSet = new TIntHashSet(iArr);
        TIntHashSet allIndicesNamesT = TensorUtils.getAllIndicesNamesT(sum);
        tIntHashSet.ensureCapacity(allIndicesNamesT.size() * i);
        tIntHashSet.addAll(allIndicesNamesT);
        for (int i2 = i - 1; i2 >= 1; i2--) {
            sum2 = expandPairOfSums(sum2, (Sum) ApplyIndexMapping.renameDummy(sum, tIntHashSet.toArray(), tIntHashSet), transformationArr);
        }
        return sum2;
    }

    static {
        $assertionsDisabled = !ExpandUtils.class.desiredAssertionStatus();
        expandIndexlessSubproduct = new Transformation() { // from class: cc.redberry.core.transformations.expand.ExpandUtils.1
            @Override // cc.redberry.core.transformations.Transformation
            public Tensor transform(Tensor tensor) {
                if (!(tensor instanceof Product)) {
                    return tensor;
                }
                Product product = (Product) tensor;
                Tensor indexlessSubProduct = product.getIndexlessSubProduct();
                boolean z = false;
                if (indexlessSubProduct instanceof Product) {
                    Iterator<Tensor> it = indexlessSubProduct.iterator();
                    while (true) {
                        if (!it.hasNext()) {
                            break;
                        }
                        if (it.next() instanceof Sum) {
                            z = true;
                            break;
                        }
                    }
                }
                return z ? Tensors.multiply(ExpandUtils.expandProductOfSums1(indexlessSubProduct, new Transformation[0], false), product.getDataSubProduct()) : tensor;
            }
        };
    }
}
