package hex;

import hex.Model;
import java.util.Arrays;
import java.util.Comparator;
import water.DKV;
import water.Iced;
import water.Key;
import water.Keyed;
import water.Value;
import water.fvec.Frame;
import water.util.TwoDimTable;

/* loaded from: input_file:hex/ModelMetrics.class */
public class ModelMetrics extends Keyed<ModelMetrics> {
    public String _description;
    final Key _modelKey;
    final Key _frameKey;
    final Model.ModelCategory _model_category;
    final long _model_checksum;
    final long _frame_checksum;
    transient Model _model;
    transient Frame _frame;
    public final long _scoring_time;
    public final double _MSE;

    /* loaded from: input_file:hex/ModelMetrics$MetricBuilder.class */
    public static abstract class MetricBuilder<T extends MetricBuilder<T>> extends Iced {
        public transient double[] _work;
        public double _sumsqe;
        public long _count;

        public abstract double[] perRow(double[] dArr, float[] fArr, Model model);

        public void reduce(T t) {
            this._sumsqe += t._sumsqe;
            this._count += t._count;
        }

        public void postGlobal() {
        }

        public abstract ModelMetrics makeModelMetrics(Model model, Frame frame, double d);
    }

    public ModelMetrics(Model model, Frame frame, double d, String str) {
        super(buildKey(model, frame));
        this._description = str;
        this._modelKey = model._key;
        this._frameKey = frame._key;
        this._model_category = model._output.getModelCategory();
        this._model = model;
        this._frame = frame;
        this._model_checksum = model.checksum();
        this._frame_checksum = frame.checksum();
        this._MSE = d;
        this._scoring_time = System.currentTimeMillis();
        DKV.put(this);
    }

    public Model model() {
        if (this._model != null) {
            return this._model;
        }
        Model model = (Model) DKV.getGet(this._modelKey);
        this._model = model;
        return model;
    }

    public Frame frame() {
        if (this._frame != null) {
            return this._frame;
        }
        Frame frame = (Frame) DKV.getGet(this._frameKey);
        this._frame = frame;
        return frame;
    }

    public double mse() {
        return this._MSE;
    }

    public ConfusionMatrix cm() {
        return null;
    }

    public float[] hr() {
        return null;
    }

    public AUC2 auc() {
        return null;
    }

    public static TwoDimTable calcVarImp(VarImp varImp) {
        if (varImp == null) {
            return null;
        }
        double[] dArr = new double[varImp._varimp.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = varImp._varimp[i];
        }
        return calcVarImp(dArr, varImp._names);
    }

    public static TwoDimTable calcVarImp(float[] fArr, String[] strArr) {
        double[] dArr = new double[fArr.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = fArr[i];
        }
        return calcVarImp(dArr, strArr);
    }

    public static TwoDimTable calcVarImp(double[] dArr, String[] strArr) {
        return calcVarImp(dArr, strArr, "Variable Importances", new String[]{"Relative Importance", "Scaled Importance", "Percentage"});
    }

    /* JADX WARN: Type inference failed for: r9v2, types: [java.lang.String[], java.lang.String[][]] */
    public static TwoDimTable calcVarImp(final double[] dArr, String[] strArr, String str, String[] strArr2) {
        if (dArr == null) {
            return null;
        }
        if (strArr == null) {
            strArr = new String[dArr.length];
            for (int i = 0; i < strArr.length; i++) {
                strArr[i] = "C" + String.valueOf(i + 1);
            }
        }
        Integer[] numArr = new Integer[dArr.length];
        for (int i2 = 0; i2 < numArr.length; i2++) {
            numArr[i2] = Integer.valueOf(i2);
        }
        Arrays.sort(numArr, new Comparator<Integer>() { // from class: hex.ModelMetrics.1
            @Override // java.util.Comparator
            public int compare(Integer num, Integer num2) {
                return Double.compare(-dArr[num.intValue()], -dArr[num2.intValue()]);
            }
        });
        double d = 0.0d;
        double d2 = dArr[numArr[0].intValue()];
        String[] strArr3 = new String[dArr.length];
        double[][] dArr2 = new double[dArr.length][3];
        int i3 = 0;
        for (Integer num : numArr) {
            int intValue = num.intValue();
            d += dArr[intValue];
            strArr3[i3] = strArr[intValue];
            dArr2[i3][0] = dArr[intValue];
            int i4 = i3;
            i3++;
            dArr2[i4][1] = dArr[intValue] / d2;
        }
        int i5 = 0;
        for (Integer num2 : numArr) {
            int i6 = i5;
            i5++;
            dArr2[i6][2] = dArr[num2.intValue()] / d;
        }
        String[] strArr4 = new String[3];
        String[] strArr5 = new String[3];
        Arrays.fill(strArr4, "double");
        Arrays.fill(strArr5, "%5f");
        return new TwoDimTable(str, null, strArr3, strArr2, strArr4, strArr5, "Variable", new String[dArr.length], dArr2);
    }

    private static Key<ModelMetrics> buildKey(Key key, long j, Key key2, long j2) {
        return Key.make("modelmetrics_" + key + "@" + j + "_on_" + key2 + "@" + j2);
    }

    public static Key<ModelMetrics> buildKey(Model model, Frame frame) {
        if (frame == null) {
            return null;
        }
        return buildKey(model._key, model.checksum(), frame._key, frame.checksum());
    }

    public boolean isForModel(Model model) {
        return this._model_checksum == model.checksum();
    }

    public boolean isForFrame(Frame frame) {
        return this._frame_checksum == frame.checksum();
    }

    public static ModelMetrics getFromDKV(Model model, Frame frame) {
        Value value = DKV.get(buildKey(model, frame));
        if (null == value) {
            return null;
        }
        return (ModelMetrics) value.get();
    }

    @Override // water.Keyed
    protected long checksum_impl() {
        return (this._frame_checksum * 13) + (this._model_checksum * 17);
    }
}
