package hex.genmodel.algos.xgboost;

import hex.genmodel.MojoModel;
import hex.genmodel.MojoReaderBackendFactory;
import hex.genmodel.PredictContributions;
import hex.genmodel.algos.tree.SharedTreeMojoModel;
import hex.genmodel.algos.xgboost.XGBoostMojoModel;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.prediction.RegressionModelPrediction;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:hex/genmodel/algos/xgboost/XGBoostJavaMojoModelTest.class */
public class XGBoostJavaMojoModelTest {
    @Test
    public void testObjFunction() {
        for (XGBoostMojoModel.ObjectiveType objectiveType : XGBoostMojoModel.ObjectiveType.values()) {
            Assert.assertNotNull(objectiveType.getId());
            Assert.assertFalse(objectiveType.getId().isEmpty());
            Assert.assertNotNull(XGBoostJavaMojoModel.getObjFunction(objectiveType.getId()));
        }
    }

    @Test
    public void testPredictContributionsSerialization() throws Exception {
        PredictContributions makeContributionsPredictor = MojoModel.load(MojoReaderBackendFactory.createReaderBackend(XGBoostJavaMojoModelTest.class.getResource("xgboost_java.zip"), MojoReaderBackendFactory.CachingStrategy.MEMORY)).makeContributionsPredictor();
        Assert.assertNotNull(makeContributionsPredictor);
        Assert.assertTrue(deserialize(serialize(makeContributionsPredictor)) instanceof PredictContributions);
    }

    @Test
    public void testLeafNodeAssignments() throws Exception {
        XGBoostJavaMojoModel load = MojoModel.load(MojoReaderBackendFactory.createReaderBackend(getClass().getResource("xgboost_java.zip"), MojoReaderBackendFactory.CachingStrategy.MEMORY));
        double[] dArr = {1.0d, 2.0d, 3.0d, 4.0d, 5.0d, 6.0d, 7.0d};
        SharedTreeMojoModel.LeafNodeAssignments leafNodeAssignments = load.getLeafNodeAssignments(dArr);
        Assert.assertNotNull(leafNodeAssignments._nodeIds);
        Assert.assertNotNull(leafNodeAssignments._paths);
        Assert.assertArrayEquals(load.getDecisionPath(dArr), leafNodeAssignments._paths);
        RowData rowData = new RowData();
        for (int i = 0; i < dArr.length; i++) {
            rowData.put(load._names[i], Double.valueOf(dArr[i]));
        }
        RegressionModelPrediction predict = new EasyPredictModelWrapper(new EasyPredictModelWrapper.Config().setModel(load).setEnableLeafAssignment(true)).predict(rowData);
        Assert.assertNotNull(predict.leafNodeAssignmentIds);
        Assert.assertNotNull(predict.leafNodeAssignments);
        Assert.assertArrayEquals(leafNodeAssignments._nodeIds, predict.leafNodeAssignmentIds);
        Assert.assertArrayEquals(leafNodeAssignments._paths, predict.leafNodeAssignments);
    }

    private static byte[] serialize(Object obj) throws Exception {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(byteArrayOutputStream);
        Throwable th = null;
        try {
            try {
                objectOutputStream.writeObject(obj);
                if (objectOutputStream != null) {
                    if (0 != 0) {
                        try {
                            objectOutputStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        objectOutputStream.close();
                    }
                }
                return byteArrayOutputStream.toByteArray();
            } finally {
            }
        } catch (Throwable th3) {
            if (objectOutputStream != null) {
                if (th != null) {
                    try {
                        objectOutputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    objectOutputStream.close();
                }
            }
            throw th3;
        }
    }

    private static Object deserialize(byte[] bArr) throws Exception {
        ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(bArr);
        Throwable th = null;
        try {
            Object readObject = new ObjectInputStream(byteArrayInputStream).readObject();
            if (byteArrayInputStream != null) {
                if (0 != 0) {
                    try {
                        byteArrayInputStream.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    byteArrayInputStream.close();
                }
            }
            return readObject;
        } catch (Throwable th3) {
            if (byteArrayInputStream != null) {
                if (0 != 0) {
                    try {
                        byteArrayInputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    byteArrayInputStream.close();
                }
            }
            throw th3;
        }
    }
}
