package ai.libs.jaicore.ml.intervaltree;

import ai.libs.jaicore.ml.intervaltree.aggregation.AggressiveAggregator;
import ai.libs.jaicore.ml.intervaltree.aggregation.IntervalAggregator;
import ai.libs.jaicore.ml.intervaltree.util.RQPHelper;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Map;
import org.apache.commons.math3.geometry.euclidean.oned.Interval;
import weka.classifiers.trees.m5.M5Base;
import weka.classifiers.trees.m5.PreConstructedLinearModel;
import weka.classifiers.trees.m5.RuleNode;
import weka.core.DenseInstance;
import weka.core.Instances;

/* loaded from: input_file:ai/libs/jaicore/ml/intervaltree/ExtendedM5Tree.class */
public class ExtendedM5Tree extends M5Base implements RangeQueryPredictor {
    private static final long serialVersionUID = 6099808075887732225L;
    private final IntervalAggregator intervalAggregator;

    public ExtendedM5Tree() {
        this(new AggressiveAggregator());
    }

    public ExtendedM5Tree(IntervalAggregator intervalAggregator) {
        try {
            setOptions(new String[]{"-U"});
            this.intervalAggregator = intervalAggregator;
        } catch (Exception e) {
            throw new IllegalStateException("Couldn't unprune the tree");
        }
    }

    @Override // ai.libs.jaicore.ml.intervaltree.RangeQueryPredictor
    public Interval predictInterval(RQPHelper.IntervalAndHeader intervalAndHeader) {
        Interval[] intervals = intervalAndHeader.getIntervals();
        ArrayDeque arrayDeque = new ArrayDeque();
        arrayDeque.push(RQPHelper.getEntry(intervals, getM5RootNode()));
        ArrayList<Double> arrayList = new ArrayList<>();
        while (arrayDeque.peek() != null) {
            Map.Entry<Interval[], RuleNode> entry = (Map.Entry) arrayDeque.pop();
            RuleNode value = entry.getValue();
            double splitVal = value.splitVal();
            int splitAtt = value.splitAtt();
            if (value.isLeaf()) {
                predictLeaf(arrayList, entry, value, intervalAndHeader.getHeaderInformation());
            } else {
                Interval interval = intervals[splitAtt];
                RuleNode leftNode = value.leftNode();
                RuleNode rightNode = value.rightNode();
                if (interval.getInf() > splitVal) {
                    arrayDeque.push(RQPHelper.getEntry(entry.getKey(), rightNode));
                } else if (splitVal <= interval.getSup()) {
                    arrayDeque.push(RQPHelper.getEntry(RQPHelper.substituteInterval(entry.getKey(), new Interval(interval.getInf(), splitVal), splitAtt), leftNode));
                    arrayDeque.push(RQPHelper.getEntry(RQPHelper.substituteInterval(entry.getKey(), new Interval(splitVal, interval.getSup()), splitAtt), rightNode));
                } else {
                    arrayDeque.push(RQPHelper.getEntry(entry.getKey(), leftNode));
                }
            }
        }
        return this.intervalAggregator.aggregate(arrayList);
    }

    private void predictLeaf(ArrayList<Double> arrayList, Map.Entry<Interval[], RuleNode> entry, RuleNode ruleNode, Instances instances) {
        Interval[] key = entry.getKey();
        PreConstructedLinearModel model = ruleNode.getModel();
        DenseInstance denseInstance = new DenseInstance(key.length + 1);
        DenseInstance denseInstance2 = new DenseInstance(key.length + 1);
        double[] coefficients = model.coefficients();
        for (int i = 0; i < key.length; i++) {
            if (coefficients[i] < 0.0d) {
                denseInstance.setValue(i + 1, key[i].getInf());
                denseInstance2.setValue(i + 1, key[i].getSup());
            } else {
                denseInstance.setValue(i + 1, key[i].getSup());
                denseInstance2.setValue(i + 1, key[i].getInf());
            }
        }
        denseInstance.setValue(0, 1.0d);
        denseInstance2.setValue(0, 1.0d);
        denseInstance.setDataset(instances);
        denseInstance2.setDataset(instances);
        try {
            double classifyInstance = model.classifyInstance(denseInstance);
            double classifyInstance2 = model.classifyInstance(denseInstance2);
            arrayList.add(Double.valueOf(classifyInstance));
            arrayList.add(Double.valueOf(classifyInstance2));
        } catch (Exception e) {
            throw new PredictionFailedException(e);
        }
    }
}
