package cc.redberry.core.transformations.substitutions;

import cc.redberry.core.tensor.ApplyIndexMapping;
import cc.redberry.core.tensor.Product;
import cc.redberry.core.tensor.Sum;
import cc.redberry.core.tensor.Tensor;
import cc.redberry.core.tensor.TensorField;
import cc.redberry.core.tensor.functions.ScalarFunction;
import cc.redberry.core.tensor.iterator.DummyPayload;
import cc.redberry.core.tensor.iterator.Payload;
import cc.redberry.core.tensor.iterator.PayloadFactory;
import cc.redberry.core.tensor.iterator.StackPosition;
import cc.redberry.core.tensor.iterator.TraverseGuide;
import cc.redberry.core.tensor.iterator.TraverseState;
import cc.redberry.core.tensor.iterator.TreeIterator;
import cc.redberry.core.tensor.iterator.TreeTraverseIterator;
import cc.redberry.core.utils.ByteBackedBitArray;
import cc.redberry.core.utils.TensorUtils;
import gnu.trove.TCollections;
import gnu.trove.iterator.TIntIterator;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import java.util.Arrays;

/* loaded from: input_file:cc/redberry/core/transformations/substitutions/SubstitutionIterator.class */
public final class SubstitutionIterator implements TreeIterator {
    private final TreeTraverseIterator<ForbiddenContainer> innerIterator;
    private static final TIntSet EMPTY_INT_SET = TCollections.unmodifiableSet(new TIntHashSet(0));
    private static final ForbiddenContainer scalarFunctionContainer = new ForbiddenContainer() { // from class: cc.redberry.core.transformations.substitutions.SubstitutionIterator.1
        @Override // cc.redberry.core.transformations.substitutions.SubstitutionIterator.ForbiddenContainer
        public TIntSet getForbidden() {
            return SubstitutionIterator.EMPTY_INT_SET;
        }

        @Override // cc.redberry.core.transformations.substitutions.SubstitutionIterator.ForbiddenContainer
        public void submit(TIntSet tIntSet, TIntSet tIntSet2) {
        }

        @Override // cc.redberry.core.tensor.iterator.Payload
        public Tensor onLeaving(StackPosition<ForbiddenContainer> stackPosition) {
            StackPosition<ForbiddenContainer> previous;
            if (!stackPosition.isModified() || (previous = stackPosition.previous()) == null) {
                return null;
            }
            Tensor renameDummy = ApplyIndexMapping.renameDummy(stackPosition.getTensor(), previous.getPayload().getForbidden().toArray());
            previous.getPayload().submit(SubstitutionIterator.EMPTY_INT_SET, TensorUtils.getAllIndicesNamesT(renameDummy));
            return renameDummy;
        }
    };
    private static final ForbiddenContainer EMPTY_CONTAINER = new ForbiddenContainer() { // from class: cc.redberry.core.transformations.substitutions.SubstitutionIterator.2
        @Override // cc.redberry.core.transformations.substitutions.SubstitutionIterator.ForbiddenContainer
        public TIntSet getForbidden() {
            return SubstitutionIterator.EMPTY_INT_SET;
        }

        @Override // cc.redberry.core.transformations.substitutions.SubstitutionIterator.ForbiddenContainer
        public void submit(TIntSet tIntSet, TIntSet tIntSet2) {
        }

        @Override // cc.redberry.core.tensor.iterator.Payload
        public Tensor onLeaving(StackPosition<ForbiddenContainer> stackPosition) {
            return null;
        }
    };

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:cc/redberry/core/transformations/substitutions/SubstitutionIterator$AbstractFC.class */
    public static abstract class AbstractFC extends DummyPayload<ForbiddenContainer> implements ForbiddenContainer {
        protected final StackPosition<ForbiddenContainer> position;
        protected TIntSet forbidden;
        protected final Tensor tensor;

        private AbstractFC(StackPosition<ForbiddenContainer> stackPosition) {
            this.forbidden = null;
            this.position = stackPosition;
            this.tensor = stackPosition.getInitialTensor();
        }

        public abstract void insureInitialized();

        @Override // cc.redberry.core.transformations.substitutions.SubstitutionIterator.ForbiddenContainer
        public TIntSet getForbidden() {
            insureInitialized();
            TIntHashSet tIntHashSet = new TIntHashSet(this.forbidden);
            tIntHashSet.removeAll(TensorUtils.getAllIndicesNamesT(this.tensor.get(this.position.currentIndex())));
            return tIntHashSet;
        }
    }

    /* loaded from: input_file:cc/redberry/core/transformations/substitutions/SubstitutionIterator$FCPayloadFactory.class */
    private class FCPayloadFactory implements PayloadFactory<ForbiddenContainer> {
        private FCPayloadFactory() {
        }

