package hex.tree.gbm;

import hex.Model;
import hex.genmodel.utils.DistributionFamily;
import hex.grid.Grid;
import hex.grid.GridSearch;
import hex.tree.gbm.GBMModel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.Key;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.Vec;
import water.test.util.GridTestUtils;
import water.util.ArrayUtils;

/* loaded from: input_file:hex/tree/gbm/GBMGridTest.class */
public class GBMGridTest extends TestUtil {
    @BeforeClass
    public static void setup() {
        stall_till_cloudsize(1);
    }

    @Test
    public void testCarsGrid() {
        Grid grid = null;
        Frame frame = null;
        Vec vec = null;
        try {
            frame = parse_test_file("smalldata/junit/cars.csv");
            frame.remove("name").remove();
            vec = frame.remove("cylinders");
            frame.add("cylinders", vec.toCategoricalVec());
            DKV.put(frame);
            final Double[] dArr = {Double.valueOf(0.01d), Double.valueOf(0.1d), Double.valueOf(0.3d)};
            final Double[] dArr2 = {Double.valueOf(-1.0d)};
            HashMap<String, Object[]> hashMap = new HashMap<String, Object[]>() { // from class: hex.tree.gbm.GBMGridTest.1
                {
                    put("_ntrees", new Integer[]{1, 2});
                    put("_distribution", new DistributionFamily[]{DistributionFamily.multinomial});
                    put("_max_depth", new Integer[]{1, 2, 5});
                    put("_learn_rate", ArrayUtils.join(dArr, dArr2));
                }
            };
            String[] strArr = (String[]) hashMap.keySet().toArray(new String[hashMap.size()]);
            Arrays.sort(strArr);
            int crossProductSize = ArrayUtils.crossProductSize(hashMap);
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = frame._key;
            gBMParameters._response_column = "cylinders";
            grid = (Grid) GridSearch.startGridSearch((Key) null, gBMParameters, hashMap).get();
            Grid.SearchFailure failures = grid.getFailures();
            Assert.assertEquals("Size of grid (models+failures) should match to size of hyper space", crossProductSize, grid.getModelCount() + failures.getFailureCount());
            String[] hyperNames = grid.getHyperNames();
            Arrays.sort(hyperNames);
            Assert.assertArrayEquals("Hyper parameters names should match!", strArr, hyperNames);
            Key[] modelKeys = grid.getModelKeys();
            Map<String, Set<Object>> initMap = GridTestUtils.initMap(strArr);
            for (Key key : modelKeys) {
                GBMModel gBMModel = key.get();
                System.out.println(gBMModel._output._scored_train[gBMModel._output._ntrees]._mse + " " + Arrays.deepToString(ArrayUtils.zip(grid.getHyperNames(), grid.getHyperValues(gBMModel._parms))));
                GridTestUtils.extractParams(initMap, gBMModel._parms, strArr);
            }
            hashMap.put("_learn_rate", dArr);
            GridTestUtils.assertParamsEqual("Grid models parameters have to cover specified hyper space", hashMap, initMap);
            Map<String, Set<Object>> initMap2 = GridTestUtils.initMap(strArr);
            for (Model.Parameters parameters : failures.getFailedParameters()) {
                GridTestUtils.extractParams(initMap2, parameters, strArr);
            }
            hashMap.put("_learn_rate", dArr2);
            GridTestUtils.assertParamsEqual("Failed model parameters have to correspond to specified hyper space", hashMap, initMap2);
            if (vec != null) {
                vec.remove();
            }
            if (frame != null) {
                frame.remove();
            }
            if (grid != null) {
                grid.remove();
            }
        } catch (Throwable th) {
            if (vec != null) {
                vec.remove();
            }
            if (frame != null) {
                frame.remove();
            }
            if (grid != null) {
                grid.remove();
            }
            throw th;
        }
    }

    @Test
    public void testDuplicatesCarsGrid() {
        Grid grid = null;
        Frame frame = null;
        Vec vec = null;
        try {
            frame = parse_test_file("smalldata/junit/cars_20mpg.csv");
            frame.remove("name").remove();
            vec = frame.remove("economy");
            frame.add("economy", vec);
            DKV.put(frame);
            HashMap<String, Object[]> hashMap = new HashMap<String, Object[]>() { // from class: hex.tree.gbm.GBMGridTest.2
                {
                    put("_distribution", new DistributionFamily[]{DistributionFamily.gaussian});
                    put("_ntrees", new Integer[]{5, 5});
                    put("_max_depth", new Integer[]{2, 2});
                    put("_learn_rate", new Double[]{Double.valueOf(0.1d), Double.valueOf(0.1d)});
                }
            };
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = frame._key;
            gBMParameters._response_column = "economy";
            grid = (Grid) GridSearch.startGridSearch((Key) null, gBMParameters, hashMap).get();
            Model[] models = grid.getModels();
            Assert.assertTrue("Number of returned models has to be > 0", models.length > 0);
            Key key = models[0]._key;
            for (Model model : models) {
                Assert.assertTrue("Number of constructed models has to be equal to 1", key == model._key);
            }
            if (vec != null) {
                vec.remove();
            }
            if (frame != null) {
                frame.remove();
            }
            if (grid != null) {
                grid.remove();
            }
        } catch (Throwable th) {
            if (vec != null) {
                vec.remove();
            }
            if (frame != null) {
                frame.remove();
            }
            if (grid != null) {
                grid.remove();
            }
            throw th;
        }
    }

