package cc.redberry.core.tensor;

import cc.redberry.core.indexgenerator.IndexGenerator;
import cc.redberry.core.indexmapping.IndexMapping;
import cc.redberry.core.indexmapping.IndexMappingBuffer;
import cc.redberry.core.indexmapping.IndexMappingBufferRecord;
import cc.redberry.core.indices.IndicesBuilder;
import cc.redberry.core.indices.IndicesFactory;
import cc.redberry.core.indices.IndicesUtils;
import cc.redberry.core.indices.SimpleIndices;
import cc.redberry.core.number.Complex;
import cc.redberry.core.tensor.functions.ScalarFunction;
import cc.redberry.core.utils.ArraysUtils;
import cc.redberry.core.utils.IntArrayList;
import cc.redberry.core.utils.TensorUtils;
import gnu.trove.set.hash.TIntHashSet;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Map;

/* loaded from: input_file:cc/redberry/core/tensor/ApplyIndexMapping.class */
public final class ApplyIndexMapping {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:cc/redberry/core/tensor/ApplyIndexMapping$IndexMapper.class */
    public static final class IndexMapper implements IndexMapping {
        private final int[] from;
        private final int[] to;

        public IndexMapper(int[] iArr, int[] iArr2) {
            this.from = iArr;
            this.to = iArr2;
        }

        @Override // cc.redberry.core.indexmapping.IndexMapping
        public int map(int i) {
            int binarySearch = Arrays.binarySearch(this.from, IndicesUtils.getNameWithType(i));
            return binarySearch < 0 ? i : IndicesUtils.getRawStateInt(i) ^ this.to[binarySearch];
        }

        boolean contract(int[] iArr) {
            if (iArr.length <= 1) {
                return false;
            }
            for (int i = 0; i < iArr.length; i++) {
                iArr[i] = Integer.MAX_VALUE & map(iArr[i]);
            }
            Arrays.sort(iArr);
            for (int i2 = 1; i2 < iArr.length; i2++) {
                if (iArr[i2] == iArr[i2 - 1]) {
                    return true;
                }
            }
            return false;
        }
    }

    public static Tensor renameDummy(Tensor tensor, int[] iArr, TIntHashSet tIntHashSet) {
        if (iArr.length == 0) {
            return tensor;
        }
        if ((tensor instanceof Complex) || (tensor instanceof ScalarFunction)) {
            return tensor;
        }
        TIntHashSet allDummyIndicesT = TensorUtils.getAllDummyIndicesT(tensor);
        if (allDummyIndicesT.isEmpty()) {
            return tensor;
        }
        allDummyIndicesT.ensureCapacity(iArr.length);
        IntArrayList intArrayList = null;
        for (int i : iArr) {
            if (!allDummyIndicesT.add(i)) {
                if (intArrayList == null) {
                    intArrayList = new IntArrayList();
                }
                intArrayList.add(i);
            }
        }
        if (intArrayList == null) {
            return tensor;
        }
        allDummyIndicesT.addAll(IndicesUtils.getIndicesNames(tensor.getIndices().getFree()));
        IndexGenerator indexGenerator = new IndexGenerator(allDummyIndicesT.toArray());
        int[] array = intArrayList.toArray();
        int[] iArr2 = new int[intArrayList.size()];
        Arrays.sort(array);
        tIntHashSet.ensureCapacity(array.length);
        for (int length = array.length - 1; length >= 0; length--) {
            int generate = indexGenerator.generate(IndicesUtils.getType(array[length]));
            iArr2[length] = generate;
            tIntHashSet.add(generate);
        }
        return applyIndexMapping(tensor, new IndexMapper(array, iArr2), false);
    }

    public static Tensor renameDummy(Tensor tensor, int[] iArr) {
        if (iArr.length == 0) {
            return tensor;
        }
        if ((tensor instanceof Complex) || (tensor instanceof ScalarFunction)) {
            return tensor;
        }
        TIntHashSet allDummyIndicesT = TensorUtils.getAllDummyIndicesT(tensor);
        if (allDummyIndicesT.isEmpty()) {
            return tensor;
        }
        allDummyIndicesT.ensureCapacity(iArr.length);
        IntArrayList intArrayList = null;
        for (int i : iArr) {
            if (!allDummyIndicesT.add(i)) {
                if (intArrayList == null) {
                    intArrayList = new IntArrayList();
                }
                intArrayList.add(i);
            }
        }
        if (intArrayList == null) {
            return tensor;
        }
        allDummyIndicesT.addAll(IndicesUtils.getIndicesNames(tensor.getIndices().getFree()));
        IndexGenerator indexGenerator = new IndexGenerator(allDummyIndicesT.toArray());
        int[] array = intArrayList.toArray();
        int[] iArr2 = new int[intArrayList.size()];
        Arrays.sort(array);
        for (int length = array.length - 1; length >= 0; length--) {
            iArr2[length] = indexGenerator.generate(IndicesUtils.getType(array[length]));
        }
        return applyIndexMapping(tensor, new IndexMapper(array, iArr2), false);
    }

