package hex.tree.gbm;

import hex.tree.CompressedTree;
import hex.tree.SharedTreeModel;
import hex.tree.gbm.GBMModel;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Ignore;
import org.junit.Test;
import water.DKV;
import water.Key;
import water.ModelSerializationTest;
import water.TestUtil;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.VecUtils;

/* loaded from: input_file:hex/tree/gbm/GBMCheckpointTest.class */
public class GBMCheckpointTest extends TestUtil {
    @BeforeClass
    public static void stall() {
        stall_till_cloudsize(1);
    }

    @Test
    public void testCheckpointReconstruction4Multinomial() {
        testCheckPointReconstruction("smalldata/iris/iris.csv", 4, true, 5, 3);
    }

    @Test
    public void testCheckpointReconstruction4Multinomial2() {
        testCheckPointReconstruction("smalldata/junit/cars_20mpg.csv", 1, true, 5, 3);
    }

    @Test
    public void testCheckpointReconstruction4Binomial() {
        testCheckPointReconstruction("smalldata/logreg/prostate.csv", 1, true, 5, 3);
    }

    @Test
    public void testCheckpointReconstruction4Binomial2() {
        testCheckPointReconstruction("smalldata/junit/cars_20mpg.csv", 7, true, 2, 2);
    }

    @Test(expected = H2OIllegalArgumentException.class)
    @Ignore
    public void testCheckpointWrongParams() {
        testCheckPointReconstruction("smalldata/iris/iris.csv", 4, true, 5, 3, 0.2f, 0.67f);
    }

    @Test
    public void testCheckpointReconstruction4Regression() {
        testCheckPointReconstruction("smalldata/logreg/prostate.csv", 8, false, 5, 3);
    }

    @Test
    public void testCheckpointReconstruction4Regression2() {
        testCheckPointReconstruction("smalldata/junit/cars_20mpg.csv", 1, false, 5, 3);
    }

    private void testCheckPointReconstruction(String str, int i, boolean z, int i2, int i3) {
        testCheckPointReconstruction(str, i, z, i2, i3, 0.632f, 0.632f);
    }

    private void testCheckPointReconstruction(String str, int i, boolean z, int i2, int i3, float f, float f2) {
        Frame parse_test_file = parse_test_file(str);
        Vec remove = parse_test_file.remove("economy");
        if (remove != null) {
            remove.remove();
        }
        DKV.put(parse_test_file);
        if (z) {
            parse_test_file.replace(i, VecUtils.toCategoricalVec(parse_test_file.vec(i))).remove();
            DKV.put(parse_test_file._key, parse_test_file);
        }
        GBMModel gBMModel = null;
        SharedTreeModel sharedTreeModel = null;
        SharedTreeModel sharedTreeModel2 = null;
        try {
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = parse_test_file._key;
            gBMParameters._response_column = parse_test_file.name(i);
            gBMParameters._ntrees = i2;
            gBMParameters._seed = 42L;
            gBMParameters._max_depth = 5;
            gBMParameters._learn_rate_annealing = 0.9d;
            gBMParameters._score_each_iteration = true;
            gBMModel = (GBMModel) new GBM(gBMParameters, Key.make("Initial model")).trainModel().get();
            GBMModel.GBMParameters gBMParameters2 = new GBMModel.GBMParameters();
            gBMParameters2._train = parse_test_file._key;
            gBMParameters2._response_column = parse_test_file.name(i);
            gBMParameters2._ntrees = i2 + i3;
            gBMParameters2._seed = 42L;
            gBMParameters2._checkpoint = gBMModel._key;
            gBMParameters2._score_each_iteration = true;
            gBMParameters2._max_depth = 5;
            gBMParameters2._learn_rate_annealing = 0.9d;
            sharedTreeModel = (GBMModel) new GBM(gBMParameters2, Key.make("Model from checkpoint")).trainModel().get();
            GBMModel.GBMParameters gBMParameters3 = new GBMModel.GBMParameters();
            gBMParameters3._train = parse_test_file._key;
            gBMParameters3._response_column = parse_test_file.name(i);
            gBMParameters3._ntrees = i2 + i3;
            gBMParameters3._seed = 42L;
            gBMParameters3._score_each_iteration = true;
            gBMParameters3._max_depth = 5;
            gBMParameters3._learn_rate_annealing = 0.9d;
            sharedTreeModel2 = (GBMModel) new GBM(gBMParameters3, Key.make("Validation model")).trainModel().get();
            CompressedTree[][] trees = ModelSerializationTest.getTrees(sharedTreeModel);
            CompressedTree[][] trees2 = ModelSerializationTest.getTrees(sharedTreeModel2);
            ModelSerializationTest.assertTreeEquals("The model created from checkpoint and corresponding model created from scratch should have the same trees!", trees, trees2, true);
            for (int i4 = 0; i4 < trees.length; i4++) {
                for (int i5 = 0; i5 < trees[i4].length; i5++) {
                    if (trees[i4][i5] != null) {
                        Assert.assertNotEquals(trees[i4][i5]._key, trees2[i4][i5]._key);
                    }
                }
            }
            if (parse_test_file != null) {
                parse_test_file.delete();
            }
            if (gBMModel != null) {
                gBMModel.delete();
            }
            if (sharedTreeModel != null) {
                sharedTreeModel.delete();
            }
            if (sharedTreeModel2 != null) {
                sharedTreeModel2.delete();
            }
        } catch (Throwable th) {
            if (parse_test_file != null) {
                parse_test_file.delete();
            }
            if (gBMModel != null) {
                gBMModel.delete();
            }
            if (sharedTreeModel != null) {
                sharedTreeModel.delete();
            }
            if (sharedTreeModel2 != null) {
                sharedTreeModel2.delete();
            }
            throw th;
        }
    }

