package ai.libs.jaicore.ml.intervaltree;

import ai.libs.jaicore.ml.core.FeatureSpace;
import ai.libs.jaicore.ml.intervaltree.aggregation.AggressiveAggregator;
import ai.libs.jaicore.ml.intervaltree.aggregation.IntervalAggregator;
import ai.libs.jaicore.ml.intervaltree.aggregation.QuantileAggregator;
import ai.libs.jaicore.ml.intervaltree.util.RQPHelper;
import java.util.ArrayList;
import java.util.Set;
import org.apache.commons.math3.geometry.euclidean.oned.Interval;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.trees.RandomForest;
import weka.core.Instance;
import weka.core.Instances;

/* loaded from: input_file:ai/libs/jaicore/ml/intervaltree/ExtendedRandomForest.class */
public class ExtendedRandomForest extends RandomForest implements RangeQueryPredictor {
    private static final long serialVersionUID = 8774800172762290733L;
    private static final Logger log = LoggerFactory.getLogger(ExtendedRandomForest.class);
    private final IntervalAggregator forestAggregator;
    private FeatureSpace featureSpace;

    public ExtendedRandomForest() {
        this(new QuantileAggregator(0.15d), new AggressiveAggregator());
    }

    public ExtendedRandomForest(IntervalAggregator intervalAggregator, IntervalAggregator intervalAggregator2) {
        setClassifier(new ExtendedRandomTree(intervalAggregator));
        this.forestAggregator = intervalAggregator2;
    }

    public ExtendedRandomForest(FeatureSpace featureSpace) {
        this();
        this.featureSpace = featureSpace;
        getClassifier().setFeatureSpace(featureSpace);
    }

    public ExtendedRandomForest(IntervalAggregator intervalAggregator, IntervalAggregator intervalAggregator2, FeatureSpace featureSpace) {
        this.forestAggregator = intervalAggregator2;
        this.featureSpace = featureSpace;
        ExtendedRandomTree extendedRandomTree = new ExtendedRandomTree(intervalAggregator);
        extendedRandomTree.setFeatureSpace(featureSpace);
        setClassifier(extendedRandomTree);
    }

    public void prepareForest(Instances instances) {
        this.featureSpace = new FeatureSpace(instances);
        for (ExtendedRandomTree extendedRandomTree : this.m_Classifiers) {
            extendedRandomTree.setFeatureSpace(this.featureSpace);
            extendedRandomTree.preprocess();
        }
    }

    public void printVariances() {
        for (ExtendedRandomTree extendedRandomTree : this.m_Classifiers) {
            log.debug("cur var: {}", Double.valueOf(extendedRandomTree.getTotalVariance()));
        }
    }

    public double computeMarginalVarianceContributionForFeatureSubset(Set<Integer> set) {
        double d = 0.0d;
        for (ExtendedRandomTree extendedRandomTree : this.m_Classifiers) {
            d += (extendedRandomTree.computeMarginalVarianceContributionForSubsetOfFeatures(set) * 1.0d) / this.m_Classifiers.length;
        }
        return d;
    }

    public double computeMarginalVarianceContributionForFeatureSubsetNotNormalized(Set<Integer> set) {
        double d = 0.0d;
        for (ExtendedRandomTree extendedRandomTree : this.m_Classifiers) {
            d += (extendedRandomTree.computeMarginalVarianceContributionForSubsetOfFeaturesNotNormalized(set) * 1.0d) / this.m_Classifiers.length;
        }
        return d;
    }

    public int getSize() {
        return this.m_Classifiers.length;
    }

    public FeatureSpace getFeatureSpace() {
        return this.featureSpace;
    }

    protected String defaultClassifierString() {
        return "jaicore.ml.intervaltree.ExtendedRandomTree";
    }

    public ExtendedRandomForest(int i) {
        this();
        setSeed(i);
    }

    @Override // ai.libs.jaicore.ml.intervaltree.RangeQueryPredictor
    public Interval predictInterval(Instance instance) {
        ArrayList arrayList = new ArrayList(this.m_Classifiers.length * 2);
        for (int i = 0; i < this.m_Classifiers.length; i++) {
            Interval predictInterval = this.m_Classifiers[i].predictInterval(instance);
            arrayList.add(Double.valueOf(predictInterval.getInf()));
            arrayList.add(Double.valueOf(predictInterval.getSup()));
        }
        return this.forestAggregator.aggregate(arrayList);
    }

    @Override // ai.libs.jaicore.ml.intervaltree.RangeQueryPredictor
    public Interval predictInterval(RQPHelper.IntervalAndHeader intervalAndHeader) {
        ArrayList arrayList = new ArrayList(this.m_Classifiers.length * 2);
        for (int i = 0; i < this.m_Classifiers.length; i++) {
            Interval predictInterval = this.m_Classifiers[i].predictInterval(intervalAndHeader);
            arrayList.add(Double.valueOf(predictInterval.getInf()));
            arrayList.add(Double.valueOf(predictInterval.getSup()));
        }
        return this.forestAggregator.aggregate(arrayList);
    }
}