    public static Tensor applyIndexMapping(Tensor tensor, IndexMappingBuffer indexMappingBuffer) {
        return applyIndexMapping(tensor, indexMappingBuffer, new int[0]);
    }

    public static Tensor applyIndexMapping(Tensor tensor, IndexMappingBuffer indexMappingBuffer, int[] iArr) {
        if (indexMappingBuffer.isEmpty()) {
            if (tensor.getIndices().getFree().size() != 0) {
                throw new IllegalArgumentException("From indices are not equal to free indices of tensor.");
            }
            Tensor renameDummy = renameDummy(tensor, iArr);
            return indexMappingBuffer.getSign() ? Tensors.negate(renameDummy) : renameDummy;
        }
        if ((tensor instanceof Complex) || (tensor instanceof ScalarFunction)) {
            return tensor;
        }
        Map<Integer, IndexMappingBufferRecord> map = indexMappingBuffer.getMap();
        int[] iArr2 = new int[map.size()];
        int[] iArr3 = new int[map.size()];
        int i = 0;
        for (Map.Entry<Integer, IndexMappingBufferRecord> entry : map.entrySet()) {
            iArr2[i] = entry.getKey().intValue();
            IndexMappingBufferRecord value = entry.getValue();
            int i2 = i;
            i++;
            iArr3[i2] = value.getIndexName() ^ (value.diffStatesInitialized() ? Integer.MIN_VALUE : 0);
        }
        int[] indicesNames = IndicesUtils.getIndicesNames(tensor.getIndices().getFree());
        Arrays.sort(indicesNames);
        int[] iArr4 = (int[]) iArr2.clone();
        Arrays.sort(iArr4);
        if (!Arrays.equals(indicesNames, iArr4)) {
            throw new IllegalArgumentException("From indices are not equal to free indices of tensor.");
        }
        Tensor applyIndexMappingFromPreparedSource = applyIndexMappingFromPreparedSource(tensor, iArr2, iArr3, iArr);
        return indexMappingBuffer.getSign() ? Tensors.negate(applyIndexMappingFromPreparedSource) : applyIndexMappingFromPreparedSource;
    }

    private static void checkConsistent(Tensor tensor, int[] iArr) {
        int[] copy = tensor.getIndices().getFree().getAllIndices().copy();
        Arrays.sort(copy);
        if (!Arrays.equals(copy, iArr)) {
            throw new IllegalArgumentException("From indices are not equal to free indices of tensor.");
        }
    }

    public static Tensor applyIndexMapping(Tensor tensor, int[] iArr, int[] iArr2, int[] iArr3) {
        if (iArr.length != 0) {
            return applyIndexMapping1(tensor, (int[]) iArr.clone(), (int[]) iArr2.clone(), iArr3);
        }
        if (tensor.getIndices().getFree().size() == 0 && iArr2.length == 0) {
            return renameDummy(tensor, iArr3);
        }
        throw new IllegalArgumentException("from legth does not match free indices size or to length.");
    }

    private static Tensor applyIndexMapping1(Tensor tensor, int[] iArr, int[] iArr2, int[] iArr3) {
        if ((tensor instanceof Complex) || (tensor instanceof ScalarFunction)) {
            return tensor;
        }
        ArraysUtils.quickSort(iArr, iArr2);
        checkConsistent(tensor, iArr);
        for (int length = iArr.length - 1; length >= 0; length--) {
            int rawStateInt = IndicesUtils.getRawStateInt(iArr[length]);
            int i = length;
            iArr[i] = iArr[i] ^ rawStateInt;
            int i2 = length;
            iArr2[i2] = iArr2[i2] ^ rawStateInt;
        }
        ArraysUtils.quickSort(iArr, iArr2);
        return applyIndexMappingFromPreparedSource(tensor, iArr, iArr2, iArr3);
    }

    private static Tensor applyIndexMappingFromPreparedSource(Tensor tensor, int[] iArr, int[] iArr2, int[] iArr3) {
        int[] iArr4 = new int[iArr2.length + iArr3.length];
        System.arraycopy(iArr2, 0, iArr4, 0, iArr2.length);
        System.arraycopy(iArr3, 0, iArr4, iArr2.length, iArr3.length);
        for (int length = iArr4.length - 1; length >= 0; length--) {
            iArr4[length] = IndicesUtils.getNameWithType(iArr4[length]);
        }
        IntArrayList intArrayList = new IntArrayList(iArr.length);
        IntArrayList intArrayList2 = new IntArrayList(iArr2.length);
        intArrayList.addAll(iArr);
        intArrayList2.addAll(iArr2);
        Arrays.sort(iArr4);
        int[] array = TensorUtils.getAllDummyIndicesT(tensor).toArray();
        int[] iArr5 = new int[iArr4.length + array.length];
        System.arraycopy(iArr4, 0, iArr5, 0, iArr4.length);
        System.arraycopy(array, 0, iArr5, iArr4.length, array.length);
        IndexGenerator indexGenerator = new IndexGenerator(iArr5);
        for (int i : array) {
            if (Arrays.binarySearch(iArr4, i) >= 0) {
                intArrayList.add(i);
                intArrayList2.add(indexGenerator.generate(IndicesUtils.getType(i)));
            }
        }
        int[] array2 = intArrayList.toArray();
        int[] array3 = intArrayList2.toArray();
        ArraysUtils.quickSort(array2, array3);
        return applyIndexMapping(tensor, new IndexMapper(array2, array3));
    }

