package hex;

import hex.Model;
import hex.ModelMetrics;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.H2O;
import water.Job;
import water.Key;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.TestFrameBuilder;
import water.parser.BufferedString;

/* loaded from: input_file:hex/ModelBuilderTest.class */
public class ModelBuilderTest extends TestUtil {

    /* loaded from: input_file:hex/ModelBuilderTest$BulkRunner.class */
    public static class BulkRunner extends H2O.H2OCountedCompleter<BulkRunner> {
        private Job _j;

        private BulkRunner(Job job) {
            this._j = job;
        }

        public void compute2() {
            ModelBuilder.bulkBuildModels("dummy-group", this._j, new ModelBuilder[]{new DummyModelBuilder(new DummyModelParameters("Dummy 1", Key.make(this._j._key + "-dummny-1"))), new DummyModelBuilder(new DummyModelParameters("Dummy 2", Key.make(this._j._key + "-dummny-2")))}, 1, 1);
            Assert.assertEquals(0.2d, this._j.progress(), 0.001d);
            tryComplete();
        }
    }

    /* loaded from: input_file:hex/ModelBuilderTest$DummyModel.class */
    public static class DummyModel extends Model<DummyModel, DummyModelParameters, DummyModelOutput> {
        public DummyModel(Key<DummyModel> key, DummyModelParameters dummyModelParameters, DummyModelOutput dummyModelOutput) {
            super(key, dummyModelParameters, dummyModelOutput);
        }

        public ModelMetrics.MetricBuilder makeMetricBuilder(String[] strArr) {
            return null;
        }

        protected double[] score0(double[] dArr, double[] dArr2) {
            return dArr2;
        }
    }

    /* loaded from: input_file:hex/ModelBuilderTest$DummyModelBuilder.class */
    public static class DummyModelBuilder extends ModelBuilder<DummyModel, DummyModelParameters, DummyModelOutput> {
        public DummyModelBuilder(DummyModelParameters dummyModelParameters) {
            super(dummyModelParameters);
            init(false);
        }

        protected ModelBuilder<DummyModel, DummyModelParameters, DummyModelOutput>.Driver trainModelImpl() {
            return new ModelBuilder<DummyModel, DummyModelParameters, DummyModelOutput>.Driver() { // from class: hex.ModelBuilderTest.DummyModelBuilder.1
                public void computeImpl() {
                    DKV.put(((DummyModelParameters) DummyModelBuilder.this._parms)._trgt, new BufferedString("Computed " + ((DummyModelParameters) DummyModelBuilder.this._parms)._msg));
                }
            };
        }

        public ModelCategory[] can_build() {
            return new ModelCategory[0];
        }

        public boolean isSupervised() {
            return false;
        }
    }

    /* loaded from: input_file:hex/ModelBuilderTest$DummyModelOutput.class */
    public static class DummyModelOutput extends Model.Output {
    }

    /* loaded from: input_file:hex/ModelBuilderTest$DummyModelParameters.class */
    public static class DummyModelParameters extends Model.Parameters {
        private String _msg;
        private Key _trgt;

        public DummyModelParameters(String str, Key key) {
            this._msg = str;
            this._trgt = key;
        }

        public String fullName() {
            return "dummy";
        }

        public String algoName() {
            return "dummy";
        }

        public String javaName() {
            return DummyModelBuilder.class.getName();
        }

        public long progressUnits() {
            return 1L;
        }
    }

    @BeforeClass
    public static void setup() {
        stall_till_cloudsize(1);
    }