    @Test
    public void testGridAccumulation() {
        Grid grid = null;
        Frame frame = null;
        Vec vec = null;
        try {
            frame = parse_test_file("smalldata/junit/cars_20mpg.csv");
            frame.remove("name").remove();
            vec = frame.remove("economy");
            frame.add("economy", vec);
            DKV.put(frame);
            HashMap<String, Object[]> hashMap = new HashMap<String, Object[]>() { // from class: hex.tree.gbm.GBMGridTest.3
                {
                    put("_distribution", new DistributionFamily[]{DistributionFamily.gaussian});
                    put("_ntrees", new Integer[]{2});
                    put("_max_depth", new Integer[]{2});
                    put("_learn_rate", new Double[]{Double.valueOf(0.1d)});
                }
            };
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = frame._key;
            gBMParameters._response_column = "economy";
            Key make = Key.make("accumulating_grid");
            GridSearch.startGridSearch(make, gBMParameters, hashMap).get();
            grid = (Grid) GridSearch.startGridSearch(make, gBMParameters, hashMap).get();
            Assert.assertTrue("Number of returned models has to be 1", grid.getModels().length == 1);
            if (vec != null) {
                vec.remove();
            }
            if (frame != null) {
                frame.remove();
            }
            if (grid != null) {
                grid.remove();
            }
        } catch (Throwable th) {
            if (vec != null) {
                vec.remove();
            }
            if (frame != null) {
                frame.remove();
            }
            if (grid != null) {
                grid.remove();
            }
            throw th;
        }
    }

    @Test
    public void testRandomCarsGrid() {
        Grid grid = null;
        GBMModel gBMModel = null;
        Frame frame = null;
        Vec vec = null;
        try {
            frame = parse_test_file("smalldata/junit/cars.csv");
            frame.remove("name").remove();
            vec = frame.remove("economy (mpg)");
            frame.add("economy (mpg)", vec);
            DKV.put(frame);
            HashMap hashMap = new HashMap();
            hashMap.put("_distribution", new DistributionFamily[]{DistributionFamily.gaussian});
            Random random = new Random();
            Integer valueOf = Integer.valueOf(random.nextInt(4) + 1);
            Integer valueOf2 = Integer.valueOf(random.nextInt(4) + 1);
            Integer valueOf3 = Integer.valueOf(random.nextInt(4) + 1);
            ArrayList arrayList = new ArrayList(Arrays.asList(ArrayUtils.interval(1, 25)));
            Collections.shuffle(arrayList);
            Integer[] numArr = new Integer[valueOf.intValue()];
            for (int i = 0; i < valueOf.intValue(); i++) {
                numArr[i] = (Integer) arrayList.get(i);
            }
            ArrayList arrayList2 = new ArrayList(Arrays.asList(ArrayUtils.interval(1, 10)));
            Collections.shuffle(arrayList2);
            Integer[] numArr2 = new Integer[valueOf2.intValue()];
            for (int i2 = 0; i2 < valueOf2.intValue(); i2++) {
                numArr2[i2] = (Integer) arrayList2.get(i2);
            }
            ArrayList arrayList3 = new ArrayList(Arrays.asList(ArrayUtils.interval(Double.valueOf(0.01d), Double.valueOf(1.0d), Double.valueOf(0.01d))));
            Collections.shuffle(arrayList3);
            Double[] dArr = new Double[valueOf3.intValue()];
            for (int i3 = 0; i3 < valueOf3.intValue(); i3++) {
                dArr[i3] = (Double) arrayList3.get(i3);
            }
            hashMap.put("_ntrees", numArr);
            hashMap.put("_max_depth", numArr2);
            hashMap.put("_learn_rate", dArr);
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = frame._key;
            gBMParameters._response_column = "economy (mpg)";
            grid = (Grid) GridSearch.startGridSearch((Key) null, gBMParameters, hashMap).get();
            System.out.println("ntrees search space: " + Arrays.toString(numArr));
            System.out.println("max_depth search space: " + Arrays.toString(numArr2));
            System.out.println("learn_rate search space: " + Arrays.toString(dArr));
            Integer valueOf4 = Integer.valueOf(grid.getModels().length);
            System.out.println("Grid consists of " + valueOf4 + " models");
            Assert.assertTrue(valueOf4.intValue() == (valueOf.intValue() * valueOf2.intValue()) * valueOf3.intValue());
            HashMap hashMap2 = new HashMap();
            hashMap2.put("_distribution", new DistributionFamily[]{DistributionFamily.gaussian});
            Integer num = numArr[random.nextInt(numArr.length)];
            hashMap2.put("_ntrees", new Integer[]{num});
            Integer num2 = numArr2[random.nextInt(numArr2.length)];
            hashMap2.put("_max_depth", numArr2);
            Double d = dArr[random.nextInt(dArr.length)];
            hashMap2.put("_learn_rate", dArr);
            gBMParameters._distribution = DistributionFamily.gaussian;
            gBMParameters._ntrees = num.intValue();
            gBMParameters._max_depth = num2.intValue();
            gBMParameters._learn_rate = d.doubleValue();
            GBM gbm = new GBM(gBMParameters);
            gBMModel = (GBMModel) gbm.trainModel().get();
            Assert.assertTrue(gbm.isStopped());
            System.out.println("The rebuilt model's MSE: " + gBMModel._output._scored_train[gBMModel._output._ntrees]._mse);
            if (vec != null) {
                vec.remove();
            }
            if (frame != null) {
                frame.remove();
            }
            if (grid != null) {
                grid.remove();
            }
            if (gBMModel != null) {
                gBMModel.remove();
            }
        } catch (Throwable th) {
            if (vec != null) {
                vec.remove();
            }
            if (frame != null) {
                frame.remove();
            }
            if (grid != null) {
                grid.remove();
            }
            if (gBMModel != null) {
                gBMModel.remove();
            }
            throw th;
        }
    }
}