    @Ignore("PUBDEV-1829")
    public void testCheckpointReconstruction4BinomialPUBDEV1829() {
        Frame parse_test_file = parse_test_file("smalldata/jira/gbm_checkpoint_train.csv");
        Frame parse_test_file2 = parse_test_file("smalldata/jira/gbm_checkpoint_valid.csv");
        parse_test_file.remove("name").remove();
        parse_test_file.remove("economy").remove();
        parse_test_file2.remove("name").remove();
        parse_test_file2.remove("economy").remove();
        parse_test_file.add("economy_20mpg", parse_test_file.remove("economy_20mpg"));
        DKV.put(parse_test_file);
        Vec remove = parse_test_file2.remove("economy_20mpg");
        parse_test_file2.add("economy_20mpg", remove);
        DKV.put(parse_test_file2);
        GBMModel gBMModel = null;
        SharedTreeModel sharedTreeModel = null;
        SharedTreeModel sharedTreeModel2 = null;
        try {
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = parse_test_file._key;
            gBMParameters._valid = parse_test_file2._key;
            gBMParameters._response_column = "economy_20mpg";
            gBMParameters._ntrees = 5;
            gBMParameters._max_depth = 5;
            gBMParameters._min_rows = 10.0d;
            gBMParameters._score_each_iteration = true;
            gBMParameters._seed = 42L;
            gBMModel = (GBMModel) new GBM(gBMParameters, Key.make("Initial model")).trainModel().get();
            GBMModel.GBMParameters gBMParameters2 = new GBMModel.GBMParameters();
            gBMParameters2._train = parse_test_file._key;
            gBMParameters2._valid = parse_test_file2._key;
            gBMParameters2._response_column = "economy_20mpg";
            gBMParameters2._ntrees = 10;
            gBMParameters2._checkpoint = gBMModel._key;
            gBMParameters2._score_each_iteration = true;
            gBMParameters2._max_depth = 5;
            gBMParameters2._min_rows = 10.0d;
            gBMParameters2._seed = 42L;
            sharedTreeModel = (GBMModel) new GBM(gBMParameters2, Key.make("Model from checkpoint")).trainModel().get();
            GBMModel.GBMParameters gBMParameters3 = new GBMModel.GBMParameters();
            gBMParameters3._train = parse_test_file._key;
            gBMParameters3._valid = parse_test_file2._key;
            gBMParameters3._response_column = "economy_20mpg";
            gBMParameters3._ntrees = 10;
            gBMParameters3._score_each_iteration = true;
            gBMParameters3._max_depth = 5;
            gBMParameters3._min_rows = 10.0d;
            gBMParameters3._seed = 42L;
            sharedTreeModel2 = (GBMModel) new GBM(gBMParameters3, Key.make("Validation model")).trainModel().get();
            ModelSerializationTest.assertTreeEquals("The model created from checkpoint and corresponding model created from scratch should have the same trees!", ModelSerializationTest.getTrees(sharedTreeModel), ModelSerializationTest.getTrees(sharedTreeModel2), true);
            if (parse_test_file != null) {
                parse_test_file.delete();
            }
            if (parse_test_file2 != null) {
                parse_test_file2.delete();
            }
            if (remove != null) {
                remove.remove();
            }
            if (gBMModel != null) {
                gBMModel.delete();
            }
            if (sharedTreeModel != null) {
                sharedTreeModel.delete();
            }
            if (sharedTreeModel2 != null) {
                sharedTreeModel2.delete();
            }
        } catch (Throwable th) {
            if (parse_test_file != null) {
                parse_test_file.delete();
            }
            if (parse_test_file2 != null) {
                parse_test_file2.delete();
            }
            if (remove != null) {
                remove.remove();
            }
            if (gBMModel != null) {
                gBMModel.delete();
            }
            if (sharedTreeModel != null) {
                sharedTreeModel.delete();
            }
            if (sharedTreeModel2 != null) {
                sharedTreeModel2.delete();
            }
            throw th;
        }
    }
}
