package hex.tree;

import hex.Model;
import hex.genmodel.tools.PrintMojo;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import hex.tree.isofor.IsolationForest;
import hex.tree.isofor.IsolationForestModel;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.security.Permission;
import java.util.Comparator;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import water.Key;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;

/* loaded from: input_file:hex/tree/PrintMojoTreeTest.class */
public class PrintMojoTreeTest {

    @Rule
    public TemporaryFolder folder = new TemporaryFolder();
    private SecurityManager originalSecurityManager;

    /* loaded from: input_file:hex/tree/PrintMojoTreeTest$PreventExitSecurityManager.class */
    private static class PreventExitSecurityManager extends SecurityManager {
        private PreventExitSecurityManager() {
        }

        @Override // java.lang.SecurityManager
        public void checkPermission(Permission permission) {
        }

        @Override // java.lang.SecurityManager
        public void checkPermission(Permission permission, Object obj) {
        }

        @Override // java.lang.SecurityManager
        public void checkExit(int i) {
            throw new PreventedExitException(i);
        }
    }

    /* loaded from: input_file:hex/tree/PrintMojoTreeTest$PreventedExitException.class */
    protected static class PreventedExitException extends SecurityException {
        public final int status;

        public PreventedExitException(int i) {
            this.status = i;
        }
    }

    @Before
    public void setUp() throws Exception {
        TestUtil.stall_till_cloudsize(1);
        this.originalSecurityManager = System.getSecurityManager();
        System.setSecurityManager(new PreventExitSecurityManager());
    }

    @After
    public void tearDown() throws Exception {
        System.setSecurityManager(this.originalSecurityManager);
    }

