package au.csiro.variantspark.algo;

import java.util.Arrays;
import java.util.function.BiConsumer;

/* loaded from: input_file:au/csiro/variantspark/algo/JConfusionClassificationSplitter.class */
public class JConfusionClassificationSplitter implements ClassificationSplitter {
    private final int[] leftSplitCounts;
    private final int[] rightSplitCounts;
    private final int[][] confusion;
    private final double[] leftRightGini = new double[2];
    private final int[] labels;
    private final int nCategories;
    private final int nLevels;

    public JConfusionClassificationSplitter(int[] iArr, int i, int i2) {
        this.labels = iArr;
        this.nCategories = i;
        this.nLevels = i2;
        this.confusion = new int[i2][this.nCategories];
        this.leftSplitCounts = new int[this.nCategories];
        this.rightSplitCounts = new int[this.nCategories];
    }

    @Override // au.csiro.variantspark.algo.ClassificationSplitter
    public SplitInfo findSplit(double[] dArr, int[] iArr) {
        return dofindSplit(iArr, (iArr2, iArr3) -> {
            for (int i : iArr2) {
                int[] iArr2 = iArr3[(int) dArr[i]];
                int i2 = this.labels[i];
                iArr2[i2] = iArr2[i2] + 1;
            }
        });
    }

    @Override // au.csiro.variantspark.algo.ClassificationSplitter
    public SplitInfo findSplit(int[] iArr, int[] iArr2) {
        return dofindSplit(iArr2, (iArr3, iArr4) -> {
            for (int i : iArr3) {
                int[] iArr3 = iArr4[iArr[i]];
                int i2 = this.labels[i];
                iArr3[i2] = iArr3[i2] + 1;
            }
        });
    }

    @Override // au.csiro.variantspark.algo.ClassificationSplitter
    public SplitInfo findSplit(byte[] bArr, int[] iArr) {
        return dofindSplit(iArr, (iArr2, iArr3) -> {
            for (int i : iArr2) {
                int[] iArr2 = iArr3[bArr[i]];
                int i2 = this.labels[i];
                iArr2[i2] = iArr2[i2] + 1;
            }
        });
    }

    private <T> SplitInfo dofindSplit(int[] iArr, BiConsumer<int[], int[][]> biConsumer) {
        SplitInfo splitInfo = null;
        double d = Double.MAX_VALUE;
        if (iArr.length < 2) {
            return null;
        }
        for (int[] iArr2 : this.confusion) {
            Arrays.fill(iArr2, 0);
        }
        biConsumer.accept(iArr, this.confusion);
        Arrays.fill(this.leftSplitCounts, 0);
        Arrays.fill(this.rightSplitCounts, 0);
        for (int[] iArr3 : this.confusion) {
            ArrayOps.addEq(this.rightSplitCounts, iArr3);
        }
        for (int i = 0; i < this.nLevels - 1; i++) {
            ArrayOps.addEq(this.leftSplitCounts, this.confusion[i]);
            ArrayOps.subEq(this.rightSplitCounts, this.confusion[i]);
            double splitGini = FastGini.splitGini(this.leftSplitCounts, this.rightSplitCounts, this.leftRightGini, true);
            if (splitGini < d) {
                splitInfo = new SplitInfo(i, splitGini, this.leftRightGini[0], this.leftRightGini[1]);
                d = splitGini;
            }
        }
        return splitInfo;
    }
}
