package hex.ensemble;

import hex.Model;
import hex.StackedEnsembleModel;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.drf.DRF;
import hex.tree.drf.DRFModel;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import java.util.Arrays;
import java.util.HashSet;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.Key;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;

/* loaded from: input_file:hex/ensemble/StackedEnsembleTest.class */
public class StackedEnsembleTest extends TestUtil {
    static final String[] ignored_aircols = {"DepTime", "ArrTime", "AirTime", "ArrDelay", "DepDelay", "TaxiIn", "TaxiOut", "Cancelled", "CancellationCode", "Diverted", "CarrierDelay", "WeatherDelay", "NASDelay", "SecurityDelay", "LateAircraftDelay", "IsDepDelayed"};

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/ensemble/StackedEnsembleTest$PrepData.class */
    public abstract class PrepData {
        private PrepData() {
        }

        abstract int prep(Frame frame);
    }

    @BeforeClass
    public static void stall() {
        stall_till_cloudsize(1);
    }

    @Test
    public void testBasicEnsembleAUTOMetalearner() {
        basicEnsemble("./smalldata/junit/cars.csv", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.1
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                frame.remove("name").remove();
                return frame.find("economy (mpg)") ^ (-1);
            }
        }, false, DistributionFamily.gaussian, StackedEnsembleModel.StackedEnsembleParameters.MetalearnerAlgorithm.AUTO);
        basicEnsemble("./smalldata/airlines/allyears2k_headers.zip", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.2
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                for (String str : StackedEnsembleTest.ignored_aircols) {
                    frame.remove(str).remove();
                }
                return frame.find("IsArrDelayed");
            }
        }, false, DistributionFamily.bernoulli, StackedEnsembleModel.StackedEnsembleParameters.MetalearnerAlgorithm.AUTO);
        basicEnsemble("./smalldata/iris/iris_wheader.csv", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.3
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                return frame.find("class");
            }
        }, false, DistributionFamily.multinomial, StackedEnsembleModel.StackedEnsembleParameters.MetalearnerAlgorithm.AUTO);
        basicEnsemble("./smalldata/logreg/prostate_train.csv", "./smalldata/logreg/prostate_test.csv", new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.4
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                return frame.find("CAPSULE");
            }
        }, false, DistributionFamily.bernoulli, StackedEnsembleModel.StackedEnsembleParameters.MetalearnerAlgorithm.AUTO);
    }

    @Test
    public void testBasicEnsembleGBMMetalearner() {
        basicEnsemble("./smalldata/junit/cars.csv", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.5
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                frame.remove("name").remove();
                return frame.find("economy (mpg)") ^ (-1);
            }
        }, false, DistributionFamily.gaussian, StackedEnsembleModel.StackedEnsembleParameters.MetalearnerAlgorithm.gbm);
        basicEnsemble("./smalldata/airlines/allyears2k_headers.zip", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.6
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                for (String str : StackedEnsembleTest.ignored_aircols) {
                    frame.remove(str).remove();
                }
                return frame.find("IsArrDelayed");
            }
        }, false, DistributionFamily.bernoulli, StackedEnsembleModel.StackedEnsembleParameters.MetalearnerAlgorithm.gbm);
        basicEnsemble("./smalldata/iris/iris_wheader.csv", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.7
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                return frame.find("class");
            }
        }, false, DistributionFamily.multinomial, StackedEnsembleModel.StackedEnsembleParameters.MetalearnerAlgorithm.gbm);
        basicEnsemble("./smalldata/logreg/prostate_train.csv", "./smalldata/logreg/prostate_test.csv", new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.8
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                return frame.find("CAPSULE");
            }
        }, false, DistributionFamily.bernoulli, StackedEnsembleModel.StackedEnsembleParameters.MetalearnerAlgorithm.gbm);
    }

    @Test
    public void testBasicEnsembleDRFMetalearner() {
        basicEnsemble("./smalldata/junit/cars.csv", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.9
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                frame.remove("name").remove();
                return frame.find("economy (mpg)") ^ (-1);
            }
        }, false, DistributionFamily.gaussian, StackedEnsembleModel.StackedEnsembleParameters.MetalearnerAlgorithm.drf);
        basicEnsemble("./smalldata/airlines/allyears2k_headers.zip", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.10
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                for (String str : StackedEnsembleTest.ignored_aircols) {
                    frame.remove(str).remove();
                }
                return frame.find("IsArrDelayed");
            }
        }, false, DistributionFamily.bernoulli, StackedEnsembleModel.StackedEnsembleParameters.MetalearnerAlgorithm.drf);
        basicEnsemble("./smalldata/iris/iris_wheader.csv", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.11
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                return frame.find("class");
            }
        }, false, DistributionFamily.multinomial, StackedEnsembleModel.StackedEnsembleParameters.MetalearnerAlgorithm.drf);
        basicEnsemble("./smalldata/logreg/prostate_train.csv", "./smalldata/logreg/prostate_test.csv", new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.12
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                return frame.find("CAPSULE");
            }
        }, false, DistributionFamily.bernoulli, StackedEnsembleModel.StackedEnsembleParameters.MetalearnerAlgorithm.drf);
    }

    @Test
    public void testBasicEnsembleDeepLearningMetalearner() {
        basicEnsemble("./smalldata/junit/cars.csv", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.13
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                frame.remove("name").remove();
                return frame.find("economy (mpg)") ^ (-1);
            }
        }, false, DistributionFamily.gaussian, StackedEnsembleModel.StackedEnsembleParameters.MetalearnerAlgorithm.deeplearning);
        basicEnsemble("./smalldata/airlines/allyears2k_headers.zip", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.14
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                for (String str : StackedEnsembleTest.ignored_aircols) {
                    frame.remove(str).remove();
                }
                return frame.find("IsArrDelayed");
            }
        }, false, DistributionFamily.bernoulli, StackedEnsembleModel.StackedEnsembleParameters.MetalearnerAlgorithm.deeplearning);
        basicEnsemble("./smalldata/iris/iris_wheader.csv", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.15
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                return frame.find("class");
            }
        }, false, DistributionFamily.multinomial, StackedEnsembleModel.StackedEnsembleParameters.MetalearnerAlgorithm.deeplearning);
        basicEnsemble("./smalldata/logreg/prostate_train.csv", "./smalldata/logreg/prostate_test.csv", new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.16
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                return frame.find("CAPSULE");
            }
        }, false, DistributionFamily.bernoulli, StackedEnsembleModel.StackedEnsembleParameters.MetalearnerAlgorithm.deeplearning);
    }

    @Test
    public void testBasicEnsembleGLMMetalearner() {
        basicEnsemble("./smalldata/junit/cars.csv", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.17
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                frame.remove("name").remove();
                return frame.find("economy (mpg)") ^ (-1);
            }
        }, false, DistributionFamily.gaussian, StackedEnsembleModel.StackedEnsembleParameters.MetalearnerAlgorithm.glm);
        basicEnsemble("./smalldata/junit/test_tree_minmax.csv", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.18
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                return frame.find("response");
            }
        }, false, DistributionFamily.bernoulli, StackedEnsembleModel.StackedEnsembleParameters.MetalearnerAlgorithm.glm);
        basicEnsemble("./smalldata/logreg/prostate.csv", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.19
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                frame.remove("ID").remove();
                return frame.find("CAPSULE");
            }
        }, false, DistributionFamily.bernoulli, StackedEnsembleModel.StackedEnsembleParameters.MetalearnerAlgorithm.glm);
        basicEnsemble("./smalldata/logreg/prostate_train.csv", "./smalldata/logreg/prostate_test.csv", new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.20
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                return frame.find("CAPSULE");
            }
        }, false, DistributionFamily.bernoulli, StackedEnsembleModel.StackedEnsembleParameters.MetalearnerAlgorithm.glm);
        basicEnsemble("./smalldata/gbm_test/alphabet_cattest.csv", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.21
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                return frame.find("y");
            }
        }, false, DistributionFamily.bernoulli, StackedEnsembleModel.StackedEnsembleParameters.MetalearnerAlgorithm.glm);
        basicEnsemble("./smalldata/airlines/allyears2k_headers.zip", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.22
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                for (String str : StackedEnsembleTest.ignored_aircols) {
                    frame.remove(str).remove();
                }
                return frame.find("IsArrDelayed");
            }
        }, false, DistributionFamily.bernoulli, StackedEnsembleModel.StackedEnsembleParameters.MetalearnerAlgorithm.glm);
        basicEnsemble("./smalldata/logreg/prostate.csv", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.23
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                frame.remove("ID").remove();
                return frame.find("RACE");
            }
        }, false, DistributionFamily.multinomial, StackedEnsembleModel.StackedEnsembleParameters.MetalearnerAlgorithm.glm);
        basicEnsemble("./smalldata/junit/cars.csv", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.24
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                frame.remove("name").remove();
                return frame.find("cylinders");
            }
        }, false, DistributionFamily.multinomial, StackedEnsembleModel.StackedEnsembleParameters.MetalearnerAlgorithm.glm);
        basicEnsemble("./smalldata/iris/iris_wheader.csv", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.25
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                return frame.find("class");
            }
        }, false, DistributionFamily.multinomial, StackedEnsembleModel.StackedEnsembleParameters.MetalearnerAlgorithm.glm);
    }

    public StackedEnsembleModel.StackedEnsembleOutput basicEnsemble(String str, String str2, PrepData prepData, boolean z, DistributionFamily distributionFamily, StackedEnsembleModel.StackedEnsembleParameters.MetalearnerAlgorithm metalearnerAlgorithm) {
        HashSet hashSet = new HashSet();
        hashSet.addAll(Arrays.asList(Frame.fetchAll()));
        GBMModel gBMModel = null;
        DRFModel dRFModel = null;
        StackedEnsembleModel stackedEnsembleModel = null;
        Frame frame = null;
        try {
            Scope.enter();
            frame = parse_test_file(str);
            r18 = null != str2 ? parse_test_file(str2) : null;
            int prep = prepData.prep(frame);
            if (null != r18) {
                prepData.prep(r18);
            }
            if ((distributionFamily == DistributionFamily.bernoulli || distributionFamily == DistributionFamily.multinomial || distributionFamily == DistributionFamily.modified_huber) && !frame.vecs()[prep].isCategorical()) {
                Scope.track(frame.replace(prep, frame.vecs()[prep].toCategoricalVec()));
                if (null != r18) {
                    Scope.track(r18.replace(prep, r18.vecs()[prep].toCategoricalVec()));
                }
            }
            DKV.put(frame);
            if (null != r18) {
                DKV.put(r18);
            }
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            if (prep < 0) {
                prep ^= -1;
            }
            gBMParameters._train = frame._key;
            gBMParameters._valid = r18 == null ? null : ((Frame) r18)._key;
            gBMParameters._response_column = frame._names[prep];
            gBMParameters._ntrees = 5;
            gBMParameters._distribution = distributionFamily;
            gBMParameters._max_depth = 4;
            gBMParameters._min_rows = 1.0d;
            gBMParameters._nbins = 50;
            gBMParameters._learn_rate = 0.20000000298023224d;
            gBMParameters._score_each_iteration = true;
            gBMParameters._fold_assignment = Model.Parameters.FoldAssignmentScheme.Modulo;
            gBMParameters._keep_cross_validation_predictions = true;
            gBMParameters._nfolds = 5;
            if (z) {
                r18 = new Frame(frame);
                DKV.put(r18);
                gBMParameters._valid = ((Frame) r18)._key;
            }
            GBM gbm = new GBM(gBMParameters);
            gBMModel = (GBMModel) gbm.trainModel().get();
            Assert.assertTrue(gbm.isStopped());
            DRFModel.DRFParameters dRFParameters = new DRFModel.DRFParameters();
            dRFParameters._train = frame._key;
            dRFParameters._valid = r18 == null ? null : ((Frame) r18)._key;
            dRFParameters._response_column = frame._names[prep];
            dRFParameters._distribution = distributionFamily;
            dRFParameters._ntrees = 5;
            dRFParameters._max_depth = 4;
            dRFParameters._min_rows = 1.0d;
            dRFParameters._nbins = 50;
            dRFParameters._score_each_iteration = true;
            dRFParameters._fold_assignment = Model.Parameters.FoldAssignmentScheme.Modulo;
            dRFParameters._keep_cross_validation_predictions = true;
            dRFParameters._nfolds = 5;
            DRF drf = new DRF(dRFParameters);
            dRFModel = (DRFModel) drf.trainModel().get();
            Assert.assertTrue(drf.isStopped());
            StackedEnsembleModel.StackedEnsembleParameters stackedEnsembleParameters = new StackedEnsembleModel.StackedEnsembleParameters();
            stackedEnsembleParameters._train = frame._key;
            stackedEnsembleParameters._valid = r18 == null ? null : ((Frame) r18)._key;
            stackedEnsembleParameters._response_column = frame._names[prep];
            stackedEnsembleParameters._metalearner_algorithm = metalearnerAlgorithm;
            stackedEnsembleParameters._base_models = new Key[]{gBMModel._key, dRFModel._key};
            StackedEnsemble stackedEnsemble = new StackedEnsemble(stackedEnsembleParameters);
            stackedEnsembleModel = (StackedEnsembleModel) stackedEnsemble.trainModel().get();
            Assert.assertTrue(stackedEnsemble.isStopped());
            StackedEnsembleModel.StackedEnsembleOutput stackedEnsembleOutput = stackedEnsembleModel._output;
            if (frame != null) {
                frame.remove();
            }
            if (r18 != null) {
                r18.remove();
            }
            if (gBMModel != null) {
                gBMModel.delete();
                for (Key key : gBMModel._output._cross_validation_predictions) {
                    key.remove();
                }
                gBMModel._output._cross_validation_holdout_predictions_frame_id.remove();
                gBMModel.deleteCrossValidationModels();
            }
            if (dRFModel != null) {
                dRFModel.delete();
                for (Key key2 : dRFModel._output._cross_validation_predictions) {
                    key2.remove();
                }
                dRFModel._output._cross_validation_holdout_predictions_frame_id.remove();
                dRFModel.deleteCrossValidationModels();
            }
            HashSet hashSet2 = new HashSet(hashSet);
            hashSet2.removeAll(Arrays.asList(Frame.fetchAll()));
            Assert.assertEquals("finish with the same number of Frames as we started: " + hashSet2, 0L, hashSet2.size());
            if (stackedEnsembleModel != null) {
                stackedEnsembleModel.delete();
                stackedEnsembleModel.remove();
                stackedEnsembleModel._output._metalearner._output._training_metrics.remove();
                stackedEnsembleModel._output._metalearner.remove();
                stackedEnsembleModel._output._metalearner.delete();
                if (stackedEnsembleModel._output._levelone_frame_id != null) {
                    stackedEnsembleModel._output._levelone_frame_id.remove();
                }
            }
            Scope.exit(new Key[0]);
            return stackedEnsembleOutput;
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            if (r18 != null) {
                r18.remove();
            }
            if (gBMModel != null) {
                gBMModel.delete();
                for (Key key3 : gBMModel._output._cross_validation_predictions) {
                    key3.remove();
                }
                gBMModel._output._cross_validation_holdout_predictions_frame_id.remove();
                gBMModel.deleteCrossValidationModels();
            }
            if (dRFModel != null) {
                dRFModel.delete();
                for (Key key4 : dRFModel._output._cross_validation_predictions) {
                    key4.remove();
                }
                dRFModel._output._cross_validation_holdout_predictions_frame_id.remove();
                dRFModel.deleteCrossValidationModels();
            }
            HashSet hashSet3 = new HashSet(hashSet);
            hashSet3.removeAll(Arrays.asList(Frame.fetchAll()));
            Assert.assertEquals("finish with the same number of Frames as we started: " + hashSet3, 0L, hashSet3.size());
            if (stackedEnsembleModel != null) {
                stackedEnsembleModel.delete();
                stackedEnsembleModel.remove();
                stackedEnsembleModel._output._metalearner._output._training_metrics.remove();
                stackedEnsembleModel._output._metalearner.remove();
                stackedEnsembleModel._output._metalearner.delete();
                if (stackedEnsembleModel._output._levelone_frame_id != null) {
                    stackedEnsembleModel._output._levelone_frame_id.remove();
                }
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }
}
