package ai.libs.jaicore.ml.functionprediction.learner.learningcurveextrapolation;

import ai.libs.jaicore.ml.classification.loss.dataset.EClassificationPerformanceMeasure;
import ai.libs.jaicore.ml.core.evaluation.evaluator.SupervisedLearnerExecutor;
import ai.libs.jaicore.ml.core.filter.FilterBasedDatasetSplitter;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.ASamplingAlgorithm;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.factories.LabelBasedStratifiedSamplingFactory;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.factories.interfaces.IRerunnableSamplingAlgorithmFactory;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.factories.interfaces.ISamplingAlgorithmFactory;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import org.api4.java.ai.ml.classification.singlelabel.evaluation.ISingleLabelClassification;
import org.api4.java.ai.ml.core.dataset.splitter.SplitFailedException;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
import org.api4.java.ai.ml.core.evaluation.execution.ILearnerRunReport;
import org.api4.java.ai.ml.core.evaluation.learningcurve.ILearningCurve;
import org.api4.java.ai.ml.core.exception.DatasetCreationException;
import org.api4.java.ai.ml.core.learner.ISupervisedLearner;
import org.api4.java.algorithm.exceptions.AlgorithmException;
import org.api4.java.algorithm.exceptions.AlgorithmExecutionCanceledException;
import org.api4.java.common.control.ILoggingCustomizable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/ml/functionprediction/learner/learningcurveextrapolation/LearningCurveExtrapolator.class */
public class LearningCurveExtrapolator implements ILoggingCustomizable {
    protected ISupervisedLearner<ILabeledInstance, ILabeledDataset<? extends ILabeledInstance>> learner;
    protected ILabeledDataset<? extends ILabeledInstance> dataset;
    protected ILabeledDataset<? extends ILabeledInstance> train;
    protected ILabeledDataset<? extends ILabeledInstance> test;
    protected ISamplingAlgorithmFactory<ILabeledDataset<?>, ? extends ASamplingAlgorithm<ILabeledDataset<?>>> samplingAlgorithmFactory;
    protected Random random;
    protected LearningCurveExtrapolationMethod extrapolationMethod;
    private final int[] anchorPoints;
    private final double[] yValues;
    private final int[] trainingTimes;
    private Logger logger = LoggerFactory.getLogger(LearningCurveExtrapolator.class);
    protected ASamplingAlgorithm<ILabeledDataset<? extends ILabeledInstance>> samplingAlgorithm = null;

    /* JADX WARN: Multi-variable type inference failed */
    public LearningCurveExtrapolator(LearningCurveExtrapolationMethod learningCurveExtrapolationMethod, ISupervisedLearner<ILabeledInstance, ILabeledDataset<? extends ILabeledInstance>> iSupervisedLearner, ILabeledDataset<?> iLabeledDataset, double d, int[] iArr, ISamplingAlgorithmFactory<ILabeledDataset<?>, ? extends ASamplingAlgorithm<ILabeledDataset<?>>> iSamplingAlgorithmFactory, long j) throws DatasetCreationException, InterruptedException {
        this.extrapolationMethod = learningCurveExtrapolationMethod;
        this.learner = iSupervisedLearner;
        this.dataset = iLabeledDataset;
        this.anchorPoints = iArr;
        this.samplingAlgorithmFactory = iSamplingAlgorithmFactory;
        this.random = new Random(j);
        createSplit(d, j);
        this.yValues = new double[this.anchorPoints.length];
        this.trainingTimes = new int[this.anchorPoints.length];
    }

