package hex.kmeans;

import hex.Model;
import hex.grid.Grid;
import hex.grid.GridSearch;
import hex.kmeans.KMeans;
import hex.kmeans.KMeansModel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Ignore;
import org.junit.Test;
import water.DKV;
import water.Key;
import water.TestUtil;
import water.fvec.Frame;
import water.test.util.GridTestUtils;
import water.util.ArrayUtils;

/* loaded from: input_file:hex/kmeans/KMeansGridTest.class */
public class KMeansGridTest extends TestUtil {
    @BeforeClass
    public static void setup() {
        stall_till_cloudsize(1);
    }

    @Test
    public void testIrisGrid() {
        Grid grid = null;
        Frame frame = null;
        try {
            frame = parse_test_file("smalldata/iris/iris_wheader.csv");
            HashMap hashMap = new HashMap();
            Integer[] numArr = {1, 2, 3, 4, 5};
            Integer[] numArr2 = {0};
            hashMap.put("_k", ArrayUtils.join(numArr, numArr2));
            hashMap.put("_init", new KMeans.Initialization[]{KMeans.Initialization.Random, KMeans.Initialization.PlusPlus, KMeans.Initialization.Furthest});
            hashMap.put("_seed", new Long[]{1L, 123456789L, 987654321L});
            String[] strArr = (String[]) hashMap.keySet().toArray(new String[hashMap.size()]);
            Arrays.sort(strArr);
            int crossProductSize = ArrayUtils.crossProductSize(hashMap);
            KMeansModel.KMeansParameters kMeansParameters = new KMeansModel.KMeansParameters();
            kMeansParameters._train = frame._key;
            grid = (Grid) GridSearch.startGridSearch((Key) null, kMeansParameters, hashMap).get();
            Grid.SearchFailure failures = grid.getFailures();
            Assert.assertEquals("Size of grid should match to size of hyper space", crossProductSize, grid.getModelCount() + failures.getFailureCount());
            String[] hyperNames = grid.getHyperNames();
            Arrays.sort(hyperNames);
            Assert.assertArrayEquals("Hyper parameters names should match!", strArr, hyperNames);
            Map<String, Set<Object>> initMap = GridTestUtils.initMap(strArr);
            for (KMeansModel kMeansModel : grid.getModels()) {
                System.out.println(kMeansModel._output._tot_withinss + " " + Arrays.deepToString(ArrayUtils.zip(grid.getHyperNames(), grid.getHyperValues(kMeansModel._parms))));
                GridTestUtils.extractParams(initMap, kMeansModel._parms, strArr);
            }
            hashMap.put("_k", numArr);
            GridTestUtils.assertParamsEqual("Grid models parameters have to cover specified hyper space", hashMap, initMap);
            Map<String, Set<Object>> initMap2 = GridTestUtils.initMap(strArr);
            for (KMeansModel.KMeansParameters kMeansParameters2 : failures.getFailedParameters()) {
                GridTestUtils.extractParams(initMap2, kMeansParameters2, strArr);
            }
            hashMap.put("_k", numArr2);
            GridTestUtils.assertParamsEqual("Failed model parameters have to correspond to specified hyper space", hashMap, initMap2);
            if (frame != null) {
                frame.remove();
            }
            if (grid != null) {
                grid.remove();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            if (grid != null) {
                grid.remove();
            }
            throw th;
        }
    }

    @Test
    public void testDuplicatesCarsGrid() {
        Grid grid = null;
        Frame frame = null;
        try {
            frame = parse_test_file("smalldata/iris/iris_wheader.csv");
            frame.remove("class").remove();
            DKV.put(frame);
            HashMap hashMap = new HashMap();
            hashMap.put("_k", new Integer[]{3, 3, 3});
            hashMap.put("_init", new KMeans.Initialization[]{KMeans.Initialization.Random, KMeans.Initialization.Random, KMeans.Initialization.Random});
            hashMap.put("_seed", new Long[]{123456789L, 123456789L, 123456789L});
            KMeansModel.KMeansParameters kMeansParameters = new KMeansModel.KMeansParameters();
            kMeansParameters._train = frame._key;
            grid = (Grid) GridSearch.startGridSearch((Key) null, kMeansParameters, hashMap).get();
            Model[] models = grid.getModels();
            Assert.assertTrue("Number of returned models has to be > 0", models.length > 0);
            Key key = models[0]._key;
            for (Model model : models) {
                Assert.assertTrue("Number of constructed models has to be equal to 1", key == model._key);
            }
            if (frame != null) {
                frame.remove();
            }
            if (grid != null) {
                grid.remove();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            if (grid != null) {
                grid.remove();
            }
            throw th;
        }
    }

    /* JADX WARN: Type inference failed for: r0v3, types: [double[], double[][]] */
    @Test
    public void testUserPointsCarsGrid() {
        Grid grid = null;
        Frame frame = null;
        Frame frame2 = ArrayUtils.frame(ard(new double[]{ard(new double[]{5.0d, 3.4d, 1.5d, 0.2d}), ard(new double[]{7.0d, 3.2d, 4.7d, 1.4d}), ard(new double[]{6.5d, 3.0d, 5.8d, 2.2d})}));
        try {
            frame = parse_test_file("smalldata/iris/iris_wheader.csv");
            frame.remove("class").remove();
            DKV.put(frame);
            HashMap hashMap = new HashMap();
            hashMap.put("_k", new Integer[]{3});
            hashMap.put("_init", new KMeans.Initialization[]{KMeans.Initialization.Random, KMeans.Initialization.PlusPlus, KMeans.Initialization.User, KMeans.Initialization.Furthest});
            hashMap.put("_seed", new Long[]{123456789L});
            KMeansModel.KMeansParameters kMeansParameters = new KMeansModel.KMeansParameters();
            kMeansParameters._train = frame._key;
            kMeansParameters._user_points = frame2._key;
            grid = (Grid) GridSearch.startGridSearch((Key) null, kMeansParameters, hashMap).get();
            Integer valueOf = Integer.valueOf(grid.getModels().length);
            System.out.println("Grid consists of " + valueOf + " models");
            Assert.assertTrue(valueOf.intValue() == 4);
            if (frame != null) {
                frame.remove();
            }
            if (frame2 != null) {
                frame2.remove();
            }
            if (grid != null) {
                grid.remove();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            if (frame2 != null) {
                frame2.remove();
            }
            if (grid != null) {
                grid.remove();
            }
            throw th;
        }
    }

    /* JADX WARN: Type inference failed for: r0v4, types: [double[], double[][]] */
    @Ignore("PUBDEV-1675")
    public void testRandomCarsGrid() {
        Grid grid = null;
        KMeansModel kMeansModel = null;
        Frame frame = null;
        Frame frame2 = ArrayUtils.frame(ard(new double[]{ard(new double[]{5.0d, 3.4d, 1.5d, 0.2d}), ard(new double[]{7.0d, 3.2d, 4.7d, 1.4d}), ard(new double[]{6.5d, 3.0d, 5.8d, 2.2d})}));
        try {
            frame = parse_test_file("smalldata/iris/iris_wheader.csv");
            frame.remove("class").remove();
            DKV.put(frame);
            HashMap hashMap = new HashMap();
            Random random = new Random();
            Integer valueOf = Integer.valueOf(random.nextInt(4) + 1);
            Integer valueOf2 = Integer.valueOf(random.nextInt(4) + 1);
            Integer valueOf3 = Integer.valueOf(random.nextInt(4) + 1);
            Integer valueOf4 = Integer.valueOf(random.nextInt(2) + 1);
            ArrayList arrayList = new ArrayList(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50));
            Collections.shuffle(arrayList);
            Integer[] numArr = new Integer[valueOf.intValue()];
            for (int i = 0; i < valueOf.intValue(); i++) {
                numArr[i] = (Integer) arrayList.get(i);
            }
            ArrayList arrayList2 = new ArrayList(Arrays.asList(KMeans.Initialization.Random, KMeans.Initialization.User, KMeans.Initialization.PlusPlus, KMeans.Initialization.Furthest));
            Collections.shuffle(arrayList2);
            KMeans.Initialization[] initializationArr = new KMeans.Initialization[valueOf2.intValue()];
            for (int i2 = 0; i2 < valueOf2.intValue(); i2++) {
                initializationArr[i2] = (KMeans.Initialization) arrayList2.get(i2);
            }
            ArrayList arrayList3 = new ArrayList(Arrays.asList(0L, 1L, 123456789L, 987654321L));
            Collections.shuffle(arrayList3);
            Long[] lArr = new Long[valueOf3.intValue()];
            for (int i3 = 0; i3 < valueOf3.intValue(); i3++) {
                lArr[i3] = (Long) arrayList3.get(i3);
            }
            ArrayList arrayList4 = new ArrayList(Arrays.asList(1, 0));
            Collections.shuffle(arrayList4);
            Integer[] numArr2 = new Integer[valueOf4.intValue()];
            for (int i4 = 0; i4 < valueOf4.intValue(); i4++) {
                numArr2[i4] = (Integer) arrayList4.get(i4);
            }
            hashMap.put("_k", numArr);
            hashMap.put("_init", initializationArr);
            hashMap.put("_seed", lArr);
            hashMap.put("_standardize", numArr2);
            System.out.println("k search space: " + Arrays.toString(numArr));
            System.out.println("max_depth search space: " + Arrays.toString(initializationArr));
            System.out.println("seed search space: " + Arrays.toString(lArr));
            System.out.println("sample_rate search space: " + Arrays.toString(numArr2));
            KMeansModel.KMeansParameters kMeansParameters = new KMeansModel.KMeansParameters();
            kMeansParameters._train = frame._key;
            if (Arrays.asList(initializationArr).contains(KMeans.Initialization.User)) {
                kMeansParameters._user_points = frame2._key;
            }
            grid = (Grid) GridSearch.startGridSearch((Key) null, kMeansParameters, hashMap).get();
            Integer valueOf5 = Integer.valueOf(grid.getModels().length);
            System.out.println("Grid consists of " + valueOf5 + " models");
            Assert.assertTrue(valueOf5.intValue() == ((valueOf.intValue() * valueOf2.intValue()) * valueOf4.intValue()) * valueOf3.intValue());
            HashMap hashMap2 = new HashMap();
            Integer num = numArr[random.nextInt(numArr.length)];
            hashMap2.put("_k", new Integer[]{num});
            KMeans.Initialization initialization = initializationArr[random.nextInt(initializationArr.length)];
            hashMap2.put("_init", initializationArr);
            Long l = lArr[random.nextInt(lArr.length)];
            hashMap2.put("_seed", lArr);
            Integer num2 = numArr2[random.nextInt(numArr2.length)];
            hashMap2.put("_standardize", numArr2);
            kMeansParameters._k = num.intValue();
            kMeansParameters._init = initialization;
            kMeansParameters._seed = l.longValue();
            kMeansParameters._standardize = num2.intValue() == 1;
            kMeansModel = (KMeansModel) new KMeans(kMeansParameters).trainModel().get();
            System.out.println("The rebuilt model's betweenss: " + kMeansModel._output._betweenss);
            if (frame != null) {
                frame.remove();
            }
            if (grid != null) {
                grid.remove();
            }
            if (kMeansModel != null) {
                kMeansModel.remove();
            }
            if (frame2 != null) {
                frame2.remove();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            if (grid != null) {
                grid.remove();
            }
            if (kMeansModel != null) {
                kMeansModel.remove();
            }
            if (frame2 != null) {
                frame2.remove();
            }
            throw th;
        }
    }
}
