package water.rapids.ast.prims.advmath;

import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.Key;
import water.TestUtil;
import water.fvec.Frame;
import water.rapids.Rapids;
import water.util.ArrayUtils;

/* loaded from: input_file:water/rapids/ast/prims/advmath/StratifiedSplitTest.class */
public class StratifiedSplitTest extends TestUtil {
    private static Frame f = null;
    private static Frame fr1 = null;
    private static Frame fanimal = null;
    private static Frame fr2 = null;

    @BeforeClass
    public static void setup() {
        stall_till_cloudsize(1);
    }

    @AfterClass
    public static void teardown() {
        f.delete();
        fr1.delete();
        fanimal.delete();
        fr2.delete();
    }

    @Test
    public void testStratifiedSampling() {
        f = ArrayUtils.frame("response", vec(ari(1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1)));
        fanimal = ArrayUtils.frame("response", vec(ar("dog", "cat"), ari(1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1)));
        f = new Frame(f);
        fanimal = new Frame(fanimal);
        f._key = Key.make();
        fanimal._key = Key.make();
        DKV.put(f);
        DKV.put(fanimal);
        fr1 = Rapids.exec("(h2o.random_stratified_split (cols_py " + f._key + " 0) 0.3333333 123)").getFrame();
        Assert.assertEquals(fr1.vec(0).at8(0L), 1L);
        Assert.assertEquals(fr1.vec(0).at8(11L), 0L);
        Assert.assertEquals(fr1.vec(0).mean(), 0.3333333333333333d, 1.0E-5d);
        fr2 = Rapids.exec("(h2o.random_stratified_split (cols_py " + fanimal._key + " 0) 0.3333333 123)").getFrame();
        Assert.assertEquals(fr2.vec(0).at8(0L), 1L);
        Assert.assertEquals(fr2.vec(0).at8(11L), 0L);
        Assert.assertEquals(fr2.vec(0).mean(), 0.3333333333333333d, 1.0E-5d);
    }
}
