package hex.tree.gbm;

import hex.ModelBuilder;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import hex.tree.SharedTreeModel;
import hex.tree.drf.DRFModel;
import hex.tree.gbm.GBMModel;
import java.util.Arrays;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import water.Key;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.TestFrameBuilder;

@RunWith(Parameterized.class)
/* loaded from: input_file:hex/tree/gbm/SharedTreeTest.class */
public class SharedTreeTest extends TestUtil {

    @Parameterized.Parameter
    public SharedTreeModel.SharedTreeParameters parms;

    @BeforeClass
    public static void stall() {
        stall_till_cloudsize(1);
    }

    @Parameterized.Parameters(name = "{index}: gbm({0})")
    public static Iterable<SharedTreeModel.SharedTreeParameters> data() {
        SharedTreeModel.SharedTreeParameters gBMParameters = new GBMModel.GBMParameters();
        ((GBMModel.GBMParameters) gBMParameters)._learn_rate = 1.0d;
        SharedTreeModel.SharedTreeParameters dRFParameters = new DRFModel.DRFParameters();
        ((DRFModel.DRFParameters) dRFParameters)._sample_rate = 1.0d;
        return Arrays.asList(gBMParameters, dRFParameters);
    }

    @Test
    public void testNAPredictor_cat() {
        checkNAPredictor(new TestFrameBuilder().withVecTypes(new byte[]{4, 4}).withDataForCol(0, ar(new String[]{null, "V", null, "V", null, "V"})));
    }

    @Test
    public void testNAPredictor_num() {
        checkNAPredictor(new TestFrameBuilder().withVecTypes(new byte[]{3, 4}).withDataForCol(0, ard(new double[]{Double.NaN, 1.0d, Double.NaN, 1.0d, Double.NaN, 1.0d})));
    }

    private void checkNAPredictor(TestFrameBuilder testFrameBuilder) {
        Scope.enter();
        try {
            Frame build = testFrameBuilder.withColNames(new String[]{"F", "Response"}).withDataForCol(1, ar(new String[]{"A", "B", "A", "B", "A", "B"})).build();
            this.parms._train = build._key;
            this.parms._valid = build._key;
            this.parms._response_column = "Response";
            this.parms._ntrees = 1;
            this.parms._ignore_const_cols = true;
            this.parms._min_rows = 1.0d;
            SharedTreeModel sharedTreeModel = ModelBuilder.make(this.parms).trainModel().get();
            Scope.track_generic(sharedTreeModel);
            Assert.assertEquals(0.0d, sharedTreeModel.classification_error(), 0.0d);
            assertCatVecEquals(build.vec("Response"), Scope.track(new Frame[]{sharedTreeModel.score(Scope.track(new Frame[]{build.subframe(new String[]{"F"})}))}).vec("predict"));
            SharedTreeSubgraph sharedTreeSubgraph = sharedTreeModel.getSharedTreeSubgraph(0, 0);
            Assert.assertEquals(3L, sharedTreeSubgraph.nodesArray.size());
            Assert.assertTrue(sharedTreeSubgraph.rootNode.isNaVsRest());
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }
}
