package dragon.ir.classification.featureselection;

import dragon.ir.classification.DocClassSet;
import dragon.ir.index.IndexReader;
import dragon.matrix.IntDenseMatrix;
import dragon.matrix.SparseMatrix;
import dragon.matrix.vector.DoubleVector;
import dragon.nlp.Token;
import dragon.nlp.compare.IndexComparator;
import dragon.nlp.compare.WeightComparator;
import dragon.util.SortedArray;
import java.io.Serializable;

/* loaded from: input_file:dragon/ir/classification/featureselection/ChiFeatureSelector.class */
public class ChiFeatureSelector extends AbstractFeatureSelector implements Serializable {
    private static final long serialVersionUID = 1;
    private double topPercentage;
    private boolean avgMode;

    public ChiFeatureSelector(double d, boolean z) {
        this.topPercentage = d;
        this.avgMode = z;
    }

    @Override // dragon.ir.classification.featureselection.AbstractFeatureSelector
    protected int[] getSelectedFeatures(IndexReader indexReader, DocClassSet docClassSet) {
        DoubleVector classPrior = getClassPrior(docClassSet);
        int i = 0;
        for (int i2 = 0; i2 < docClassSet.getClassNum(); i2++) {
            i += docClassSet.getDocClass(i2).getDocNum();
        }
        SortedArray computeTermCHI = computeTermCHI(getTermDistribution(indexReader, docClassSet), classPrior, i);
        int min = Math.min(computeTermCHI.size(), (int) (this.topPercentage * indexReader.getCollection().getTermNum()));
        SortedArray sortedArray = new SortedArray(min, new IndexComparator());
        for (int i3 = 0; i3 < min; i3++) {
            sortedArray.add(computeTermCHI.get(i3));
        }
        int[] iArr = new int[sortedArray.size()];
        for (int i4 = 0; i4 < iArr.length; i4++) {
            iArr[i4] = ((Token) sortedArray.get(i4)).getIndex();
        }
        return iArr;
    }

    @Override // dragon.ir.classification.featureselection.AbstractFeatureSelector
    protected int[] getSelectedFeatures(SparseMatrix sparseMatrix, DocClassSet docClassSet) {
        DoubleVector classPrior = getClassPrior(docClassSet);
        int i = 0;
        for (int i2 = 0; i2 < docClassSet.getClassNum(); i2++) {
            i += docClassSet.getDocClass(i2).getDocNum();
        }
        SortedArray computeTermCHI = computeTermCHI(getTermDistribution(sparseMatrix, docClassSet), classPrior, i);
        int min = Math.min(computeTermCHI.size(), (int) (this.topPercentage * sparseMatrix.columns()));
        SortedArray sortedArray = new SortedArray(min, new IndexComparator());
        for (int i3 = 0; i3 < min; i3++) {
            sortedArray.add(computeTermCHI.get(i3));
        }
        int[] iArr = new int[sortedArray.size()];
        for (int i4 = 0; i4 < iArr.length; i4++) {
            iArr[i4] = ((Token) sortedArray.get(i4)).getIndex();
        }
        return iArr;
    }

    private SortedArray computeTermCHI(IntDenseMatrix intDenseMatrix, DoubleVector doubleVector, int i) {
        DoubleVector copy = doubleVector.copy();
        copy.multiply(i);
        DoubleVector doubleVector2 = new DoubleVector(intDenseMatrix.columns());
        for (int i2 = 0; i2 < intDenseMatrix.columns(); i2++) {
            doubleVector2.set(i2, intDenseMatrix.getColumnSum(i2));
        }
        double d = i;
        DoubleVector doubleVector3 = new DoubleVector(copy.size());
        SortedArray sortedArray = new SortedArray(doubleVector2.size(), new IndexComparator());
        for (int i3 = 0; i3 < doubleVector2.size(); i3++) {
            if (doubleVector2.get(i3) > 0.0d) {
                for (int i4 = 0; i4 < copy.size(); i4++) {
                    doubleVector3.set(i4, calChiSquare(intDenseMatrix.getInt(i4, i3), copy.get(i4), doubleVector2.get(i3), d));
                }
                Token token = new Token(i3, 0);
                if (this.avgMode) {
                    token.setWeight(doubleVector3.dotProduct(doubleVector));
                } else {
                    token.setWeight(doubleVector3.getMaxValue());
                }
                sortedArray.add(token);
            }
        }
        sortedArray.setComparator(new WeightComparator(true));
        return sortedArray;
    }

    private double calChiSquare(double d, double d2, double d3, double d4) {
        if (d2 == 0.0d || d3 == 0.0d) {
            return 0.0d;
        }
        double d5 = (d2 * d3) / d4;
        double d6 = ((d - d5) * (d - d5)) / d5;
        double d7 = (d2 * (d4 - d3)) / d4;
        double d8 = d2 - d;
        double d9 = d6 + (((d8 - d7) * (d8 - d7)) / d7);
        double d10 = ((d4 - d2) * d3) / d4;
        double d11 = d3 - d;
        double d12 = d9 + (((d11 - d10) * (d11 - d10)) / d10);
        double d13 = ((d4 - d2) * (d4 - d3)) / d4;
        double d14 = ((d4 - d2) - d3) + d;
        return d12 + (((d14 - d13) * (d14 - d13)) / d13);
    }
}
