package hex.genmodel.easy;

import hex.ModelCategory;
import hex.genmodel.GenModel;
import hex.genmodel.MojoModel;
import hex.genmodel.algos.word2vec.WordEmbeddingModel;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.error.CountingErrorConsumer;
import hex.genmodel.easy.error.VoidErrorConsumer;
import hex.genmodel.easy.exception.PredictUnknownCategoricalLevelException;
import hex.genmodel.easy.prediction.AutoEncoderModelPrediction;
import hex.genmodel.easy.prediction.SortedClassProbability;
import hex.genmodel.easy.prediction.Word2VecPrediction;
import java.lang.reflect.Field;
import java.util.HashMap;
import java.util.Iterator;
import java.util.concurrent.atomic.AtomicLong;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:hex/genmodel/easy/EasyPredictModelWrapperTest.class */
public class EasyPredictModelWrapperTest {

    /* loaded from: input_file:hex/genmodel/easy/EasyPredictModelWrapperTest$MyAutoEncoderModel.class */
    private static class MyAutoEncoderModel extends GenModel {
        private static final String[][] DOMAINS = {new String[]{"setosa", "versicolor", "virginica"}, 0, 0, 0, 0};

        private MyAutoEncoderModel() {
            super(new String[]{"Species", "Sepal.Length", "Sepal.Width", "Petal.Length", "Petal.Width"}, DOMAINS, (String) null);
        }

        public ModelCategory getModelCategory() {
            return ModelCategory.AutoEncoder;
        }

        public boolean isSupervised() {
            return false;
        }

        public int nfeatures() {
            return 5;
        }

        public int nclasses() {
            return 8;
        }

        public String getUUID() {
            return null;
        }

        public int getPredsSize() {
            return nclasses();
        }

