package hex.genmodel;

import com.google.common.io.ByteStreams;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.exception.PredictException;
import hex.genmodel.tools.BuildPipeline;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

/* loaded from: input_file:hex/genmodel/MojoPipelineBuilderTest.class */
public class MojoPipelineBuilderTest {

    @Rule
    public TemporaryFolder tmp = new TemporaryFolder();
    private MojoPipelineBuilder builder;
    private File targetMojoPipelineFile;
    private File kmeansMojoFile;
    private File glmMojoFile;

    @Before
    public void setup() throws IOException {
        this.builder = new MojoPipelineBuilder();
        this.targetMojoPipelineFile = this.tmp.newFile("mojo-pipeline.zip");
        this.kmeansMojoFile = copyMojoFileResource("kmeans_model.zip");
        this.glmMojoFile = copyMojoFileResource("glm_model.zip");
    }

    @Test
    public void testPipelineBuilder() throws IOException, PredictException {
        this.builder.addModel("clustering", this.kmeansMojoFile).addMapping("CLUSTER", "clustering", 0).addMainModel("regression", this.glmMojoFile).buildPipeline(this.targetMojoPipelineFile);
        checkPipeline(this.targetMojoPipelineFile);
    }

    @Test
    public void testBuildPipelineTool() throws IOException, PredictException {
        BuildPipeline.main(new String[]{"--mapping", "CLUSTER=kmeans_model:0", "--input", this.kmeansMojoFile.getAbsolutePath(), this.glmMojoFile.getAbsolutePath(), "--output", this.targetMojoPipelineFile.getAbsolutePath()});
        checkPipeline(this.targetMojoPipelineFile);
    }

    private static void checkPipeline(File file) throws IOException, PredictException {
        EasyPredictModelWrapper easyPredictModelWrapper = new EasyPredictModelWrapper(MojoModel.load(file.getAbsolutePath()));
        Assert.assertEquals(0.7812266d, easyPredictModelWrapper.predict(prow(71.0d, 1, 3.0d, 2.0d, 3.3d, 0.0d, 8.0d)).value, 1.0E-7d);
        Assert.assertEquals(0.5690164d, easyPredictModelWrapper.predict(prow(76.0d, 2, 2.0d, 1.0d, 51.2d, 20.0d, 7.0d)).value, 1.0E-7d);
    }

    private static RowData prow(double d, int i, double d2, double d3, double d4, double d5, double d6) {
        RowData rowData = new RowData();
        rowData.put("AGE", Double.valueOf(d));
        rowData.put("RACE", String.valueOf(i));
        rowData.put("DPROS", Double.valueOf(d2));
        rowData.put("DCAPS", Double.valueOf(d3));
        rowData.put("PSA", Double.valueOf(d4));
        rowData.put("VOL", Double.valueOf(d5));
        rowData.put("GLEASON", Double.valueOf(d6));
        return rowData;
    }

    private File copyMojoFileResource(String str) throws IOException {
        File newFile = this.tmp.newFile(str);
        FileOutputStream fileOutputStream = new FileOutputStream(newFile);
        Throwable th = null;
        try {
            InputStream resourceAsStream = getClass().getResourceAsStream("/hex/genmodel/algos/pipeline/" + str);
            Throwable th2 = null;
            try {
                try {
                    ByteStreams.copy(resourceAsStream, fileOutputStream);
                    if (resourceAsStream != null) {
                        if (0 != 0) {
                            try {
                                resourceAsStream.close();
                            } catch (Throwable th3) {
                                th2.addSuppressed(th3);
                            }
                        } else {
                            resourceAsStream.close();
                        }
                    }
                    return newFile;
                } finally {
                }
            } catch (Throwable th4) {
                if (resourceAsStream != null) {
                    if (th2 != null) {
                        try {
                            resourceAsStream.close();
                        } catch (Throwable th5) {
                            th2.addSuppressed(th5);
                        }
                    } else {
                        resourceAsStream.close();
                    }
                }
                throw th4;
            }
        } finally {
            if (fileOutputStream != null) {
                if (0 != 0) {
                    try {
                        fileOutputStream.close();
                    } catch (Throwable th6) {
                        th.addSuppressed(th6);
                    }
                } else {
                    fileOutputStream.close();
                }
            }
        }
    }
}
