package hex.ensemble;

import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.StackedEnsembleModel;
import hex.glm.GLM;
import hex.glm.GLMModel;
import water.DKV;
import water.Job;
import water.Key;
import water.exceptions.H2OIllegalArgumentException;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.util.Log;

/* loaded from: input_file:hex/ensemble/StackedEnsemble.class */
public class StackedEnsemble extends ModelBuilder<StackedEnsembleModel, StackedEnsembleModel.StackedEnsembleParameters, StackedEnsembleModel.StackedEnsembleOutput> {
    StackedEnsembleDriver _driver;
    protected StackedEnsembleModel _model;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/ensemble/StackedEnsemble$StackedEnsembleDriver.class */
    public class StackedEnsembleDriver extends ModelBuilder<StackedEnsembleModel, StackedEnsembleModel.StackedEnsembleParameters, StackedEnsembleModel.StackedEnsembleOutput>.Driver {
        private StackedEnsembleDriver() {
            super(StackedEnsemble.this);
        }

        private Frame prepareLevelOneFrame(StackedEnsembleModel.StackedEnsembleParameters stackedEnsembleParameters) {
            Frame frame = new Frame(Key.make("levelone_" + StackedEnsemble.this._model._key.toString()));
            for (Key<Model> key : ((StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._parms)._base_models) {
                Model get = DKV.getGet(key);
                if (null == get) {
                    Log.warn(new Object[]{"Failed to find base model; skipping: " + key});
                } else {
                    if (null == get._output._cross_validation_holdout_predictions_frame_id) {
                        throw new H2OIllegalArgumentException("Failed to find the xval predictions frame id. . .  Looks like keep_cross_validation_predictions wasn't set when building the models.");
                    }
                    StackedEnsemble.addModelPredictionsToLevelOneFrame(get, get._output._cross_validation_holdout_predictions_frame_id.get(), frame);
                }
            }
            frame.add(StackedEnsemble.this._model.responseColumn, StackedEnsemble.this._model.commonTrainingFrame.vec(StackedEnsemble.this._model.responseColumn));
            Frame get2 = DKV.getGet(frame._key);
            if (get2 != null && (get2 instanceof Frame)) {
                get2.removeAll();
                get2.write_lock(StackedEnsemble.this._job);
                get2.update(StackedEnsemble.this._job);
                get2.unlock(StackedEnsemble.this._job);
            }
            frame.delete_and_lock(StackedEnsemble.this._job);
            frame.unlock(StackedEnsemble.this._job);
            Log.info(new Object[]{"Finished creating \"level one\" frame for stacking: " + frame.toString()});
            return frame;
        }

        public void computeImpl() {
            StackedEnsemble.this.init(true);
            if (StackedEnsemble.this.error_count() > 0) {
                throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(StackedEnsemble.this);
            }
            StackedEnsemble.this._model = new StackedEnsembleModel(StackedEnsemble.this.dest(), (StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._parms, new StackedEnsembleModel.StackedEnsembleOutput(StackedEnsemble.this));
            StackedEnsemble.this._model.delete_and_lock(StackedEnsemble.this._job);
            StackedEnsemble.this._model.checkAndInheritModelProperties();
            Frame prepareLevelOneFrame = prepareLevelOneFrame((StackedEnsembleModel.StackedEnsembleParameters) StackedEnsemble.this._parms);
            Key make = Key.make("metalearner_" + StackedEnsemble.this._model._key);
            GLM glm = (GLM) ModelBuilder.make("GLM", new Job(make, ModelBuilder.javaName("glm"), "StackingEnsemble metalearner (GLM)"), make);
            ((GLMModel.GLMParameters) glm._parms)._non_negative = true;
            ((GLMModel.GLMParameters) glm._parms)._train = prepareLevelOneFrame._key;
            ((GLMModel.GLMParameters) glm._parms)._response_column = StackedEnsemble.this._model.responseColumn;
            ((GLMModel.GLMParameters) glm._parms)._family = StackedEnsemble.this._model.modelCategory == ModelCategory.Regression ? GLMModel.GLMParameters.Family.gaussian : GLMModel.GLMParameters.Family.binomial;
            glm.init(false);
            Job trainModel = glm.trainModel();
            while (trainModel.isRunning()) {
                try {
                    StackedEnsemble.this._job.update(trainModel._work, "training metalearner");
                    Thread.sleep(100L);
                } catch (InterruptedException e) {
                }
            }
            Log.info(new Object[]{"Finished training metalearner model."});
            ((StackedEnsembleModel.StackedEnsembleOutput) StackedEnsemble.this._model._output)._metalearner = glm.get();
            StackedEnsemble.this._model.doScoreMetrics(StackedEnsemble.this._job);
            StackedEnsemble.this._model.update(StackedEnsemble.this._job);
            StackedEnsemble.this._model.unlock(StackedEnsemble.this._job);
        }
    }

    public StackedEnsemble(boolean z) {
        super(new StackedEnsembleModel.StackedEnsembleParameters(), z);
    }

    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Regression, ModelCategory.Binomial};
    }

    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Experimental;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: trainModelImpl, reason: merged with bridge method [inline-methods] */
    public StackedEnsembleDriver m61trainModelImpl() {
        StackedEnsembleDriver stackedEnsembleDriver = new StackedEnsembleDriver();
        this._driver = stackedEnsembleDriver;
        return stackedEnsembleDriver;
    }

    public static void addModelPredictionsToLevelOneFrame(Model model, Frame frame, Frame frame2) {
        if (model._output.isBinomialClassifier()) {
            frame2.add(model._key.toString(), frame.vec(2));
        } else {
            if (model._output.isClassifier()) {
                throw new H2OIllegalArgumentException("Don't yet know how to stack multinomial classifiers: " + model._key);
            }
            if (model._output.isAutoencoder()) {
                throw new H2OIllegalArgumentException("Don't yet know how to stack autoencoders: " + model._key);
            }
            if (!model._output.isSupervised()) {
                throw new H2OIllegalArgumentException("Don't yet know how to stack unsupervised models: " + model._key);
            }
            frame2.add(model._key.toString(), frame.vec("predict"));
        }
    }
}