        public double[] score0(double[] dArr, double[] dArr2) {
            double[] dArr3 = {0.0d, 1.3124d, 0.4864d, 0.0d, 6.1729d, 3.0573d, 17.8372d, 1.1993d};
            Assert.assertArrayEquals(new double[]{1.0d, 7.0d, 3.2d, 4.7d, 1.4d}, dArr, 1.0E-4d);
            Assert.assertEquals(dArr3.length, dArr2.length);
            System.arraycopy(dArr3, 0, dArr2, 0, dArr3.length);
            return dArr3;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/genmodel/easy/EasyPredictModelWrapperTest$MyModel.class */
    public static class MyModel extends GenModel {
        MyModel(String[] strArr, String[][] strArr2) {
            super(strArr, strArr2, (String) null);
        }

        public int nclasses() {
            return 2;
        }

        public boolean isSupervised() {
            return true;
        }

        public double[] score0(double[] dArr, double[] dArr2) {
            Assert.assertEquals(dArr2.length, 3L);
            dArr2[0] = 0.0d;
            dArr2[1] = 1.0d;
            dArr2[2] = 0.0d;
            return dArr2;
        }

        public ModelCategory getModelCategory() {
            return ModelCategory.Binomial;
        }

        public String getUUID() {
            return null;
        }
    }

    /* loaded from: input_file:hex/genmodel/easy/EasyPredictModelWrapperTest$MyWordEmbeddingModel.class */
    private static class MyWordEmbeddingModel extends MojoModel implements WordEmbeddingModel {
        /* JADX WARN: Type inference failed for: r2v1, types: [java.lang.String[], java.lang.String[][]] */
        public MyWordEmbeddingModel() {
            super(new String[0], (String[][]) new String[0], (String) null);
        }

        public int getVecSize() {
            return 2;
        }

        public float[] transform0(String str, float[] fArr) {
            if (str.equals("NA")) {
                return null;
            }
            String[] split = str.split(",");
            for (int i = 0; i < split.length; i++) {
                fArr[i] = Float.valueOf(split[i]).floatValue();
            }
            return fArr;
        }

        public double[] score0(double[] dArr, double[] dArr2) {
            throw new IllegalStateException("Should never be called");
        }

        public ModelCategory getModelCategory() {
            return ModelCategory.WordEmbedding;
        }
    }

    /* JADX WARN: Type inference failed for: r0v3, types: [java.lang.String[], java.lang.String[][]] */
    private static MyModel makeModel() {
        return new MyModel(new String[]{"C1", "C2", "RESPONSE"}, new String[]{new String[]{"c1level1", "c1level2"}, new String[]{"c2level1", "c2level2", "c2level3"}, new String[]{"NO", "YES"}});
    }

    @Test
    public void testUnknownCategoricalLevels() throws Exception {
        MyModel makeModel = makeModel();
        CountingErrorConsumer countingErrorConsumer = new CountingErrorConsumer(makeModel);
        EasyPredictModelWrapper easyPredictModelWrapper = new EasyPredictModelWrapper(new EasyPredictModelWrapper.Config().setModel(makeModel).setErrorConsumer(countingErrorConsumer));
        RowData rowData = new RowData();
        rowData.put("C1", "c1level1");
        try {
            easyPredictModelWrapper.predictBinomial(rowData);
        } catch (PredictUnknownCategoricalLevelException e) {
            Assert.fail("Caught exception but should not have");
        }
        long j = 0;
        Iterator it = countingErrorConsumer.getUnknownCategoricalsPerColumn().values().iterator();
        while (it.hasNext()) {
            j += ((AtomicLong) it.next()).get();
        }
        Assert.assertEquals(j, 0L);
        RowData rowData2 = new RowData();
        rowData2.put("C1", "c1level1");
        rowData2.put("C2", "unknownLevel");
        boolean z = false;
        try {
            easyPredictModelWrapper.predictBinomial(rowData2);
        } catch (PredictUnknownCategoricalLevelException e2) {
            z = true;
        }
        Assert.assertEquals(Boolean.valueOf(z), true);
        long j2 = 0;
        Iterator it2 = countingErrorConsumer.getUnknownCategoricalsPerColumn().values().iterator();
        while (it2.hasNext()) {
            j2 += ((AtomicLong) it2.next()).get();
        }
        Assert.assertEquals(j2, 0L);
        CountingErrorConsumer countingErrorConsumer2 = new CountingErrorConsumer(makeModel);
        EasyPredictModelWrapper easyPredictModelWrapper2 = new EasyPredictModelWrapper(new EasyPredictModelWrapper.Config().setModel(makeModel).setErrorConsumer(countingErrorConsumer2).setConvertUnknownCategoricalLevelsToNa(true).setConvertInvalidNumbersToNa(true));
        easyPredictModelWrapper2.predict(new RowData());
        Assert.assertEquals(countingErrorConsumer2.getTotalUnknownCategoricalLevelsSeen(), 0L);
        RowData rowData3 = new RowData();
        rowData3.put("C1", "c1level1");
        rowData3.put("C2", "unknownLevel");
        easyPredictModelWrapper2.predictBinomial(rowData3);
        Assert.assertEquals(countingErrorConsumer2.getTotalUnknownCategoricalLevelsSeen(), 1L);
        RowData rowData4 = new RowData();
        rowData4.put("C1", "c1level1");
        rowData4.put("C2", "c2level3");
        easyPredictModelWrapper2.predictBinomial(rowData4);
        Assert.assertEquals(countingErrorConsumer2.getTotalUnknownCategoricalLevelsSeen(), 1L);
        RowData rowData5 = new RowData();
        rowData5.put("C1", "c1level1");
        rowData5.put("unknownColumn", "unknownLevel");
        easyPredictModelWrapper2.predictBinomial(rowData5);
        Assert.assertEquals(countingErrorConsumer2.getTotalUnknownCategoricalLevelsSeen(), 1L);
        easyPredictModelWrapper2.predictBinomial(rowData3);
        easyPredictModelWrapper2.predictBinomial(rowData3);
        Assert.assertEquals(countingErrorConsumer2.getTotalUnknownCategoricalLevelsSeen(), 3L);
        RowData rowData6 = new RowData();
        rowData6.put("C1", "unknownLevel");
        easyPredictModelWrapper2.predictBinomial(rowData6);
        Assert.assertEquals(countingErrorConsumer2.getTotalUnknownCategoricalLevelsSeen(), 4L);
        Assert.assertEquals(((AtomicLong) countingErrorConsumer2.getUnknownCategoricalsPerColumn().get("C1")).get(), 1L);
        Assert.assertEquals(((AtomicLong) countingErrorConsumer2.getUnknownCategoricalsPerColumn().get("C2")).get(), 3L);
        Assert.assertEquals(4L, countingErrorConsumer2.getTotalUnknownCategoricalLevelsSeen());
    }

    @Test
    public void testSortedClassProbability() throws Exception {
        EasyPredictModelWrapper easyPredictModelWrapper = new EasyPredictModelWrapper(makeModel());
        RowData rowData = new RowData();
        rowData.put("C1", "c1level1");
        SortedClassProbability[] sortByDescendingClassProbability = easyPredictModelWrapper.sortByDescendingClassProbability(easyPredictModelWrapper.predictBinomial(rowData));
        Assert.assertEquals(sortByDescendingClassProbability[0].name, "NO");
        Assert.assertEquals(sortByDescendingClassProbability[0].probability, 1.0d, 0.001d);
        Assert.assertEquals(sortByDescendingClassProbability[1].name, "YES");
        Assert.assertEquals(sortByDescendingClassProbability[1].probability, 0.0d, 0.001d);
    }

    @Test
    public void testWordEmbeddingModel() throws Exception {
        EasyPredictModelWrapper easyPredictModelWrapper = new EasyPredictModelWrapper(new MyWordEmbeddingModel());
        RowData rowData = new RowData();
        rowData.put("C0", -1);
        rowData.put("C1", "0.9,0.1");
        rowData.put("C2", "0.1,0.9");
        rowData.put("C3", "NA");
        Word2VecPrediction predictWord2Vec = easyPredictModelWrapper.predictWord2Vec(rowData);
        Assert.assertFalse(predictWord2Vec.wordEmbeddings.containsKey("C0"));
        Assert.assertArrayEquals(new float[]{0.9f, 0.1f}, (float[]) predictWord2Vec.wordEmbeddings.get("C1"), 1.0E-4f);
        Assert.assertArrayEquals(new float[]{0.1f, 0.9f}, (float[]) predictWord2Vec.wordEmbeddings.get("C2"), 1.0E-4f);
        Assert.assertTrue(predictWord2Vec.wordEmbeddings.containsKey("C3"));
        Assert.assertNull(predictWord2Vec.wordEmbeddings.get("C3"));
    }

    @Test
    public void testAutoEncoderModel() throws Exception {
        EasyPredictModelWrapper easyPredictModelWrapper = new EasyPredictModelWrapper(new MyAutoEncoderModel());
        RowData rowData = new RowData();
        rowData.put("Species", "versicolor");
        rowData.put("Sepal.Length", Double.valueOf(7.0d));
        rowData.put("Sepal.Width", Double.valueOf(3.2d));
        rowData.put("Petal.Length", Double.valueOf(4.7d));
        rowData.put("Petal.Width", Double.valueOf(1.4d));
        AutoEncoderModelPrediction predict = easyPredictModelWrapper.predict(rowData);
        Assert.assertTrue(predict instanceof AutoEncoderModelPrediction);
        AutoEncoderModelPrediction autoEncoderModelPrediction = predict;
        Assert.assertArrayEquals(new double[]{0.0d, 1.0d, 0.0d, 0.0d, 7.0d, 3.2d, 4.7d, 1.4d}, autoEncoderModelPrediction.original, 0.01d);
        Assert.assertArrayEquals(new double[]{0.0d, 1.3124d, 0.4864d, 0.0d, 6.1729d, 3.0573d, 17.8372d, 1.1993d}, autoEncoderModelPrediction.reconstructed, 0.001d);
        Assert.assertEquals(new HashMap<String, Object>() { // from class: hex.genmodel.easy.EasyPredictModelWrapperTest.1
            {
                put("Petal.Length", Double.valueOf(17.8372d));
                put("Petal.Width", Double.valueOf(1.1993d));
                put("Sepal.Width", Double.valueOf(3.0573d));
                put("Sepal.Length", Double.valueOf(6.1729d));
                put("Species", new HashMap<String, Object>() { // from class: hex.genmodel.easy.EasyPredictModelWrapperTest.1.1
                    {
                        put(null, Double.valueOf(0.0d));
                        put("setosa", Double.valueOf(0.0d));
                        put("virginica", Double.valueOf(0.4864d));
                        put("versicolor", Double.valueOf(1.3124d));
                    }
                });
            }
        }, autoEncoderModelPrediction.reconstructedRowData);
    }

    @Test
    public void testVoidErrorConsumerInitialized() throws NoSuchFieldException, IllegalAccessException {
        EasyPredictModelWrapper easyPredictModelWrapper = new EasyPredictModelWrapper(new MyAutoEncoderModel());
        Field declaredField = easyPredictModelWrapper.getClass().getDeclaredField("errorConsumer");
        declaredField.setAccessible(true);
        Object obj = declaredField.get(easyPredictModelWrapper);
        Assert.assertNotNull(obj);
        Assert.assertEquals(VoidErrorConsumer.class, obj.getClass());
    }

    @Test
    public void testVoidErrorConsumerInitializedWithConfig() throws NoSuchFieldException, IllegalAccessException {
        EasyPredictModelWrapper easyPredictModelWrapper = new EasyPredictModelWrapper(new EasyPredictModelWrapper.Config().setModel(new MyAutoEncoderModel()));
        Field declaredField = easyPredictModelWrapper.getClass().getDeclaredField("errorConsumer");
        declaredField.setAccessible(true);
        Object obj = declaredField.get(easyPredictModelWrapper);
        Assert.assertNotNull(obj);
        Assert.assertEquals(VoidErrorConsumer.class, obj.getClass());
    }
}
