package hex.tree.gbm;

import hex.Model;
import hex.tree.gbm.GBMModel;
import java.util.Arrays;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import water.Key;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.TestFrameBuilder;
import water.util.Log;

@RunWith(Parameterized.class)
/* loaded from: input_file:hex/tree/gbm/GBMEncodingTest.class */
public class GBMEncodingTest extends TestUtil {

    @Parameterized.Parameter
    public Model.Parameters.CategoricalEncodingScheme encoding;

    @BeforeClass
    public static void stall() {
        stall_till_cloudsize(1);
    }

    @Parameterized.Parameters
    public static Iterable<?> data() {
        return Arrays.asList(Model.Parameters.CategoricalEncodingScheme.values());
    }

    @Test
    public void testGBM_BasicCategoricalEncoding() {
        if (this.encoding == Model.Parameters.CategoricalEncodingScheme.OneHotInternal) {
            return;
        }
        Log.info(new Object[]{"Using encoding " + this.encoding});
        try {
            Scope.enter();
            Frame build = new TestFrameBuilder().withName("trainEncoding").withColNames(new String[]{"ColA", "Response"}).withVecTypes(new byte[]{4, 4}).withDataForCol(0, ar(new String[]{"B", "B", "A", "A", "A"})).withDataForCol(1, ar(new String[]{"C", "C", "V", "V", "V"})).build();
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._seed = 1L;
            gBMParameters._train = build._key;
            gBMParameters._response_column = "Response";
            gBMParameters._ntrees = 1;
            gBMParameters._max_depth = 1;
            gBMParameters._learn_rate = 1.0d;
            gBMParameters._min_rows = 1.0d;
            gBMParameters._categorical_encoding = this.encoding;
            if (this.encoding == Model.Parameters.CategoricalEncodingScheme.EnumLimited) {
                gBMParameters._max_categorical_levels = 2;
            }
            GBMModel gBMModel = new GBM(gBMParameters).trainModel().get();
            Scope.track_generic(gBMModel);
            Frame score = gBMModel.score(build);
            Scope.track(new Frame[]{score});
            assertStringVecEquals(build.vec("Response"), score.vec(0));
            Frame score2 = gBMModel.score(new TestFrameBuilder().withName("testEncoding").withColNames(new String[]{"ColA"}).withVecTypes(new byte[]{4}).withDataForCol(0, ar(new String[]{"A"})).build());
            Scope.track(new Frame[]{score2});
            Assert.assertEquals("V", score2.vec(0).stringAt(0L));
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testGBM_CategoricalEncodingWithUnseenCategories() {
        if (this.encoding == Model.Parameters.CategoricalEncodingScheme.OneHotInternal) {
            return;
        }
        Log.info(new Object[]{"Using encoding " + this.encoding});
        try {
            Scope.enter();
            Frame build = new TestFrameBuilder().withName("trainEncoding").withColNames(new String[]{"ColA", "Response"}).withVecTypes(new byte[]{4, 4}).withDataForCol(0, ar(new String[]{"B", "B", "A", "A", "A", "B", "A"})).withDataForCol(1, ar(new String[]{"C", "C", "V", "V", "V", "C", "V"})).build();
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._seed = 1L;
            gBMParameters._train = build._key;
            gBMParameters._response_column = "Response";
            gBMParameters._ntrees = 1;
            gBMParameters._max_depth = 3;
            gBMParameters._learn_rate = 1.0d;
            gBMParameters._min_rows = 1.0d;
            gBMParameters._categorical_encoding = this.encoding;
            if (this.encoding == Model.Parameters.CategoricalEncodingScheme.EnumLimited) {
                gBMParameters._max_categorical_levels = 2;
            }
            GBMModel gBMModel = new GBM(gBMParameters).trainModel().get();
            Scope.track_generic(gBMModel);
            Scope.track(new Frame[]{gBMModel.score(build)});
            Frame score = gBMModel.score(new TestFrameBuilder().withName("testEncoding").withColNames(new String[]{"ColA"}).withVecTypes(new byte[]{4}).withDataForCol(0, ar(new String[]{"A", "D", "E"})).build());
            Scope.track(new Frame[]{score});
            Assert.assertEquals("V", score.vec(0).stringAt(0L));
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testGBM_CategoricalEncodingWithPredictionsOnFeaturesSubset() {
        if (this.encoding == Model.Parameters.CategoricalEncodingScheme.OneHotInternal) {
            return;
        }
        Log.info(new Object[]{"Using encoding " + this.encoding});
        try {
            Scope.enter();
            Frame build = new TestFrameBuilder().withName("trainEncoding").withColNames(new String[]{"ColA", "ColB", "Response"}).withVecTypes(new byte[]{4, 3, 4}).withDataForCol(0, ar(new String[]{"B", "B", "A", "A", "A", "B", "A"})).withDataForCol(1, ar(new long[]{2, 2, 1, 1, 1, 2, 1})).withDataForCol(2, ar(new String[]{"C", "C", "V", "V", "V", "C", "V"})).build();
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._seed = 1L;
            gBMParameters._train = build._key;
            gBMParameters._response_column = "Response";
            gBMParameters._ntrees = 1;
            gBMParameters._max_depth = 3;
            gBMParameters._learn_rate = 1.0d;
            gBMParameters._min_rows = 1.0d;
            gBMParameters._categorical_encoding = this.encoding;
            if (this.encoding == Model.Parameters.CategoricalEncodingScheme.EnumLimited) {
                gBMParameters._max_categorical_levels = 2;
            }
            GBMModel gBMModel = new GBM(gBMParameters).trainModel().get();
            Scope.track_generic(gBMModel);
            Scope.track(new Frame[]{gBMModel.score(build)});
            Frame score = gBMModel.score(new TestFrameBuilder().withName("testEncodingCat").withColNames(new String[]{"ColA", "ColZ"}).withVecTypes(new byte[]{4, 3}).withDataForCol(0, ar(new String[]{"A"})).withDataForCol(1, ard(new double[]{0.0d})).build());
            Scope.track(new Frame[]{score});
            Assert.assertEquals("V", score.vec(0).stringAt(0L));
            Frame score2 = gBMModel.score(new TestFrameBuilder().withName("testEncodingNum").withColNames(new String[]{"ColB"}).withVecTypes(new byte[]{3}).withDataForCol(0, ar(new long[]{1})).build());
            Scope.track(new Frame[]{score2});
            Assert.assertEquals("V", score2.vec(0).stringAt(0L));
            try {
                Scope.track(new Frame[]{gBMModel.score(new TestFrameBuilder().withName("testEncodingNoCommon").withColNames(new String[]{"ColZ"}).withVecTypes(new byte[]{3}).withDataForCol(0, ar(new long[]{1})).build())});
                Assert.fail("Should have thrown IllegalArgumentException");
            } catch (IllegalArgumentException e) {
                Assert.assertTrue("Expected exception due to no column in common with training data, but got: " + e.getMessage(), e.getMessage().contains("no columns in common"));
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }
}