    private static Tensor applyIndexMapping(Tensor tensor, IndexMapper indexMapper) {
        return tensor instanceof SimpleTensor ? applyIndexMapping(tensor, indexMapper, false) : ((tensor instanceof Complex) || (tensor instanceof ScalarFunction)) ? tensor : applyIndexMapping(tensor, indexMapper, indexMapper.contract(IndicesUtils.getIndicesNames(tensor.getIndices().getFree())));
    }

    private static Tensor applyIndexMapping(Tensor tensor, IndexMapper indexMapper, boolean z) {
        if (tensor instanceof SimpleTensor) {
            SimpleTensor simpleTensor = (SimpleTensor) tensor;
            SimpleIndices indices = simpleTensor.getIndices();
            SimpleIndices applyIndexMapping = indices.applyIndexMapping((IndexMapping) indexMapper);
            if (indices == applyIndexMapping) {
                return tensor;
            }
            if (!(tensor instanceof TensorField)) {
                return Tensors.simpleTensor(simpleTensor.name, applyIndexMapping);
            }
            TensorField tensorField = (TensorField) simpleTensor;
            return Tensors.field(tensorField.name, applyIndexMapping, tensorField.argIndices, tensorField.args);
        }
        if ((tensor instanceof Complex) || (tensor instanceof ScalarFunction)) {
            return tensor;
        }
        if (tensor instanceof Expression) {
            boolean contract = indexMapper.contract(IndicesUtils.getIndicesNames(tensor.getIndices()));
            return Tensors.expression(applyIndexMapping(tensor.get(0), indexMapper, contract), applyIndexMapping(tensor.get(1), indexMapper, contract));
        }
        if (tensor instanceof Power) {
            Tensor tensor2 = tensor.get(0);
            Tensor applyIndexMapping2 = applyIndexMapping(tensor2, indexMapper, false);
            return tensor2 == applyIndexMapping2 ? tensor : new Power(applyIndexMapping2, tensor.get(1));
        }
        if (z) {
            TensorBuilder builder = tensor.getBuilder();
            Iterator<Tensor> it = tensor.iterator();
            while (it.hasNext()) {
                builder.put(applyIndexMapping(it.next(), indexMapper));
            }
            return builder.build();
        }
        if (!(tensor instanceof Product)) {
            if (!(tensor instanceof Sum)) {
                throw new RuntimeException();
            }
            Tensor[] tensorArr = ((Sum) tensor).data;
            Tensor[] tensorArr2 = null;
            for (int length = tensorArr.length - 1; length >= 0; length--) {
                Tensor tensor3 = tensorArr[length];
                Tensor applyIndexMapping3 = applyIndexMapping(tensor3, indexMapper, false);
                if (tensor3 != applyIndexMapping3) {
                    if (tensorArr2 == null) {
                        tensorArr2 = (Tensor[]) tensorArr.clone();
                    }
                    tensorArr2[length] = applyIndexMapping3;
                }
            }
            return tensorArr2 == null ? tensor : new Sum(tensorArr2, IndicesFactory.create(tensorArr2[0].getIndices().getFree()));
        }
        Product product = (Product) tensor;
        Tensor[] indexless = product.getIndexless();
        Tensor[] tensorArr3 = null;
        Tensor[] tensorArr4 = product.data;
        Tensor[] tensorArr5 = null;
        for (int length2 = indexless.length - 1; length2 >= 0; length2--) {
            Tensor tensor4 = indexless[length2];
            Tensor applyIndexMapping4 = applyIndexMapping(tensor4, indexMapper, false);
            if (tensor4 != applyIndexMapping4) {
                if (tensorArr3 == null) {
                    tensorArr3 = (Tensor[]) indexless.clone();
                }
                tensorArr3[length2] = applyIndexMapping4;
            }
        }
        for (int length3 = tensorArr4.length - 1; length3 >= 0; length3--) {
            Tensor tensor5 = tensorArr4[length3];
            Tensor applyIndexMapping5 = applyIndexMapping(tensor5, indexMapper, false);
            if (tensor5 != applyIndexMapping5) {
                if (tensorArr5 == null) {
                    tensorArr5 = (Tensor[]) tensorArr4.clone();
                }
                tensorArr5[length3] = applyIndexMapping5;
            }
        }
        if (tensorArr3 == null) {
            tensorArr3 = indexless;
        }
        return tensorArr5 == null ? new Product(product.indices, product.factor, tensorArr3, tensorArr4, product.contentReference) : new Product(new IndicesBuilder().append(tensorArr5).getIndices(), product.factor, tensorArr3, tensorArr5);
    }
}
