package hex.deeplearning;

import hex.deeplearning.DeepLearningModel;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.exception.PredictException;
import java.io.IOException;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.Key;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.Vec;
import water.parser.BufferedString;
import water.parser.ParseDataset;
import water.util.Log;

/* loaded from: input_file:hex/deeplearning/DeepLearningAutoEncoderCategoricalTest.class */
public class DeepLearningAutoEncoderCategoricalTest extends TestUtil {
    static final String PATH = "smalldata/airlines/AirlinesTrain.csv.zip";

    @BeforeClass
    public static void setup() {
        stall_till_cloudsize(1);
    }

    @Test
    public void run() {
        Frame frame = null;
        try {
            frame = ParseDataset.parse(Key.make("train.hex"), new Key[]{TestUtil.makeNfsFileVec(PATH)._key});
            DeepLearningModel.DeepLearningParameters deepLearningParameters = new DeepLearningModel.DeepLearningParameters();
            deepLearningParameters._train = frame._key;
            deepLearningParameters._autoencoder = true;
            deepLearningParameters._response_column = frame.names()[frame.names().length - 1];
            deepLearningParameters._seed = 912559L;
            deepLearningParameters._hidden = new int[]{10, 5, 3};
            deepLearningParameters._adaptive_rate = true;
            deepLearningParameters._l1 = 1.0E-4d;
            deepLearningParameters._activation = DeepLearningModel.DeepLearningParameters.Activation.Tanh;
            deepLearningParameters._max_w2 = 10.0f;
            deepLearningParameters._train_samples_per_iteration = -1L;
            deepLearningParameters._loss = DeepLearningModel.DeepLearningParameters.Loss.Huber;
            deepLearningParameters._epochs = 0.2d;
            deepLearningParameters._force_load_balance = true;
            deepLearningParameters._score_training_samples = 0L;
            deepLearningParameters._score_validation_samples = 0L;
            deepLearningParameters._reproducible = true;
            Frame frame2 = null;
            Frame frame3 = null;
            Frame frame4 = null;
            Frame frame5 = null;
            Frame frame6 = null;
            DeepLearningModel deepLearningModel = null;
            StringBuilder sb = new StringBuilder();
            try {
                DeepLearningModel deepLearningModel2 = new DeepLearning(deepLearningParameters).trainModel().get();
                sb.append("Verifying results.\n");
                sb.append("Reported mean reconstruction error: " + deepLearningModel2.mse() + "\n");
                Frame scoreAutoEncoder = deepLearningModel2.scoreAutoEncoder(frame, Key.make(), true);
                sb.append("Reconstruction error per feature: " + scoreAutoEncoder.toString() + "\n");
                scoreAutoEncoder.remove();
                Frame scoreAutoEncoder2 = deepLearningModel2.scoreAutoEncoder(frame, Key.make(), false);
                Vec anyVec = scoreAutoEncoder2.anyVec();
                sb.append("Actual   mean reconstruction error: " + anyVec.mean() + "\n");
                double numRows = 1.0d - (5.0d / frame.numRows());
                sb.append("The following training points are reconstructed with an error above the " + (numRows * 100.0d) + "-th percentile - potential \"outliers\" in testing data.\n");
                double calcOutlierThreshold = deepLearningModel2.calcOutlierThreshold(anyVec, numRows);
                for (long j = 0; j < anyVec.length(); j++) {
                    if (anyVec.at(j) > calcOutlierThreshold) {
                        sb.append(String.format("row %d : l2vec error = %5f\n", Long.valueOf(j), Double.valueOf(anyVec.at(j))));
                    }
                }
                Log.info(new Object[]{sb.toString()});
                Assert.assertEquals(anyVec.mean(), deepLearningModel2.mse(), 1.0E-8d * deepLearningModel2.mse());
                Log.info(new Object[]{"Creating full reconstruction."});
                Frame score = deepLearningModel2.score(frame);
                Assert.assertTrue(deepLearningModel2.testJavaScoring(frame, score, 1.0E-5d));
                Frame scoreDeepFeatures = deepLearningModel2.scoreDeepFeatures(frame, 0);
                Assert.assertEquals(10L, scoreDeepFeatures.numCols());
                Assert.assertEquals(frame.numRows(), scoreDeepFeatures.numRows());
                Frame scoreDeepFeatures2 = deepLearningModel2.scoreDeepFeatures(frame, 1);
                Assert.assertEquals(5L, scoreDeepFeatures2.numCols());
                Assert.assertEquals(frame.numRows(), scoreDeepFeatures2.numRows());
                Frame scoreDeepFeatures3 = deepLearningModel2.scoreDeepFeatures(frame, 2);
                Assert.assertEquals(3L, scoreDeepFeatures3.numCols());
                Assert.assertEquals(frame.numRows(), scoreDeepFeatures3.numRows());
                try {
                    EasyPredictModelWrapper easyPredictModelWrapper = new EasyPredictModelWrapper(deepLearningModel2.toMojo());
                    double d = 0.0d;
                    for (int i = 0; i < frame.numRows(); i++) {
                        RowData rowData = new RowData();
                        BufferedString bufferedString = new BufferedString();
                        for (int i2 = 0; i2 < frame.numCols(); i2++) {
                            if (frame.vec(i2).isCategorical()) {
                                rowData.put(frame.names()[i2], frame.vec(i2).atStr(bufferedString, i).toString());
                            } else {
                                rowData.put(frame.names()[i2], Double.valueOf(frame.vec(i2).at(i)));
                            }
                        }
                        d += easyPredictModelWrapper.predictAutoEncoder(rowData).mse;
                    }
                    double numRows2 = d / frame.numRows();
                    sb.append("Mojo mean reconstruction error (train): ").append(numRows2).append("\n");
                    sb.append("Mean reconstruction error should be the same from model compare to mojo model reconstruction error: ");
                    sb.append(deepLearningModel2.mse()).append(" == ").append(numRows2).append("\n");
                    Assert.assertEquals(deepLearningModel2.mse(), numRows2, 1.0E-7d);
                } catch (IOException e) {
                    Assert.fail(e.getStackTrace().toString());
                } catch (PredictException e2) {
                    Assert.fail(e2.getStackTrace().toString());
                }
                Log.info(new Object[]{sb});
                if (score != null) {
                    score.delete();
                }
                if (scoreAutoEncoder2 != null) {
                    scoreAutoEncoder2.delete();
                }
                if (deepLearningModel2 != null) {
                    deepLearningModel2.delete();
                }
                if (scoreDeepFeatures != null) {
                    scoreDeepFeatures.delete();
                }
                if (scoreDeepFeatures2 != null) {
                    scoreDeepFeatures2.delete();
                }
                if (scoreDeepFeatures3 != null) {
                    scoreDeepFeatures3.delete();
                }
                if (frame != null) {
                    frame.delete();
                }
            } catch (Throwable th) {
                Log.info(new Object[]{sb});
                if (0 != 0) {
                    frame2.delete();
                }
                if (0 != 0) {
                    frame3.delete();
                }
                if (0 != 0) {
                    deepLearningModel.delete();
                }
                if (0 != 0) {
                    frame4.delete();
                }
                if (0 != 0) {
                    frame5.delete();
                }
                if (0 != 0) {
                    frame6.delete();
                }
                throw th;
            }
        } catch (Throwable th2) {
            if (frame != null) {
                frame.delete();
            }
            throw th2;
        }
    }
}
