package hex.ensemble;

import hex.Model;
import hex.Model.Parameters;
import hex.ModelBuilder;
import hex.ensemble.StackedEnsembleModel;
import water.DKV;
import water.Job;
import water.Key;
import water.fvec.Frame;
import water.util.Log;

/* loaded from: input_file:hex/ensemble/Metalearner.class */
public abstract class Metalearner<B extends ModelBuilder<M, P, ?>, M extends Model<M, P, ?>, P extends Model.Parameters> {
    protected Frame _levelOneTrainingFrame;
    protected Frame _levelOneValidationFrame;
    protected StackedEnsembleModel _model;
    protected StackedEnsembleModel.StackedEnsembleParameters _parms;
    protected Job _job;
    protected Key<Model> _metalearnerKey;
    protected Job _metalearnerJob;
    protected P _metalearner_parameters;
    protected boolean _hasMetalearnerParams;
    protected long _metalearnerSeed;
    protected long _maxRuntimeSecs;

    /* loaded from: input_file:hex/ensemble/Metalearner$Algorithm.class */
    public enum Algorithm {
        AUTO,
        deeplearning,
        drf,
        gbm,
        glm,
        naivebayes,
        xgboost
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void init(Frame frame, Frame frame2, P p, StackedEnsembleModel stackedEnsembleModel, Job job, Key<Model> key, Job job2, StackedEnsembleModel.StackedEnsembleParameters stackedEnsembleParameters, boolean z, long j, long j2) {
        this._levelOneTrainingFrame = frame;
        this._levelOneValidationFrame = frame2;
        this._metalearner_parameters = p;
        this._model = stackedEnsembleModel;
        this._job = job;
        this._metalearnerKey = key;
        this._metalearnerJob = job2;
        this._parms = stackedEnsembleParameters;
        this._hasMetalearnerParams = z;
        this._metalearnerSeed = j;
        this._maxRuntimeSecs = j2;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* JADX WARN: Multi-variable type inference failed */
    public void compute() {
        try {
            this._model.write_lock(this._job);
            ModelBuilder createBuilder = createBuilder();
            if (this._hasMetalearnerParams) {
                createBuilder._parms = this._metalearner_parameters;
            }
            setCommonParams(createBuilder._parms);
            setCrossValidationParams(createBuilder._parms);
            setCustomParams(createBuilder._parms);
            validateParams(createBuilder._parms);
            createBuilder.init(false);
            Job trainModel = createBuilder.trainModel();
            while (trainModel.isRunning()) {
                try {
                    this._job.update(trainModel.getWork(), "training metalearner(" + ((StackedEnsembleModel.StackedEnsembleParameters) this._model._parms)._metalearner_algorithm + ")");
                    Thread.sleep(100L);
                } catch (InterruptedException e) {
                }
            }
            Log.info(new Object[]{"Finished training metalearner model(" + ((StackedEnsembleModel.StackedEnsembleParameters) this._model._parms)._metalearner_algorithm + ")."});
            ((StackedEnsembleModel.StackedEnsembleOutput) this._model._output)._metalearner = createBuilder.get();
            this._model._dist = ((StackedEnsembleModel.StackedEnsembleOutput) this._model._output)._metalearner._dist;
            this._model.doScoreOrCopyMetrics(this._job);
            if (this._parms._keep_levelone_frame) {
                ((StackedEnsembleModel.StackedEnsembleOutput) this._model._output)._levelone_frame_id = this._levelOneTrainingFrame;
            }
        } finally {
            cleanup();
            this._model.update(this._job);
            this._model.unlock(this._job);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public abstract B createBuilder();

    protected void setCommonParams(P p) {
        if (((Model.Parameters) p)._seed == -1) {
            ((Model.Parameters) p)._seed = this._metalearnerSeed;
        }
        ((Model.Parameters) p)._train = this._levelOneTrainingFrame._key;
        ((Model.Parameters) p)._valid = this._levelOneValidationFrame == null ? null : this._levelOneValidationFrame._key;
        ((Model.Parameters) p)._response_column = this._model.responseColumn;
        ((Model.Parameters) p)._max_runtime_secs = this._maxRuntimeSecs;
        ((Model.Parameters) p)._weights_column = ((StackedEnsembleModel.StackedEnsembleParameters) this._model._parms)._weights_column;
        ((Model.Parameters) p)._offset_column = ((StackedEnsembleModel.StackedEnsembleParameters) this._model._parms)._offset_column;
        ((Model.Parameters) p)._main_model_time_budget_factor = ((StackedEnsembleModel.StackedEnsembleParameters) this._model._parms)._main_model_time_budget_factor;
    }

    protected void setCrossValidationParams(P p) {
        if (((StackedEnsembleModel.StackedEnsembleParameters) this._model._parms)._metalearner_fold_column != null) {
            ((Model.Parameters) p)._fold_column = ((StackedEnsembleModel.StackedEnsembleParameters) this._model._parms)._metalearner_fold_column;
            return;
        }
        ((Model.Parameters) p)._nfolds = ((StackedEnsembleModel.StackedEnsembleParameters) this._model._parms)._metalearner_nfolds;
        if (((StackedEnsembleModel.StackedEnsembleParameters) this._model._parms)._metalearner_nfolds > 1) {
            if (((StackedEnsembleModel.StackedEnsembleParameters) this._model._parms)._metalearner_fold_assignment == null) {
                ((Model.Parameters) p)._fold_assignment = Model.Parameters.FoldAssignmentScheme.AUTO;
            } else {
                ((Model.Parameters) p)._fold_assignment = ((StackedEnsembleModel.StackedEnsembleParameters) this._model._parms)._metalearner_fold_assignment;
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setCustomParams(P p) {
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void validateParams(P p) {
    }

    protected void cleanup() {
        if (!this._parms._keep_base_model_predictions) {
            this._model.deleteBaseModelPredictions();
        }
        if (!this._parms._keep_levelone_frame) {
            DKV.remove(this._levelOneTrainingFrame._key);
        }
        if (null != this._levelOneValidationFrame) {
            DKV.remove(this._levelOneValidationFrame._key);
        }
    }
}