    public ILearningCurve extrapolateLearningCurve() throws InvalidAnchorPointsException, AlgorithmException, InterruptedException {
        try {
            ILabeledDataset<? extends ILabeledInstance> iLabeledDataset = this.test;
            SupervisedLearnerExecutor supervisedLearnerExecutor = new SupervisedLearnerExecutor();
            EClassificationPerformanceMeasure eClassificationPerformanceMeasure = EClassificationPerformanceMeasure.ERRORRATE;
            for (int i = 0; i < this.anchorPoints.length; i++) {
                if ((this.samplingAlgorithmFactory instanceof IRerunnableSamplingAlgorithmFactory) && this.samplingAlgorithm != null) {
                    ((IRerunnableSamplingAlgorithmFactory) this.samplingAlgorithmFactory).setPreviousRun(this.samplingAlgorithm);
                }
                this.samplingAlgorithm = this.samplingAlgorithmFactory.getAlgorithm(this.anchorPoints[i], this.train, this.random);
                ILabeledDataset<? extends ILabeledInstance> m93call = this.samplingAlgorithm.m93call();
                this.logger.debug("Running classifier with {} data points.", Integer.valueOf(this.anchorPoints[i]));
                ILearnerRunReport execute = supervisedLearnerExecutor.execute(this.learner, m93call, iLabeledDataset);
                this.trainingTimes[i] = (int) (execute.getTrainEndTime() - execute.getTrainStartTime());
                this.yValues[i] = eClassificationPerformanceMeasure.loss(execute.getPredictionDiffList().getCastedView(Integer.class, ISingleLabelClassification.class));
                this.logger.debug("Training finished. Observed learning curve value (accuracy) of {}.", Double.valueOf(this.yValues[i]));
            }
            if (this.logger.isInfoEnabled()) {
                this.logger.info("Computed accuracies of {} for anchor points {}. Now extrapolating a curve from these observations.", Arrays.toString(this.yValues), Arrays.toString(this.anchorPoints));
            }
            return this.extrapolationMethod.extrapolateLearningCurveFromAnchorPoints(this.anchorPoints, this.yValues, this.dataset.size());
        } catch (InvalidAnchorPointsException | InterruptedException e) {
            throw e;
        } catch (ExecutionException e2) {
            throw new AlgorithmException("Error during learning curve extrapolation", e2);
        } catch (Exception e3) {
            throw new AlgorithmException("Error during training/testing the classifier", e3);
        } catch (AlgorithmExecutionCanceledException | TimeoutException | AlgorithmException e4) {
            throw new AlgorithmException("Error during creation of the subsamples for the anchorpoints", e4);
        }
    }

    private void createSplit(double d, long j) throws DatasetCreationException, InterruptedException {
        long currentTimeMillis = System.currentTimeMillis();
        this.logger.debug("Creating split with training portion {} and seed {}", Double.valueOf(d), Long.valueOf(j));
        Random random = new Random(j);
        try {
            List split = new FilterBasedDatasetSplitter(new LabelBasedStratifiedSamplingFactory(), d, random).split(this.dataset);
            this.train = (ILabeledDataset) split.get(0);
            this.test = (ILabeledDataset) split.get(1);
            this.logger.debug("Shuffling train and test data");
            Collections.shuffle(this.train, random);
            Collections.shuffle(this.test, random);
            this.logger.debug("Finished split creation after {}ms", Long.valueOf(System.currentTimeMillis() - currentTimeMillis));
        } catch (SplitFailedException e) {
            throw new DatasetCreationException(e);
        }
    }

    public ISupervisedLearner<ILabeledInstance, ILabeledDataset<? extends ILabeledInstance>> getLearner() {
        return this.learner;
    }

    public ILabeledDataset<?> getDataset() {
        return this.dataset;
    }

    public LearningCurveExtrapolationMethod getExtrapolationMethod() {
        return this.extrapolationMethod;
    }

    public int[] getAnchorPoints() {
        return this.anchorPoints;
    }

    public double[] getyValues() {
        return this.yValues;
    }

    public int[] getTrainingTimes() {
        return this.trainingTimes;
    }

    public String getLoggerName() {
        return this.logger.getName();
    }

    public void setLoggerName(String str) {
        this.logger = LoggerFactory.getLogger(str);
    }
}
