package hex.grid;

import hex.Model;
import hex.genmodel.utils.DistributionFamily;
import hex.grid.Grid;
import hex.grid.GridSearch;
import hex.grid.HyperSpaceSearchCriteria;
import hex.tree.gbm.GBMModel;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Paths;
import java.util.HashMap;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import water.AutoBuffer;
import water.Job;
import water.Key;
import water.Keyed;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;
import water.util.PojoUtils;

/* loaded from: input_file:hex/grid/GridTest.class */
public class GridTest extends TestUtil {

    @Rule
    public TemporaryFolder temporaryFolder = new TemporaryFolder();

    @Before
    public void setUp() {
        TestUtil.stall_till_cloudsize(1);
    }

    @Test
    public void testParallelModelTimeConstraint() {
        try {
            Scope.enter();
            Frame parse_test_file = parse_test_file("./smalldata/testng/airlines_train.csv");
            Scope.track(new Frame[]{parse_test_file});
            final Integer[] numArr = {5, 50, 7, 8, 9, 10, 500};
            final Integer[] numArr2 = {2, 3, 4};
            HashMap<String, Object[]> hashMap = new HashMap<String, Object[]>() { // from class: hex.grid.GridTest.1
                {
                    put("_distribution", new DistributionFamily[]{DistributionFamily.multinomial});
                    put("_ntrees", numArr);
                    put("_max_depth", numArr2);
                }
            };
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = parse_test_file._key;
            gBMParameters._response_column = "IsDepDelayed";
            gBMParameters._seed = 42L;
            HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria randomDiscreteValueSearchCriteria = new HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria();
            Job startGridSearch = GridSearch.startGridSearch((Key) null, gBMParameters, hashMap, new GridSearch.SimpleParametersBuilderFactory(), randomDiscreteValueSearchCriteria, 2);
            randomDiscreteValueSearchCriteria.set_max_runtime_secs(1.0d);
            Scope.track_generic(startGridSearch);
            Scope.track_generic(startGridSearch.get());
            Assert.assertNotEquals(numArr.length * numArr2.length, r0.getModelCount());
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testParallelUserStopRequest() {
        try {
            Scope.enter();
            Frame parse_test_file = parse_test_file("./smalldata/testng/airlines_train.csv");
            Scope.track(new Frame[]{parse_test_file});
            final Integer[] numArr = {5, 50, 7, 8, 9, 10, 500};
            final Integer[] numArr2 = {2, 3, 4, 50};
            HashMap<String, Object[]> hashMap = new HashMap<String, Object[]>() { // from class: hex.grid.GridTest.2
                {
                    put("_distribution", new DistributionFamily[]{DistributionFamily.multinomial});
                    put("_ntrees", numArr);
                    put("_max_depth", numArr2);
                }
            };
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = parse_test_file._key;
            gBMParameters._response_column = "IsDepDelayed";
            gBMParameters._seed = 42L;
            Job startGridSearch = GridSearch.startGridSearch((Key) null, gBMParameters, hashMap, new GridSearch.SimpleParametersBuilderFactory(), new HyperSpaceSearchCriteria.CartesianSearchCriteria(), 2);
            Scope.track_generic(startGridSearch);
            startGridSearch.stop();
            Grid grid = startGridSearch.get();
            Scope.track_generic(grid);
            for (Keyed keyed : grid.getModels()) {
                Scope.track_generic(keyed);
            }
            Assert.assertNotEquals(numArr.length * numArr2.length, grid.getModelCount());
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testParallelGridSearch() {
        try {
            Scope.enter();
            Frame parse_test_file = parse_test_file("./smalldata/testng/airlines_train.csv");
            Scope.track(new Frame[]{parse_test_file});
            final Integer[] numArr = {5, 50, 7, 8, 9, 10};
            final Integer[] numArr2 = {2, 3, 4};
            HashMap<String, Object[]> hashMap = new HashMap<String, Object[]>() { // from class: hex.grid.GridTest.3
                {
                    put("_distribution", new DistributionFamily[]{DistributionFamily.multinomial});
                    put("_ntrees", numArr);
                    put("_max_depth", numArr2);
                }
            };
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = parse_test_file._key;
            gBMParameters._response_column = "IsDepDelayed";
            gBMParameters._seed = 42L;
            Job startGridSearch = GridSearch.startGridSearch((Key) null, gBMParameters, hashMap, 5);
            Scope.track_generic(startGridSearch);
            Scope.track_generic(startGridSearch.get());
            Assert.assertEquals(numArr.length * numArr2.length, r0.getModelCount());
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testAdaptiveParallelGridSearch() {
        try {
            Scope.enter();
            Frame parse_test_file = parse_test_file("./smalldata/testng/airlines_train.csv");
            Scope.track(new Frame[]{parse_test_file});
            final Integer[] numArr = {5, 50, 7, 8, 9, 10};
            final Integer[] numArr2 = {2, 3, 4};
            HashMap<String, Object[]> hashMap = new HashMap<String, Object[]>() { // from class: hex.grid.GridTest.4
                {
                    put("_distribution", new DistributionFamily[]{DistributionFamily.multinomial});
                    put("_ntrees", numArr);
                    put("_max_depth", numArr2);
                }
            };
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = parse_test_file._key;
            gBMParameters._response_column = "IsDepDelayed";
            gBMParameters._seed = 42L;
            Job startGridSearch = GridSearch.startGridSearch((Key) null, gBMParameters, hashMap, GridSearch.getAdaptiveParallelism());
            Scope.track_generic(startGridSearch);
            Scope.track_generic(startGridSearch.get());
            Assert.assertEquals(numArr.length * numArr2.length, r0.getModelCount());
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testFaileH2OdParamsCleanup() {
        try {
            Scope.enter();
            Frame parse_test_file = parse_test_file("smalldata/iris/iris_train.csv");
            Scope.track(new Frame[]{parse_test_file});
            HashMap<String, Object[]> hashMap = new HashMap<String, Object[]>() { // from class: hex.grid.GridTest.5
                {
                    put("_distribution", new DistributionFamily[]{DistributionFamily.multinomial});
                    put("_ntrees", new Integer[]{5});
                    put("_max_depth", new Integer[]{2});
                    put("_min_rows", new Integer[]{5000000});
                    put("_learn_rate", new Double[]{Double.valueOf(0.7d)});
                }
            };
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = parse_test_file._key;
            gBMParameters._response_column = "species";
            Job startGridSearch = GridSearch.startGridSearch((Key) null, gBMParameters, hashMap);
            Scope.track_generic(startGridSearch);
            Grid grid = startGridSearch.get();
            Scope.track_generic(grid);
            Assert.assertEquals(0L, grid.getModelCount());
            Grid.SearchFailure failures = grid.getFailures();
            Assert.assertEquals(1L, failures.getFailureCount());
            Assert.assertEquals(1L, failures.getFailedParameters().length);
            Assert.assertEquals(1L, failures.getFailedRawParameters().length);
            Assert.assertEquals(1L, failures.getFailureDetails().length);
            Assert.assertEquals(1L, failures.getFailureStackTraces().length);
            Assert.assertTrue(failures.getFailureStackTraces()[0].contains("Details: ERRR on field: _min_rows: The dataset size is too small to split for min_rows=5000000.0: must have at least 1.0E7 (weighted) rows"));
            hashMap.put("_min_rows", new Integer[]{10});
            Job startGridSearch2 = GridSearch.startGridSearch(grid._key, gBMParameters, hashMap);
            Scope.track_generic(startGridSearch2);
            Grid grid2 = startGridSearch2.get();
            Scope.track_generic(grid2);
            Assert.assertEquals(1L, grid2.getModelCount());
            Assert.assertTrue(grid2.getModels()[0] instanceof GBMModel);
            Assert.assertEquals(10.0d, grid2.getModels()[0]._parms._min_rows, 0.0d);
            Grid.SearchFailure failures2 = grid2.getFailures();
            Assert.assertEquals(0L, failures2.getFailureCount());
            Assert.assertEquals(0L, failures2.getFailedParameters().length);
            Assert.assertEquals(0L, failures2.getFailedRawParameters().length);
            Assert.assertEquals(0L, failures2.getFailureDetails().length);
            Assert.assertEquals(0L, failures2.getFailureStackTraces().length);
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void gridSearchExportCheckpointsDir() throws IOException {
        try {
            Scope.enter();
            Frame parse_test_file = parse_test_file("smalldata/iris/iris_train.csv");
            Scope.track(new Frame[]{parse_test_file});
            HashMap<String, Object[]> hashMap = new HashMap<String, Object[]>() { // from class: hex.grid.GridTest.6
                {
                    put("_distribution", new DistributionFamily[]{DistributionFamily.multinomial});
                    put("_ntrees", new Integer[]{5});
                    put("_max_depth", new Integer[]{2});
                    put("_learn_rate", new Double[]{Double.valueOf(0.7d)});
                }
            };
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = parse_test_file._key;
            gBMParameters._response_column = "species";
            gBMParameters._export_checkpoints_dir = this.temporaryFolder.newFolder().getAbsolutePath();
            Job startGridSearch = GridSearch.startGridSearch((Key) null, gBMParameters, hashMap);
            Scope.track_generic(startGridSearch);
            Grid grid = startGridSearch.get();
            Scope.track_generic(grid);
            File file = new File(gBMParameters._export_checkpoints_dir, grid._key.toString());
            Assert.assertTrue(file.exists());
            Assert.assertTrue(file.isFile());
            Grid loadGridFromFile = loadGridFromFile(file);
            Assert.assertArrayEquals(grid.getModelKeys(), loadGridFromFile.getModelKeys());
            Scope.track_generic(loadGridFromFile);
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void gridSearchManualExport() throws IOException {
        try {
            Scope.enter();
            Frame parse_test_file = parse_test_file("smalldata/iris/iris_train.csv");
            Scope.track(new Frame[]{parse_test_file});
            HashMap<String, Object[]> hashMap = new HashMap<String, Object[]>() { // from class: hex.grid.GridTest.7
                {
                    put("_distribution", new DistributionFamily[]{DistributionFamily.multinomial});
                    put("_ntrees", new Integer[]{5});
                    put("_max_depth", new Integer[]{2});
                    put("_learn_rate", new Double[]{Double.valueOf(0.7d)});
                }
            };
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = parse_test_file._key;
            gBMParameters._response_column = "species";
            String absolutePath = this.temporaryFolder.newFolder().getAbsolutePath();
            Job startGridSearch = GridSearch.startGridSearch((Key) null, gBMParameters, hashMap, 1);
            Scope.track_generic(startGridSearch);
            Grid grid = startGridSearch.get();
            Scope.track_generic(grid);
            String str = absolutePath + "/" + grid._key.toString();
            grid.exportBinary(str);
            Assert.assertTrue(Files.exists(Paths.get(str, new String[0]), new LinkOption[0]));
            grid.exportModelsBinary(absolutePath);
            for (Model model : grid.getModels()) {
                Assert.assertTrue(Files.exists(Paths.get(absolutePath, model._key.toString()), new LinkOption[0]));
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void gridSearchExportCheckpointsDirParallel() throws IOException {
        try {
            Scope.enter();
            Frame parse_test_file = parse_test_file("smalldata/iris/iris_train.csv");
            Scope.track(new Frame[]{parse_test_file});
            HashMap<String, Object[]> hashMap = new HashMap<String, Object[]>() { // from class: hex.grid.GridTest.8
                {
                    put("_distribution", new DistributionFamily[]{DistributionFamily.multinomial});
                    put("_ntrees", new Integer[]{5, 10, 50});
                    put("_max_depth", new Integer[]{2, 3});
                    put("_learn_rate", new Double[]{Double.valueOf(0.7d)});
                }
            };
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = parse_test_file._key;
            gBMParameters._response_column = "species";
            gBMParameters._export_checkpoints_dir = this.temporaryFolder.newFolder().getAbsolutePath();
            Job startGridSearch = GridSearch.startGridSearch((Key) null, gBMParameters, hashMap, 2);
            Scope.track_generic(startGridSearch);
            Grid grid = startGridSearch.get();
            Scope.track_generic(grid);
            File file = new File(gBMParameters._export_checkpoints_dir, grid._key.toString());
            Assert.assertTrue(file.exists());
            Assert.assertTrue(file.isFile());
            Grid loadGridFromFile = loadGridFromFile(file);
            Assert.assertArrayEquals(grid.getModelKeys(), loadGridFromFile.getModelKeys());
            Scope.track_generic(loadGridFromFile);
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    private static Grid loadGridFromFile(File file) throws IOException {
        FileInputStream fileInputStream = new FileInputStream(file);
        Throwable th = null;
        try {
            try {
                Grid grid = new AutoBuffer(fileInputStream).get();
                Assert.assertTrue(grid instanceof Grid);
                Grid grid2 = grid;
                if (fileInputStream != null) {
                    if (0 != 0) {
                        try {
                            fileInputStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        fileInputStream.close();
                    }
                }
                return grid2;
            } finally {
            }
        } catch (Throwable th3) {
            if (fileInputStream != null) {
                if (th != null) {
                    try {
                        fileInputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    fileInputStream.close();
                }
            }
            throw th3;
        }
    }

    @Test
    public void testFailedParamsRetention() {
        try {
            Scope.enter();
            Frame parse_test_file = parse_test_file("smalldata/iris/iris_train.csv");
            Scope.track(new Frame[]{parse_test_file});
            HashMap<String, Object[]> hashMap = new HashMap<String, Object[]>() { // from class: hex.grid.GridTest.9
                {
                    put("_distribution", new DistributionFamily[]{DistributionFamily.multinomial});
                    put("_ntrees", new Integer[]{5});
                    put("_max_depth", new Integer[]{2});
                    put("_min_rows", new Integer[]{10});
                    put("_learn_rate", new Double[]{Double.valueOf(0.7d)});
                }
            };
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = parse_test_file._key;
            gBMParameters._response_column = "species";
            Job startGridSearch = GridSearch.startGridSearch((Key) null, gBMParameters, hashMap);
            Scope.track_generic(startGridSearch);
            Grid grid = startGridSearch.get();
            Scope.track_generic(grid);
            Assert.assertEquals(1L, grid.getModelCount());
            Assert.assertEquals(0L, grid.getFailures().getFailureCount());
            grid.appendFailedModelParameters(grid.getModels()[0]._key, gBMParameters, new RuntimeException("Test exception"));
            Grid.SearchFailure failures = grid.getFailures();
            Assert.assertEquals(1L, failures.getFailureCount());
            hashMap.put("_learn_rate", new Double[]{Double.valueOf(0.5d)});
            Job startGridSearch2 = GridSearch.startGridSearch(grid._key, gBMParameters, hashMap);
            Scope.track_generic(startGridSearch2);
            Grid grid2 = startGridSearch2.get();
            Scope.track_generic(grid2);
            Assert.assertEquals(2L, grid2.getModelCount());
            Assert.assertTrue(grid2.getModels()[0] instanceof GBMModel);
            Assert.assertTrue(grid2.getModels()[1] instanceof GBMModel);
            Assert.assertEquals(1L, grid2.getFailures().getFailureCount());
            Assert.assertTrue(failures.getFailureStackTraces()[0].contains("Test exception"));
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testGetModelKeys() {
        Grid grid = new Grid((Key) null, (Model.Parameters) null, (String[]) null, (PojoUtils.FieldNaming) null);
        grid.putModel(3L, Key.make("2"));
        grid.putModel(2L, Key.make("1"));
        grid.putModel(1L, Key.make("3"));
        Assert.assertArrayEquals(new Key[]{Key.make("1"), Key.make("2"), Key.make("3")}, grid.getModelKeys());
    }

    @Test
    public void testParallelCartesian() {
        try {
            Scope.enter();
            Frame parse_test_file = parse_test_file("smalldata/iris/iris_train.csv");
            Scope.track(new Frame[]{parse_test_file});
            HashMap<String, Object[]> hashMap = new HashMap<String, Object[]>() { // from class: hex.grid.GridTest.10
                {
                    put("_distribution", new DistributionFamily[]{DistributionFamily.multinomial});
                    put("_ntrees", new Integer[]{5});
                    put("_max_depth", new Integer[]{2});
                    put("_min_rows", new Integer[]{10, 11, 12, 13, 14});
                    put("_learn_rate", new Double[]{Double.valueOf(0.7d)});
                }
            };
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = parse_test_file._key;
            gBMParameters._response_column = "species";
            Job startGridSearch = GridSearch.startGridSearch((Key) null, gBMParameters, hashMap, 2);
            Scope.track_generic(startGridSearch);
            Grid grid = startGridSearch.get();
            Scope.track_generic(grid);
            Assert.assertEquals(5L, grid.getModelCount());
            hashMap.put("_learn_rate", new Double[]{Double.valueOf(0.5d)});
            Job startGridSearch2 = GridSearch.startGridSearch(grid._key, gBMParameters, hashMap, 2);
            Scope.track_generic(startGridSearch2);
            Scope.track_generic(startGridSearch2.get());
            Assert.assertEquals(10, r0.getModelCount());
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void test_parallel_random_search_with_max_models_being_less_than_parallelism() {
        try {
            Scope.enter();
            Frame parse_test_file = parse_test_file("smalldata/iris/iris_train.csv");
            Scope.track(new Frame[]{parse_test_file});
            HashMap<String, Object[]> hashMap = new HashMap<String, Object[]>() { // from class: hex.grid.GridTest.11
                {
                    put("_distribution", new DistributionFamily[]{DistributionFamily.multinomial});
                    put("_ntrees", new Integer[]{5});
                    put("_max_depth", new Integer[]{2});
                    put("_min_rows", new Integer[]{10, 11, 12, 13, 14});
                    put("_learn_rate", new Double[]{Double.valueOf(0.7d)});
                }
            };
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = parse_test_file._key;
            gBMParameters._response_column = "species";
            GridSearch.SimpleParametersBuilderFactory simpleParametersBuilderFactory = new GridSearch.SimpleParametersBuilderFactory();
            HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria randomDiscreteValueSearchCriteria = new HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria();
            randomDiscreteValueSearchCriteria.set_max_models(2);
            Job startGridSearch = GridSearch.startGridSearch((Key) null, gBMParameters, hashMap, simpleParametersBuilderFactory, randomDiscreteValueSearchCriteria, 4);
            Scope.track_generic(startGridSearch);
            Scope.track_generic(startGridSearch.get());
            Assert.assertEquals(2, r0.getModelCount());
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void test_parallel_random_search_with_max_models_being_greater_than_parallelism() {
        try {
            Scope.enter();
            Frame parse_test_file = parse_test_file("smalldata/iris/iris_train.csv");
            Scope.track(new Frame[]{parse_test_file});
            HashMap<String, Object[]> hashMap = new HashMap<String, Object[]>() { // from class: hex.grid.GridTest.12
                {
                    put("_distribution", new DistributionFamily[]{DistributionFamily.multinomial});
                    put("_ntrees", new Integer[]{5});
                    put("_max_depth", new Integer[]{2});
                    put("_min_rows", new Integer[]{10, 11, 12, 13, 14});
                    put("_learn_rate", new Double[]{Double.valueOf(0.7d)});
                }
            };
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = parse_test_file._key;
            gBMParameters._response_column = "species";
            GridSearch.SimpleParametersBuilderFactory simpleParametersBuilderFactory = new GridSearch.SimpleParametersBuilderFactory();
            HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria randomDiscreteValueSearchCriteria = new HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria();
            randomDiscreteValueSearchCriteria.set_max_models(3);
            Job startGridSearch = GridSearch.startGridSearch((Key) null, gBMParameters, hashMap, simpleParametersBuilderFactory, randomDiscreteValueSearchCriteria, 2);
            Scope.track_generic(startGridSearch);
            Scope.track_generic(startGridSearch.get());
            Assert.assertEquals(3, r0.getModelCount());
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }
}