    @Test
    public void testRebalancePubDev5400() {
        try {
            Scope.enter();
            int i = H2O.NUMCPUS;
            int i2 = i * 1000;
            double[] dArr = new double[i2];
            String[] strArr = new String[i2];
            for (int i3 = 0; i3 < dArr.length; i3++) {
                dArr[i3] = i3 % 7;
                strArr[i3] = i3 % 3 == 0 ? "A" : "B";
            }
            long[] jArr = new long[i];
            jArr[i - 1] = dArr.length;
            Frame track = Scope.track(new Frame[]{new TestFrameBuilder().withName("testFrame").withColNames("ColA", "Response").withVecTypes(3, 4).withDataForCol(0, dArr).withDataForCol(1, strArr).withChunkLayout(jArr).build()});
            Assert.assertEquals(i, track.anyVec().nChunks());
            Assert.assertEquals(dArr.length, track.numRows());
            DummyModelParameters dummyModelParameters = new DummyModelParameters("Rebalance Test", Key.make("rebalance-test"));
            dummyModelParameters._train = track._key;
            DummyModelBuilder dummyModelBuilder = new DummyModelBuilder(dummyModelParameters);
            Assert.assertEquals(i, dummyModelBuilder.desiredChunks(track, true));
            dummyModelBuilder.init(true);
            long[] espc = dummyModelBuilder.train().anyVec().espc();
            Assert.assertEquals(i + 1, espc.length);
            Assert.assertEquals(i2, espc[i]);
            for (int i4 = 0; i4 < espc.length; i4++) {
                Assert.assertEquals(i4 * 1000, espc[i4]);
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testRebalanceMulti() {
        Assume.assumeTrue(H2O.getCloudSize() > 1);
        try {
            Scope.enter();
            double[] dArr = new double[1000000];
            String[] strArr = new String[dArr.length];
            for (int i = 0; i < dArr.length; i++) {
                dArr[i] = i % 7;
                strArr[i] = i % 3 == 0 ? "A" : "B";
            }
            Frame track = Scope.track(new Frame[]{new TestFrameBuilder().withName("testFrame").withColNames("ColA", "Response").withVecTypes(3, 4).withDataForCol(0, dArr).withDataForCol(1, strArr).withChunkLayout(dArr.length).build()});
            Assert.assertEquals(1L, track.anyVec().nChunks());
            DummyModelParameters dummyModelParameters = new DummyModelParameters("Rebalance Test", Key.make("rebalance-test"));
            dummyModelParameters._train = track._key;
            DummyModelBuilder dummyModelBuilder = new DummyModelBuilder(dummyModelParameters) { // from class: hex.ModelBuilderTest.1
                protected String getSysProperty(String str, String str2) {
                    if (str.equals("rebalance.ratio.multi")) {
                        return "0.5";
                    }
                    if (str.equals("rebalance.enableMulti")) {
                        return "true";
                    }
                    if (str.startsWith("sys.ai.h2o.rebalance")) {
                        throw new IllegalStateException("Unexpected property: " + str);
                    }
                    return super.getSysProperty(str, str2);
                }
            };
            int desiredChunks = dummyModelBuilder.desiredChunks(track, false);
            Assert.assertTrue(desiredChunks > 4 * H2O.NUMCPUS);
            dummyModelBuilder.init(true);
            Assert.assertEquals(desiredChunks, dummyModelBuilder.train().anyVec().nonEmptyChunks());
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testMakeUnknownModel() {
        try {
            ModelBuilder.make("invalid", (Job) null, (Key) null);
            Assert.fail();
        } catch (IllegalStateException e) {
            Assert.assertEquals("Algorithm 'invalid' is not registered. Available algos: []", e.getMessage());
        }
    }

    @Test
    public void bulkBuildModels() throws Exception {
        Job job = new Job((Key) null, (String) null, "BulkBuilding");
        Key make = Key.make(job._key + "-dummny-1");
        Key make2 = Key.make(job._key + "-dummny-2");
        try {
            job.start(new BulkRunner(job), 10L).get();
            Assert.assertEquals("Computed Dummy 1", DKV.getGet(make).toString());
            Assert.assertEquals("Computed Dummy 2", DKV.getGet(make2).toString());
            DKV.remove(make);
            DKV.remove(make2);
        } catch (Throwable th) {
            DKV.remove(make);
            DKV.remove(make2);
            throw th;
        }
    }
}
