package hex;

import hex.Model;
import hex.SupervisedModel;
import hex.SupervisedModel.SupervisedOutput;
import hex.SupervisedModel.SupervisedParameters;
import hex.genmodel.GenModel;
import water.Key;
import water.fvec.Chunk;
import water.util.JCodeGen;
import water.util.MRUtils;
import water.util.SB;

/* loaded from: input_file:hex/SupervisedModel.class */
public abstract class SupervisedModel<M extends SupervisedModel<M, P, O>, P extends SupervisedParameters, O extends SupervisedOutput> extends Model<M, P, O> {

    /* loaded from: input_file:hex/SupervisedModel$SupervisedOutput.class */
    public static abstract class SupervisedOutput extends Model.Output {
        public long[] _distribution;
        public double[] _priorClassDist;
        public double[] _modelClassDist;

        public SupervisedOutput() {
            this(null);
        }

        public SupervisedOutput(SupervisedModelBuilder supervisedModelBuilder) {
            super(supervisedModelBuilder);
            if (supervisedModelBuilder == null) {
                return;
            }
            this._names = supervisedModelBuilder._train.names();
            this._domains = supervisedModelBuilder._train.domains();
            if (supervisedModelBuilder.isClassifier() && supervisedModelBuilder.isSupervised()) {
                MRUtils.ClassDist doAll = new MRUtils.ClassDist(supervisedModelBuilder._nclass).doAll(supervisedModelBuilder._response);
                this._distribution = doAll.dist();
                this._priorClassDist = doAll.rel_dist();
            } else {
                this._distribution = new long[]{supervisedModelBuilder._train.numRows()};
                this._priorClassDist = new double[]{1.0d};
            }
            this._modelClassDist = this._priorClassDist;
        }

        @Override // hex.Model.Output
        public int nfeatures() {
            return this._names.length - 1;
        }

        @Override // hex.Model.Output
        public boolean isSupervised() {
            return true;
        }

        @Override // hex.Model.Output
        public int nclasses() {
            if (this._distribution == null) {
                return 1;
            }
            return this._distribution.length;
        }

        @Override // hex.Model.Output
        public boolean isClassifier() {
            return nclasses() > 1;
        }

        @Override // hex.Model.Output
        public ModelCategory getModelCategory() {
            return nclasses() == 1 ? ModelCategory.Regression : nclasses() == 2 ? ModelCategory.Binomial : ModelCategory.Multinomial;
        }
    }

    /* loaded from: input_file:hex/SupervisedModel$SupervisedParameters.class */
    public static abstract class SupervisedParameters extends Model.Parameters {
        public String _response_column;
        public float[] _class_sampling_factors;
        public boolean _balance_classes = false;
        public float _max_after_balance_size = 5.0f;
        public int _max_hit_ratio_k = 10;
        public int _max_confusion_matrix_size = 20;
    }

    public SupervisedModel(Key key, P p, O o) {
        super(key, p, o);
    }

    public final double defaultThreshold() {
        if (((SupervisedOutput) this._output).nclasses() != 2 || ((SupervisedOutput) this._output)._training_metrics == null) {
            return 0.5d;
        }
        if (((SupervisedOutput) this._output)._validation_metrics != null && ((ModelMetricsBinomial) ((SupervisedOutput) this._output)._validation_metrics)._auc != null) {
            return ((ModelMetricsBinomial) ((SupervisedOutput) this._output)._validation_metrics)._auc.defaultThreshold();
        }
        if (((SupervisedOutput) this._output)._training_metrics == null || ((ModelMetricsBinomial) ((SupervisedOutput) this._output)._training_metrics)._auc == null) {
            return 0.5d;
        }
        return ((ModelMetricsBinomial) ((SupervisedOutput) this._output)._training_metrics)._auc.defaultThreshold();
    }

    @Override // hex.Model
    public double[] score0(Chunk[] chunkArr, int i, double[] dArr, double[] dArr2) {
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr[i2] = chunkArr[i2].atd(i);
        }
        double[] score0 = score0(dArr, dArr2);
        if (((SupervisedOutput) this._output).isClassifier()) {
            if (((SupervisedParameters) this._parms)._balance_classes) {
                GenModel.correctProbabilities(score0, ((SupervisedOutput) this._output)._priorClassDist, ((SupervisedOutput) this._output)._modelClassDist);
            }
            score0[0] = GenModel.getPrediction(score0, dArr, defaultThreshold());
        }
        return score0;
    }

    @Override // hex.Model
    protected SB toJavaPROB(SB sb) {
        JCodeGen.toStaticVar(sb, "PRIOR_CLASS_DISTRIB", ((SupervisedOutput) this._output)._priorClassDist, "Prior class distribution");
        JCodeGen.toStaticVar(sb, "MODEL_CLASS_DISTRIB", ((SupervisedOutput) this._output)._modelClassDist, "Class distribution used for model building");
        return sb;
    }
}