        @Override // cc.redberry.core.tensor.iterator.PayloadFactory
        public boolean allowLazyInitialization() {
            return true;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // cc.redberry.core.tensor.iterator.PayloadFactory
        public ForbiddenContainer create(StackPosition<ForbiddenContainer> stackPosition) {
            Tensor initialTensor = stackPosition.getInitialTensor();
            StackPosition<ForbiddenContainer> previous = stackPosition.previous();
            ForbiddenContainer payload = previous == null ? SubstitutionIterator.EMPTY_CONTAINER : previous.getPayload();
            return payload == SubstitutionIterator.EMPTY_CONTAINER ? initialTensor instanceof Product ? new TopProductFC(stackPosition) : SubstitutionIterator.EMPTY_CONTAINER : initialTensor instanceof Product ? new ProductFC(stackPosition) : initialTensor instanceof Sum ? new SumFC(stackPosition) : initialTensor instanceof TensorField ? SubstitutionIterator.EMPTY_CONTAINER : initialTensor instanceof ScalarFunction ? SubstitutionIterator.scalarFunctionContainer : new TransparentFC(payload);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:cc/redberry/core/transformations/substitutions/SubstitutionIterator$ForbiddenContainer.class */
    public interface ForbiddenContainer extends Payload<ForbiddenContainer> {
        TIntSet getForbidden();

        void submit(TIntSet tIntSet, TIntSet tIntSet2);
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:cc/redberry/core/transformations/substitutions/SubstitutionIterator$ProductFC.class */
    public static final class ProductFC extends AbstractFC {
        private ProductFC(StackPosition<ForbiddenContainer> stackPosition) {
            super(stackPosition);
        }

        @Override // cc.redberry.core.transformations.substitutions.SubstitutionIterator.AbstractFC
        public void insureInitialized() {
            if (this.forbidden != null) {
                return;
            }
            this.forbidden = new TIntHashSet(this.position.previous().getPayload().getForbidden());
            this.forbidden.addAll(TensorUtils.getAllIndicesNamesT(this.tensor));
        }

        @Override // cc.redberry.core.transformations.substitutions.SubstitutionIterator.ForbiddenContainer
        public void submit(TIntSet tIntSet, TIntSet tIntSet2) {
            insureInitialized();
            this.forbidden.addAll(tIntSet2);
            this.forbidden.removeAll(tIntSet);
            this.position.previous().getPayload().submit(tIntSet, tIntSet2);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:cc/redberry/core/transformations/substitutions/SubstitutionIterator$SumFC.class */
    public static final class SumFC extends AbstractFC {
        private int[] allDummyIndices;
        private ByteBackedBitArray[] usedArrays;

        private SumFC(StackPosition<ForbiddenContainer> stackPosition) {
            super(stackPosition);
        }

        @Override // cc.redberry.core.transformations.substitutions.SubstitutionIterator.AbstractFC
        public void insureInitialized() {
            if (this.forbidden != null) {
                return;
            }
            this.forbidden = this.position.previous().getPayload().getForbidden();
            this.allDummyIndices = TensorUtils.getAllDummyIndicesT(this.tensor).toArray();
            Arrays.sort(this.allDummyIndices);
            int size = this.tensor.size();
            this.usedArrays = new ByteBackedBitArray[this.allDummyIndices.length];
            for (int length = this.allDummyIndices.length - 1; length >= 0; length--) {
                this.usedArrays[length] = new ByteBackedBitArray(size);
            }
            for (int i = size - 1; i >= 0; i--) {
                TIntIterator it = TensorUtils.getAllDummyIndicesT(this.tensor.get(i)).iterator();
                while (it.hasNext()) {
                    this.usedArrays[Arrays.binarySearch(this.allDummyIndices, it.next())].set(i);
                }
            }
        }

        @Override // cc.redberry.core.transformations.substitutions.SubstitutionIterator.ForbiddenContainer
        public void submit(TIntSet tIntSet, TIntSet tIntSet2) {
            insureInitialized();
            TIntHashSet tIntHashSet = null;
            TIntIterator it = tIntSet.iterator();
            while (it.hasNext()) {
                int[] iArr = this.allDummyIndices;
                int next = it.next();
                int binarySearch = Arrays.binarySearch(iArr, next);
                this.usedArrays[binarySearch].clear(this.position.currentIndex());
                if (this.usedArrays[binarySearch].bitCount() == 0) {
                    if (tIntHashSet == null) {
                        tIntHashSet = new TIntHashSet(tIntSet.size());
                    }
                    tIntHashSet.add(next);
                }
            }
            if (tIntHashSet == null) {
                tIntHashSet = SubstitutionIterator.EMPTY_INT_SET;
            }
            TIntSet tIntHashSet2 = new TIntHashSet(tIntSet2);
            TIntIterator it2 = tIntHashSet2.iterator();
            while (it2.hasNext()) {
                int binarySearch2 = Arrays.binarySearch(this.allDummyIndices, it2.next());
                if (binarySearch2 >= 0) {
                    if (this.usedArrays[binarySearch2].bitCount() >= 0) {
                        it2.remove();
                    }
                    this.usedArrays[binarySearch2].set(this.position.currentIndex());
                }
            }
            this.position.previous().getPayload().submit(tIntHashSet, tIntHashSet2);
        }

        @Override // cc.redberry.core.transformations.substitutions.SubstitutionIterator.AbstractFC, cc.redberry.core.transformations.substitutions.SubstitutionIterator.ForbiddenContainer
        public TIntSet getForbidden() {
            insureInitialized();
            return new TIntHashSet(this.forbidden);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:cc/redberry/core/transformations/substitutions/SubstitutionIterator$TopProductFC.class */
    public static final class TopProductFC extends AbstractFC {
        private TopProductFC(StackPosition<ForbiddenContainer> stackPosition) {
            super(stackPosition);
        }

        @Override // cc.redberry.core.transformations.substitutions.SubstitutionIterator.AbstractFC
        public void insureInitialized() {
            if (this.forbidden != null) {
                return;
            }
            this.forbidden = TensorUtils.getAllIndicesNamesT(this.tensor);
        }

        @Override // cc.redberry.core.transformations.substitutions.SubstitutionIterator.ForbiddenContainer
        public void submit(TIntSet tIntSet, TIntSet tIntSet2) {
            insureInitialized();
            this.forbidden.addAll(tIntSet2);
            this.forbidden.removeAll(tIntSet);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:cc/redberry/core/transformations/substitutions/SubstitutionIterator$TransparentFC.class */
    public static final class TransparentFC extends DummyPayload<ForbiddenContainer> implements ForbiddenContainer {
        private final ForbiddenContainer parent;

        private TransparentFC(ForbiddenContainer forbiddenContainer) {
            this.parent = forbiddenContainer;
        }

        @Override // cc.redberry.core.transformations.substitutions.SubstitutionIterator.ForbiddenContainer
        public TIntSet getForbidden() {
            return this.parent.getForbidden();
        }

        @Override // cc.redberry.core.transformations.substitutions.SubstitutionIterator.ForbiddenContainer
        public void submit(TIntSet tIntSet, TIntSet tIntSet2) {
            this.parent.submit(tIntSet, tIntSet2);
        }
    }

    public SubstitutionIterator(Tensor tensor) {
        this.innerIterator = new TreeTraverseIterator<>(tensor, new FCPayloadFactory());
    }

    public SubstitutionIterator(Tensor tensor, TraverseGuide traverseGuide) {
        this.innerIterator = new TreeTraverseIterator<>(tensor, traverseGuide, new FCPayloadFactory());
    }

    @Override // cc.redberry.core.tensor.iterator.TreeIterator
    public Tensor next() {
        TraverseState next;
        do {
            next = this.innerIterator.next();
        } while (next == TraverseState.Entering);
        if (next == null) {
            return null;
        }
        return this.innerIterator.current();
    }

    public void unsafeSet(Tensor tensor) {
        this.innerIterator.set(tensor);
    }

    @Override // cc.redberry.core.tensor.iterator.TreeIterator
    public void set(Tensor tensor) {
        Tensor current = this.innerIterator.current();
        if (current == tensor) {
            return;
        }
        if (TensorUtils.isZeroOrIndeterminate(tensor) || TensorUtils.isSymbolic(tensor)) {
            this.innerIterator.set(tensor);
            return;
        }
        if (!tensor.getIndices().getFree().equalsRegardlessOrder(current.getIndices().getFree())) {
            throw new RuntimeException("Substitution with different free indices.");
        }
        StackPosition<ForbiddenContainer> previous = this.innerIterator.currentStackPosition().previous();
        if (previous != null) {
            ForbiddenContainer payload = previous.getPayload();
            TIntHashSet allDummyIndicesT = TensorUtils.getAllDummyIndicesT(current);
            TIntHashSet allDummyIndicesT2 = TensorUtils.getAllDummyIndicesT(tensor);
            TIntHashSet tIntHashSet = new TIntHashSet(allDummyIndicesT);
            TIntHashSet tIntHashSet2 = new TIntHashSet(allDummyIndicesT2);
            tIntHashSet.removeAll(allDummyIndicesT2);
            tIntHashSet2.removeAll(allDummyIndicesT);
            payload.submit(tIntHashSet, tIntHashSet2);
        }
        this.innerIterator.set(tensor);
    }

    public void safeSet(Tensor tensor) {
        if (this.innerIterator.current() != tensor) {
            set(ApplyIndexMapping.renameDummy(tensor, getForbidden()));
        }
    }

    public boolean isCurrentModified() {
        return this.innerIterator.currentStackPosition().isModified();
    }

    @Override // cc.redberry.core.tensor.iterator.TreeIterator
    public Tensor result() {
        return this.innerIterator.result();
    }

    @Override // cc.redberry.core.tensor.iterator.TreeIterator
    public int depth() {
        return this.innerIterator.depth();
    }

    public int[] getForbidden() {
        StackPosition<ForbiddenContainer> previous = this.innerIterator.currentStackPosition().previous();
        return previous == null ? new int[0] : previous.getPayload().getForbidden().toArray();
    }
}
