package hex.glm;

import hex.StringPair;
import hex.glm.GLMModel;
import java.util.Arrays;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.Key;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.TestFrameBuilder;
import water.util.ArrayUtils;

/* loaded from: input_file:hex/glm/GLMPlugValuesTest.class */
public class GLMPlugValuesTest extends TestUtil {
    @BeforeClass
    public static void setup() {
        stall_till_cloudsize(1);
    }

    @Test
    public void testNumeric() {
        Scope.enter();
        try {
            Frame build = new TestFrameBuilder().withColNames(new String[]{"x", "y", "z"}).withDataForCol(0, ard(new double[]{1.0d, Double.NaN})).withDataForCol(1, ard(new double[]{Double.NaN, 2.0d})).withDataForCol(2, ard(new double[]{2.0d, 8.0d})).build();
            Frame build2 = new TestFrameBuilder().withColNames(new String[]{"x", "y", "z"}).withDataForCol(0, ard(new double[]{1.0d, 4.0d})).withDataForCol(1, ard(new double[]{0.5d, 2.0d})).withDataForCol(2, ard(new double[]{2.0d, 8.0d})).build();
            Frame oneRowFrame = oneRowFrame(new String[]{"x", "y"}, new double[]{4.0d, 0.5d}, new String[0]);
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters();
            gLMParameters._response_column = "z";
            gLMParameters._family = GLMModel.GLMParameters.Family.gaussian;
            gLMParameters._standardize = false;
            gLMParameters._train = build._key;
            gLMParameters._ignore_const_cols = false;
            gLMParameters._intercept = false;
            gLMParameters._seed = 42L;
            GLMModel.GLMParameters clone = gLMParameters.clone();
            clone._train = build2._key;
            gLMParameters._missing_values_handling = GLMModel.GLMParameters.MissingValuesHandling.PlugValues;
            gLMParameters._plug_values = oneRowFrame._key;
            GLMModel gLMModel = new GLM(gLMParameters).trainModel().get();
            Scope.track_generic(gLMModel);
            GLMModel gLMModel2 = new GLM(clone).trainModel().get();
            Scope.track_generic(gLMModel2);
            Assert.assertEquals(gLMModel2.coefficients(), gLMModel.coefficients());
            gLMModel.testJavaScoring(build, Scope.track(new Frame[]{gLMModel.score(build)}), 1.0E-8d, 1.0E-15d, 1.0d);
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testCategorical() {
        Scope.enter();
        try {
            Frame build = new TestFrameBuilder().withColNames(new String[]{"x", "y"}).withVecTypes(new byte[]{4, 3}).withDataForCol(0, ar(new String[]{"a", "b"})).withDataForCol(1, ard(new double[]{1.0d, 2.0d})).build();
            Frame deepCopy = build.deepCopy(Key.make().toString());
            deepCopy.vec(0).setNA(1L);
            Scope.track(new Frame[]{deepCopy});
            DKV.put(deepCopy);
            Frame oneRowFrame = oneRowFrame(new String[]{"x"}, new double[0], "b");
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters();
            gLMParameters._response_column = "y";
            gLMParameters._family = GLMModel.GLMParameters.Family.gaussian;
            gLMParameters._standardize = false;
            gLMParameters._train = deepCopy._key;
            gLMParameters._ignore_const_cols = false;
            gLMParameters._intercept = false;
            gLMParameters._seed = 42L;
            GLMModel.GLMParameters clone = gLMParameters.clone();
            clone._train = build._key;
            gLMParameters._missing_values_handling = GLMModel.GLMParameters.MissingValuesHandling.PlugValues;
            gLMParameters._plug_values = oneRowFrame._key;
            GLMModel gLMModel = new GLM(gLMParameters).trainModel().get();
            Scope.track_generic(gLMModel);
            GLMModel gLMModel2 = new GLM(clone).trainModel().get();
            Scope.track_generic(gLMModel2);
            Assert.assertEquals(gLMModel2.coefficients(), gLMModel.coefficients());
            gLMModel.testJavaScoring(deepCopy, Scope.track(new Frame[]{gLMModel.score(deepCopy)}), 1.0E-8d, 1.0E-15d, 1.0d);
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testNumericInteraction() {
        Scope.enter();
        try {
            Frame build = new TestFrameBuilder().withColNames(new String[]{"x", "y", "z"}).withDataForCol(0, ard(new double[]{1.0d, 4.0d})).withDataForCol(1, ard(new double[]{Double.NaN, 2.0d})).withDataForCol(2, ard(new double[]{2.0d, 8.0d})).build();
            Frame build2 = new TestFrameBuilder().withColNames(new String[]{"x_y", "x", "y", "z"}).withDataForCol(0, ard(new double[]{0.5d, 8.0d})).withDataForCol(1, ard(new double[]{1.0d, 4.0d})).withDataForCol(2, ard(new double[]{0.5d, 2.0d})).withDataForCol(3, ard(new double[]{2.0d, 8.0d})).build();
            Frame oneRowFrame = oneRowFrame(new String[]{"x_y", "x", "y"}, new double[]{0.5d, 4.0d, 0.5d}, new String[0]);
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters();
            gLMParameters._response_column = "z";
            gLMParameters._family = GLMModel.GLMParameters.Family.gaussian;
            gLMParameters._standardize = false;
            gLMParameters._train = build._key;
            gLMParameters._ignore_const_cols = false;
            gLMParameters._intercept = false;
            gLMParameters._seed = 42L;
            GLMModel.GLMParameters clone = gLMParameters.clone();
            clone._train = build2._key;
            gLMParameters._missing_values_handling = GLMModel.GLMParameters.MissingValuesHandling.PlugValues;
            gLMParameters._plug_values = oneRowFrame._key;
            gLMParameters._interaction_pairs = new StringPair[]{new StringPair("x", "y")};
            GLMModel gLMModel = new GLM(gLMParameters).trainModel().get();
            Scope.track_generic(gLMModel);
            GLMModel gLMModel2 = new GLM(clone).trainModel().get();
            Scope.track_generic(gLMModel2);
            Assert.assertNotEquals(0, gLMModel.coefficients().get("x_y"));
            Assert.assertEquals(gLMModel2.coefficients(), gLMModel.coefficients());
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testCatCatInteraction_smoke() {
        Scope.enter();
        try {
            Frame build = new TestFrameBuilder().withColNames(new String[]{"n", "x", "y", "z"}).withVecTypes(new byte[]{3, 4, 4, 3}).withDataForCol(0, ar(new long[]{0, 1, 0, 1})).withDataForCol(1, ar(new String[]{"a", "b", "a", "b"})).withDataForCol(2, ar(new String[]{"A", "B", "B", "A"})).withDataForCol(3, ard(new double[]{2.0d, 8.0d, 4.0d, 1.0d})).build();
            Frame oneRowFrame = oneRowFrame(new String[]{"n", "x_y", "x", "y"}, new double[]{0.0d}, "a_A", "a", "B");
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters();
            gLMParameters._response_column = "z";
            gLMParameters._family = GLMModel.GLMParameters.Family.gaussian;
            gLMParameters._standardize = false;
            gLMParameters._train = build._key;
            gLMParameters._ignore_const_cols = false;
            gLMParameters._intercept = false;
            gLMParameters._seed = 42L;
            gLMParameters._missing_values_handling = GLMModel.GLMParameters.MissingValuesHandling.PlugValues;
            gLMParameters._plug_values = oneRowFrame._key;
            gLMParameters._interaction_pairs = new StringPair[]{new StringPair("x", "y")};
            Scope.track_generic(new GLM(gLMParameters).trainModel().get());
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testNumCatInteraction_smoke() {
        Scope.enter();
        try {
            Frame build = new TestFrameBuilder().withColNames(new String[]{"x", "y", "z"}).withVecTypes(new byte[]{3, 4, 3}).withDataForCol(0, ard(new double[]{0.0d, Double.NaN, 0.0d, 1.0d})).withDataForCol(1, ar(new String[]{"a", "b", "a", "b"})).withDataForCol(2, ard(new double[]{2.0d, 8.0d, 4.0d, 1.0d})).build();
            Frame oneRowFrame = oneRowFrame(new String[]{"x", "x_y.a", "x_y.b", "y"}, new double[]{0.0d, 1.0d, 2.0d}, "b");
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters();
            gLMParameters._response_column = "z";
            gLMParameters._family = GLMModel.GLMParameters.Family.gaussian;
            gLMParameters._standardize = false;
            gLMParameters._train = build._key;
            gLMParameters._ignore_const_cols = false;
            gLMParameters._intercept = false;
            gLMParameters._seed = 42L;
            gLMParameters._missing_values_handling = GLMModel.GLMParameters.MissingValuesHandling.PlugValues;
            gLMParameters._plug_values = oneRowFrame._key;
            gLMParameters._interaction_pairs = new StringPair[]{new StringPair("x", "y")};
            Scope.track_generic(new GLM(gLMParameters).trainModel().get());
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testPlugValues_zeros() {
        Scope.enter();
        try {
            Frame parse_test_file = parse_test_file("smalldata/junit/cars.csv");
            Scope.track(new Frame[]{parse_test_file});
            parse_test_file.remove("name");
            DKV.put(parse_test_file);
            Assert.assertTrue(parse_test_file.vec("economy (mpg)").naCnt() > 0);
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.poisson, GLMModel.GLMParameters.Family.poisson.defaultLink, new double[]{0.0d}, new double[]{0.0d}, 0.0d, 0.0d);
            gLMParameters._response_column = "power (hp)";
            gLMParameters._train = parse_test_file._key;
            gLMParameters._lambda = new double[]{0.0d};
            gLMParameters._alpha = new double[]{0.0d};
            gLMParameters._missing_values_handling = GLMModel.GLMParameters.MissingValuesHandling.MeanImputation;
            gLMParameters._seed = 42L;
            GLMModel.GLMParameters clone = gLMParameters.clone();
            GLMModel.GLMParameters clone2 = gLMParameters.clone();
            GLMModel gLMModel = new GLM(gLMParameters).trainModel().get();
            Scope.track_generic(gLMModel);
            Frame clone3 = parse_test_file.clone();
            clone3.remove(gLMParameters._response_column);
            clone._plug_values = oneRowFrame(clone3.names(), clone3.means(), new String[0])._key;
            clone._missing_values_handling = GLMModel.GLMParameters.MissingValuesHandling.PlugValues;
            GLMModel gLMModel2 = new GLM(clone).trainModel().get();
            Scope.track_generic(gLMModel2);
            Assert.assertArrayEquals(gLMModel.beta(), gLMModel2.beta(), 0.0d);
            Assert.assertArrayEquals(gLMModel.dinfo()._numNAFill, gLMModel2.dinfo()._numNAFill, 0.0d);
            clone2._plug_values = oneRowFrame(clone3.names(), new double[clone3.numCols()], new String[0])._key;
            clone2._missing_values_handling = GLMModel.GLMParameters.MissingValuesHandling.PlugValues;
            GLMModel gLMModel3 = new GLM(clone2).trainModel().get();
            Scope.track_generic(gLMModel3);
            Assert.assertArrayEquals(gLMModel3.dinfo()._numNAFill, new double[clone3.numCols()], 0.0d);
            Assert.assertNotEquals(gLMModel.coefficients().get("economy (mpg)"), gLMModel3.coefficients().get("economy (mpg)"));
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    private static Frame oneRowFrame(String[] strArr, double[] dArr, String... strArr2) {
        TestFrameBuilder withColNames = new TestFrameBuilder().withColNames(strArr);
        byte[] bArr = new byte[dArr.length];
        Arrays.fill(bArr, (byte) 3);
        byte[] bArr2 = new byte[strArr2.length];
        Arrays.fill(bArr2, (byte) 4);
        withColNames.withVecTypes(ArrayUtils.append(bArr, bArr2));
        for (int i = 0; i < dArr.length; i++) {
            withColNames.withDataForCol(i, new double[]{dArr[i]});
        }
        for (int i2 = 0; i2 < strArr2.length; i2++) {
            withColNames.withDataForCol(i2 + dArr.length, new String[]{strArr2[i2]});
        }
        return withColNames.build();
    }
}
