package ai.libs.jaicore.ml.core.filter.sampling.inmemory.casecontrol;

import ai.libs.jaicore.basic.sets.Pair;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.stream.IntStream;
import org.apache.commons.math3.distribution.EnumeratedIntegerDistribution;
import org.api4.java.ai.ml.classification.IClassifier;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
import org.api4.java.ai.ml.core.evaluation.IPrediction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/ml/core/filter/sampling/inmemory/casecontrol/ClassifierWeightedSampling.class */
public class ClassifierWeightedSampling<D extends ILabeledDataset<? extends ILabeledInstance>> extends PilotEstimateSampling<D> {
    private Logger logger;

    public ClassifierWeightedSampling(IClassifier iClassifier, Random random, D d) {
        super(d, iClassifier);
        this.logger = LoggerFactory.getLogger(ClassifierWeightedSampling.class);
        this.rand = random;
    }

    private double getMean(ILabeledDataset<?> iLabeledDataset) {
        double d = 0.0d;
        Iterator it = iLabeledDataset.iterator();
        while (it.hasNext()) {
            ILabeledInstance iLabeledInstance = (ILabeledInstance) it.next();
            try {
                d += getPilotEstimator().predict(iLabeledInstance).getProbabilityOfLabel(iLabeledInstance.getLabel());
            } catch (Exception e) {
                this.logger.error("Unexpected error in pilot estimator", e);
            }
        }
        return d / iLabeledDataset.size();
    }

    @Override // ai.libs.jaicore.ml.core.filter.sampling.inmemory.casecontrol.PilotEstimateSampling
    public List<Pair<ILabeledInstance, Double>> calculateAcceptanceThresholdsWithTrainedPilot(D d, IClassifier iClassifier) {
        int sample;
        double mean = getMean(d);
        double d2 = (10.0d * mean) + 1.0d;
        double d3 = d2 + (2.0d * mean);
        double[] dArr = new double[d.size()];
        for (int i = 0; i < dArr.length; i++) {
            try {
                IPrediction predict = iClassifier.predict((ILabeledInstance) d.get(i));
                if (predict.getLabelWithHighestProbability() == ((ILabeledInstance) d.get(i)).getLabel()) {
                    dArr[i] = d3 - predict.getProbabilityOfLabel(((ILabeledInstance) d.get(i)).getLabel());
                } else {
                    dArr[i] = d2 + predict.getProbabilityOfLabel(predict.getLabelWithHighestProbability());
                }
            } catch (Exception e) {
                dArr[i] = 0.0d;
            }
        }
        EnumeratedIntegerDistribution enumeratedIntegerDistribution = new EnumeratedIntegerDistribution(IntStream.range(0, ((ILabeledDataset) getInput()).size()).toArray(), dArr);
        enumeratedIntegerDistribution.reseedRandomGenerator(this.rand.nextLong());
        int sampleSize = getSampleSize();
        HashSet hashSet = new HashSet();
        for (int i2 = 0; i2 < sampleSize; i2++) {
            do {
                sample = enumeratedIntegerDistribution.sample();
            } while (hashSet.contains(Integer.valueOf(sample)));
            hashSet.add(Integer.valueOf(sample));
        }
        ArrayList arrayList = new ArrayList();
        int size = d.size();
        for (int i3 = 0; i3 < size; i3++) {
            arrayList.add(new Pair((ILabeledInstance) d.get(i3), Double.valueOf(hashSet.contains(Integer.valueOf(i3)) ? 1.0d : 0.0d)));
        }
        return arrayList;
    }
}