    @Test
    public void testMojoCategoricalPrint() throws IOException {
        try {
            Scope.enter();
            Frame track = Scope.track(new Frame[]{TestUtil.parse_test_file("smalldata/iris/iris.csv")});
            IsolationForestModel.IsolationForestParameters isolationForestParameters = new IsolationForestModel.IsolationForestParameters();
            isolationForestParameters._train = track._key;
            isolationForestParameters._ignored_columns = new String[]{"C1", "C2", "C3", "C4"};
            isolationForestParameters._seed = 65261L;
            isolationForestParameters._ntrees = 1;
            IsolationForestModel isolationForestModel = new IsolationForest(isolationForestParameters).trainModel().get();
            File newFile = this.folder.newFile();
            isolationForestModel.exportMojo(newFile.getAbsolutePath(), true);
            File newFile2 = this.folder.newFile();
            try {
                PrintMojo.main(new String[]{"--input", newFile.getAbsolutePath(), "--output", newFile2.getAbsolutePath()});
                Assert.fail("Expected PrintMojo to call System.exit()");
            } catch (PreventedExitException e) {
            }
            String readFileToString = FileUtils.readFileToString(newFile2);
            System.out.println(readFileToString);
            Assert.assertFalse(readFileToString.isEmpty());
            Pattern compile = Pattern.compile("label{1}=\\\"(.*?)\\\"");
            Pattern compile2 = Pattern.compile(".*[<>=].*");
            Matcher matcher = compile.matcher(readFileToString);
            Assert.assertEquals(1L, matcher.groupCount());
            while (matcher.find()) {
                Assert.assertFalse(compile2.matcher(matcher.group(1)).matches());
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testMojoCategoricalPrint_limitedLevels() throws IOException {
        try {
            Scope.enter();
            Frame track = Scope.track(new Frame[]{TestUtil.parse_test_file("smalldata/iris/iris.csv")});
            IsolationForestModel.IsolationForestParameters isolationForestParameters = new IsolationForestModel.IsolationForestParameters();
            isolationForestParameters._train = track._key;
            isolationForestParameters._ignored_columns = new String[]{"C1", "C2", "C3", "C4"};
            isolationForestParameters._seed = 65261L;
            isolationForestParameters._ntrees = 1;
            IsolationForestModel isolationForestModel = new IsolationForest(isolationForestParameters).trainModel().get();
            File newFile = this.folder.newFile();
            isolationForestModel.exportMojo(newFile.getAbsolutePath(), true);
            File newFile2 = this.folder.newFile();
            try {
                PrintMojo.main(new String[]{"--input", newFile.getAbsolutePath(), "--output", newFile2.getAbsolutePath(), "--levels", "1"});
                Assert.fail("Expected PrintMojo to call System.exit()");
            } catch (PreventedExitException e) {
            }
            Assert.assertTrue(FileUtils.readFileToString(newFile2).contains("\"SG_0_Node_0\" -> \"SG_0_Node_1\" [fontsize=14, label=\"[NA]\n2 levels\n\"]\n\"SG_0_Node_0\" -> \"SG_0_Node_4\" [fontsize=14, label=\"Iris-virginica\n\"]\n\"SG_0_Node_1\" -> \"SG_0_Node_5\" [fontsize=14, label=\"Iris-versicolor\n\"]\n\"SG_0_Node_1\" -> \"SG_0_Node_6\" [fontsize=14, label=\"[NA]\nIris-setosa\n\"]"));
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testMojoCategoricalPrint_internalRepresentationOutput() throws IOException {
        try {
            Scope.enter();
            Frame track = Scope.track(new Frame[]{TestUtil.parse_test_file("smalldata/iris/iris.csv")});
            IsolationForestModel.IsolationForestParameters isolationForestParameters = new IsolationForestModel.IsolationForestParameters();
            isolationForestParameters._train = track._key;
            isolationForestParameters._ignored_columns = new String[]{"C1", "C2", "C3", "C4"};
            isolationForestParameters._seed = 65261L;
            isolationForestParameters._ntrees = 1;
            IsolationForestModel isolationForestModel = new IsolationForest(isolationForestParameters).trainModel().get();
            File newFile = this.folder.newFile();
            isolationForestModel.exportMojo(newFile.getAbsolutePath(), true);
            File newFile2 = this.folder.newFile();
            try {
                PrintMojo.main(new String[]{"--input", newFile.getAbsolutePath(), "--output", newFile2.getAbsolutePath(), "--internal"});
                Assert.fail("Expected PrintMojo to call System.exit()");
            } catch (PreventedExitException e) {
            }
            String readFileToString = FileUtils.readFileToString(newFile2);
            Assert.assertFalse(readFileToString.isEmpty());
            System.out.println(readFileToString);
            Pattern compile = Pattern.compile("label{1}=\\\"(.*?)\\\"");
            Pattern compile2 = Pattern.compile(".*[<>=].*");
            Matcher matcher = compile.matcher(readFileToString);
            Assert.assertEquals(1L, matcher.groupCount());
            int i = 0;
            while (matcher.find()) {
                if (compile2.matcher(matcher.group(1)).matches()) {
                    i++;
                }
            }
            Assert.assertTrue(i > 0);
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    private void assertMojoJSONEqualsFixture(Model model, String str) throws IOException {
        File newFile = this.folder.newFile();
        model.exportMojo(newFile.getAbsolutePath(), true);
        File newFile2 = this.folder.newFile();
        try {
            PrintMojo.main(new String[]{"--input", newFile.getAbsolutePath(), "--output", newFile2.getAbsolutePath(), "--format", "json"});
            Assert.fail("Expected PrintMojo to call System.exit()");
        } catch (PreventedExitException e) {
        }
        String readFileToString = FileUtils.readFileToString(newFile2);
        Assert.assertFalse(readFileToString.isEmpty());
        Assert.assertEquals(removeH2OVersion(IOUtils.toString(getClass().getResourceAsStream(str))), removeH2OVersion(readFileToString));
    }

    private void assertMojoPngGenerated(Model model, String[] strArr) throws IOException {
        Path path = this.folder.newFile().toPath();
        model.exportMojo(path.toAbsolutePath().toString(), true);
        Path path2 = this.folder.newFile("exampleh2o.png").toPath();
        try {
            PrintMojo.main(new String[]{"--input", path.toAbsolutePath().toString(), "--output", path2.toAbsolutePath().toString(), "--format", "png"});
            Assert.fail("Expected PrintMojo to call System.exit()");
        } catch (PreventedExitException e) {
        }
        int length = strArr.length;
        if (length <= 1) {
            Assert.assertTrue(path2.endsWith(strArr[0]));
            return;
        }
        List list = (List) Files.list(path2).sorted(Comparator.reverseOrder()).collect(Collectors.toList());
        for (int i = 0; i < length; i++) {
            Assert.assertTrue(((Path) list.get(i)).endsWith(strArr[i]));
        }
    }

    private String removeH2OVersion(String str) {
        return str.replaceAll("\"h2o_version\": \"[\\d\\.]+\"", "h2o_version");
    }

    @Test
    public void testMojoCategoricalJson() throws IOException {
        try {
            Scope.enter();
            Frame track = Scope.track(new Frame[]{TestUtil.parse_test_file("smalldata/testng/airlines.csv")});
            IsolationForestModel.IsolationForestParameters isolationForestParameters = new IsolationForestModel.IsolationForestParameters();
            isolationForestParameters._train = track._key;
            isolationForestParameters._seed = 65261L;
            isolationForestParameters._response_column = "IsDepDelayed";
            isolationForestParameters._ntrees = 1;
            isolationForestParameters._max_depth = 3;
            isolationForestParameters._ignored_columns = new String[]{"Origin", "Dest", "IsDepDelayed"};
            assertMojoJSONEqualsFixture(new IsolationForest(isolationForestParameters).trainModel().get(), "categorical.json");
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testMojoCategoricalOneHotJson() throws IOException {
        try {
            Scope.enter();
            Frame track = Scope.track(new Frame[]{TestUtil.parse_test_file("smalldata/testng/airlines.csv")});
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = track._key;
            gBMParameters._seed = 65261L;
            gBMParameters._response_column = "IsDepDelayed";
            gBMParameters._categorical_encoding = Model.Parameters.CategoricalEncodingScheme.OneHotExplicit;
            gBMParameters._ntrees = 2;
            gBMParameters._max_depth = 3;
            gBMParameters._ignored_columns = new String[]{"Origin", "Dest"};
            assertMojoJSONEqualsFixture(new GBM(gBMParameters).trainModel().get(), "categoricalOneHot.json");
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testMojoGBMJson() throws IOException {
        try {
            Scope.enter();
            Frame track = Scope.track(new Frame[]{TestUtil.parse_test_file("smalldata/extdata/prostate.csv")});
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = track._key;
            gBMParameters._response_column = "CAPSULE";
            gBMParameters._ignored_columns = new String[]{"ID"};
            gBMParameters._seed = 1L;
            gBMParameters._ntrees = 2;
            gBMParameters._max_depth = 3;
            assertMojoJSONEqualsFixture(new GBM(gBMParameters).trainModel().get(), "gbmProstate.json");
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testMojoCategoricalPng() throws IOException {
        try {
            Scope.enter();
            Frame track = Scope.track(new Frame[]{TestUtil.parse_test_file("smalldata/testng/airlines.csv")});
            IsolationForestModel.IsolationForestParameters isolationForestParameters = new IsolationForestModel.IsolationForestParameters();
            isolationForestParameters._train = track._key;
            isolationForestParameters._seed = 65261L;
            isolationForestParameters._response_column = "IsDepDelayed";
            isolationForestParameters._ntrees = 1;
            isolationForestParameters._max_depth = 3;
            isolationForestParameters._ignored_columns = new String[]{"Origin", "Dest", "IsDepDelayed"};
            assertMojoPngGenerated(new IsolationForest(isolationForestParameters).trainModel().get(), new String[]{"exampleh2o.png"});
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testMojoCategoricalOneHotPng() throws IOException {
        try {
            Scope.enter();
            Frame track = Scope.track(new Frame[]{TestUtil.parse_test_file("smalldata/testng/airlines.csv")});
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = track._key;
            gBMParameters._seed = 65261L;
            gBMParameters._response_column = "IsDepDelayed";
            gBMParameters._categorical_encoding = Model.Parameters.CategoricalEncodingScheme.OneHotExplicit;
            gBMParameters._ntrees = 2;
            gBMParameters._max_depth = 3;
            gBMParameters._ignored_columns = new String[]{"Origin", "Dest"};
            assertMojoPngGenerated(new GBM(gBMParameters).trainModel().get(), new String[]{"Tree1_ClassNO.png", "Tree0_ClassNO.png"});
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testMojoGBMPng() throws IOException {
        try {
            Scope.enter();
            Frame track = Scope.track(new Frame[]{TestUtil.parse_test_file("smalldata/extdata/prostate.csv")});
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = track._key;
            gBMParameters._response_column = "CAPSULE";
            gBMParameters._ignored_columns = new String[]{"ID"};
            gBMParameters._seed = 1L;
            gBMParameters._ntrees = 2;
            gBMParameters._max_depth = 3;
            assertMojoPngGenerated(new GBM(gBMParameters).trainModel().get(), new String[]{"Tree1.png", "Tree0.png"});
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }
}
