package hex.deeplearning;

import hex.DataInfo;
import hex.Model;
import hex.ModelBuilder;
import hex.SupervisedModelBuilder;
import hex.deeplearning.DeepLearningModel;
import hex.schemas.DeepLearningV3;
import hex.schemas.ModelBuilderSchema;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import water.AutoBuffer;
import water.DKV;
import water.H2O;
import water.H2ONode;
import water.Job;
import water.Key;
import water.Scope;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.RebalanceDataSet;
import water.fvec.Vec;
import water.init.Linpack;
import water.init.NetworkTest;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.MRUtils;
import water.util.PrettyPrint;

/* loaded from: input_file:hex/deeplearning/DeepLearning.class */
public class DeepLearning extends SupervisedModelBuilder<DeepLearningModel, DeepLearningModel.DeepLearningParameters, DeepLearningModel.DeepLearningModelOutput> {

    /* loaded from: input_file:hex/deeplearning/DeepLearning$DeepLearningDriver.class */
    public class DeepLearningDriver extends H2O.H2OCountedCompleter<DeepLearningDriver> {
        final transient String[] cp_modifiable = {"_seed", "_epochs", "_score_interval", "_train_samples_per_iteration", "_target_ratio_comm_to_comp", "_score_duty_cycle", "_score_training_samples", "_classification_stop", "_regression_stop", "_quiet_mode", "_max_confusion_matrix_size", "_max_hit_ratio_k", "_diagnostics", "_variable_importances", "_force_load_balance", "_replicate_training_data", "_shuffle_training_data", "_single_node_mode", "_fast_mode", "_l1", "_l2", "_max_w2", "_input_dropout_ratio", "_hidden_dropout_ratios", "_loss", "_overwrite_with_best_model", "_missing_values_handling", "_reproducible", "_export_weights_and_biases"};
        final transient String[] cp_not_modifiable = {"_drop_na20_cols", "_response_column", "_activation", "_use_all_factor_levels", "_adaptive_rate", "_autoencoder", "_rho", "_epsilon", "_sparse", "_sparsity_beta", "_col_major", "_rate", "_momentum_start", "_momentum_ramp", "_momentum_stable", "_nesterov_accelerated_gradient", "_ignore_const_cols", "_max_categorical_features"};
        transient HashSet<Frame> _delete_me = new HashSet<>();
        static final /* synthetic */ boolean $assertionsDisabled;

        public DeepLearningDriver() {
        }

        protected void compute2() {
            byte[] buf;
            try {
                try {
                    buf = new AutoBuffer().put(DeepLearning.this._parms).buf();
                    Scope.enter();
                    DeepLearning.this._parms.read_lock_frames(DeepLearning.this);
                    DeepLearning.this.init(true);
                } catch (Throwable th) {
                    if (DKV.getGet(DeepLearning.this._key)._state != Job.JobState.CANCELLED) {
                        DeepLearning.this.failed(th);
                        throw th;
                    }
                    Log.info(new Object[]{"Job cancelled by user."});
                    DeepLearning.this._parms.read_unlock_frames(DeepLearning.this);
                    Scope.exit(new Key[0]);
                }
                if (DeepLearning.this.error_count() > 0) {
                    throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(DeepLearning.this);
                }
                buildModel();
                byte[] buf2 = new AutoBuffer().put(DeepLearning.this._parms).buf();
                if (!$assertionsDisabled && !Arrays.equals(buf, buf2)) {
                    throw new AssertionError();
                }
                DeepLearning.this.done();
                DeepLearning.this._parms.read_unlock_frames(DeepLearning.this);
                Scope.exit(new Key[0]);
                tryComplete();
            } catch (Throwable th2) {
                DeepLearning.this._parms.read_unlock_frames(DeepLearning.this);
                Scope.exit(new Key[0]);
                throw th2;
            }
        }

        Key self() {
            return DeepLearning.this._key;
        }

        /* JADX WARN: Finally extract failed */
        public final void buildModel() {
            Scope.enter();
            DeepLearningModel deepLearningModel = null;
            if (DeepLearning.this._parms._checkpoint == null) {
                deepLearningModel = new DeepLearningModel(DeepLearning.this.dest(), DeepLearning.this._parms, new DeepLearningModel.DeepLearningModelOutput(DeepLearning.this), DeepLearning.this._train, DeepLearning.this._valid);
                deepLearningModel.model_info().initializeMembers();
            } else {
                DeepLearningModel get = DKV.getGet(DeepLearning.this._parms._checkpoint);
                if (get == null) {
                    throw new IllegalArgumentException("Checkpoint not found.");
                }
                Log.info(new Object[]{"Resuming from checkpoint."});
                if (DeepLearning.this.isClassifier() != get._output.isClassifier()) {
                    throw new IllegalArgumentException("Response type must be the same as for the checkpointed model.");
                }
                if (DeepLearning.this.isSupervised() != get._output.isSupervised()) {
                    throw new IllegalArgumentException("Model type must be the same as for the checkpointed model.");
                }
                DeepLearningModel.DeepLearningParameters deepLearningParameters = (DeepLearningModel.DeepLearningParameters) get._parms;
                DeepLearningModel.DeepLearningParameters deepLearningParameters2 = (DeepLearningModel.DeepLearningParameters) DeepLearning.this._parms;
                new Job.ProgressUpdate("Resuming from checkpoint").fork(DeepLearning.this._progressKey);
                if (deepLearningParameters2.getNumFolds() != 0) {
                    throw new UnsupportedOperationException("n_folds must be 0: Cross-validation is not supported during checkpoint restarts.");
                }
                if ((DeepLearning.this._parms._valid == null) != (get._parms._valid == null) || (DeepLearning.this._parms._valid != null && !DeepLearning.this._parms._valid.equals(get._parms._valid))) {
                    throw new IllegalArgumentException("Validation dataset must be the same as for the checkpointed model.");
                }
                if (!deepLearningParameters2._autoencoder && (deepLearningParameters2._response_column == null || !deepLearningParameters2._response_column.equals(deepLearningParameters._response_column))) {
                    throw new IllegalArgumentException("Response column (" + deepLearningParameters2._response_column + ") is not the same as for the checkpointed model: " + deepLearningParameters._response_column);
                }
                if (!Arrays.equals(deepLearningParameters2._hidden, deepLearningParameters._hidden)) {
                    throw new IllegalArgumentException("Hidden layers (" + Arrays.toString(deepLearningParameters2._hidden) + ") is not the same as for the checkpointed model: " + Arrays.toString(deepLearningParameters._hidden));
                }
                if (!Arrays.equals(deepLearningParameters2._ignored_columns, deepLearningParameters._ignored_columns)) {
                    throw new IllegalArgumentException("Predictor columns must be the same as for the checkpointed model. Check ignored columns.");
                }
                loop4: for (Field field : deepLearningParameters.getClass().getDeclaredFields()) {
                    if (ArrayUtils.contains(this.cp_not_modifiable, field.getName())) {
                        for (Field field2 : deepLearningParameters2.getClass().getDeclaredFields()) {
                            if (field.equals(field2)) {
                                try {
                                    if ((field2.get(deepLearningParameters2) == null || field.get(deepLearningParameters) == null || !field.get(deepLearningParameters).toString().equals(field2.get(deepLearningParameters2).toString())) && (field.get(deepLearningParameters) != null || field2.get(deepLearningParameters2) != null)) {
                                        throw new IllegalArgumentException("Cannot change parameter: '" + field.getName() + "': " + field.get(deepLearningParameters) + " -> " + field2.get(deepLearningParameters2));
                                        break loop4;
                                    }
                                } catch (IllegalAccessException e) {
                                    e.printStackTrace();
                                }
                            }
                        }
                    }
                }
                try {
                    DataInfo makeDataInfo = DeepLearning.makeDataInfo(DeepLearning.this._train, DeepLearning.this._valid, DeepLearning.this._parms);
                    DKV.put(makeDataInfo._key, makeDataInfo);
                    deepLearningModel = new DeepLearningModel(DeepLearning.this.dest(), DeepLearning.this._parms, get, false, makeDataInfo);
                    deepLearningModel.write_lock(self());
                    DeepLearningModel.DeepLearningParameters deepLearningParameters3 = deepLearningModel.model_info().get_params();
                    if (!$assertionsDisabled && deepLearningParameters3 == get.model_info().get_params()) {
                        throw new AssertionError();
                    }
                    if (!$assertionsDisabled && deepLearningParameters3 == deepLearningParameters2) {
                        throw new AssertionError();
                    }
                    if (!$assertionsDisabled && deepLearningParameters3 == deepLearningParameters) {
                        throw new AssertionError();
                    }
                    if (!Arrays.equals(deepLearningModel._output._names, get._output._names)) {
                        throw new IllegalArgumentException("Predictor columns of the training data must be the same as for the checkpointed model. Check ignored columns.");
                    }
                    if (!Arrays.deepEquals(deepLearningModel._output._domains, get._output._domains)) {
                        throw new IllegalArgumentException("Categorical factor levels of the training data must be the same as for the checkpointed model.");
                    }
                    if (makeDataInfo.fullN() != get.model_info().data_info().fullN()) {
                        throw new IllegalArgumentException("Total number of predictors is different than for the checkpointed model.");
                    }
                    for (Field field3 : deepLearningParameters3.getClass().getDeclaredFields()) {
                        if (ArrayUtils.contains(this.cp_modifiable, field3.getName())) {
                            for (Field field4 : deepLearningParameters2.getClass().getDeclaredFields()) {
                                if (field3.equals(field4)) {
                                    try {
                                        if (field4.get(deepLearningParameters2) == null || field3.get(deepLearningParameters3) == null || !field3.get(deepLearningParameters3).toString().equals(field4.get(deepLearningParameters2).toString())) {
                                            if (field3.get(deepLearningParameters3) != null || field4.get(deepLearningParameters2) != null) {
                                                Log.info(new Object[]{"Applying user-requested modification of '" + field3.getName() + "': " + field3.get(deepLearningParameters3) + " -> " + field4.get(deepLearningParameters2)});
                                                field3.set(deepLearningParameters3, field4.get(deepLearningParameters2));
                                            }
                                        }
                                    } catch (IllegalAccessException e2) {
                                        e2.printStackTrace();
                                    }
                                }
                            }
                        }
                    }
                    DeepLearningModel.modifyParms(deepLearningParameters3, deepLearningParameters3, DeepLearning.this.isClassifier());
                    deepLearningParameters3._epochs += get.epoch_counter;
                    Log.info(new Object[]{"Adding " + String.format("%.3f", Double.valueOf(get.epoch_counter)) + " epochs from the checkpointed model."});
                    if (deepLearningParameters3.getNumFolds() != 0) {
                        Log.info(new Object[]{"Disabling cross-validation: Not supported when resuming training from a checkpoint."});
                        H2O.unimpl("writing to n_folds field needs to be uncommented");
                    }
                    deepLearningModel.update(self());
                    if (deepLearningModel != null) {
                        deepLearningModel.unlock(self());
                    }
                } catch (Throwable th) {
                    if (deepLearningModel != null) {
                        deepLearningModel.unlock(self());
                    }
                    throw th;
                }
            }
            trainModel(deepLearningModel);
            ArrayList arrayList = new ArrayList();
            arrayList.add(DeepLearning.this.dest());
            if (deepLearningModel._output._model_metrics.length != 0) {
                arrayList.add(deepLearningModel._output._model_metrics[deepLearningModel._output._model_metrics.length - 1]);
            }
            if (deepLearningModel._output.weights != null && deepLearningModel._output.biases != null) {
                for (Key key : Arrays.asList(deepLearningModel._output.weights)) {
                    arrayList.add(key);
                    for (Vec vec : DKV.getGet(key).vecs()) {
                        arrayList.add(vec._key);
                    }
                }
                for (Key key2 : Arrays.asList(deepLearningModel._output.biases)) {
                    arrayList.add(key2);
                    for (Vec vec2 : DKV.getGet(key2).vecs()) {
                        arrayList.add(vec2._key);
                    }
                }
            }
            Scope.exit((Key[]) arrayList.toArray(new Key[0]));
        }

        public final DeepLearningModel trainModel(DeepLearningModel deepLearningModel) {
            DeepLearningModel get;
            Frame frame = null;
            try {
                if (deepLearningModel == null) {
                    try {
                        deepLearningModel = (DeepLearningModel) DKV.get(DeepLearning.this.dest()).get();
                    } catch (Throwable th) {
                        DKV.get(DeepLearning.this.dest()).get();
                        Log.info(new Object[]{"Deep Learning model building was cancelled."});
                        throw new RuntimeException(th);
                    }
                }
                Object[] objArr = new Object[1];
                objArr[0] = "Model category: " + (DeepLearning.this._parms._autoencoder ? "Auto-Encoder" : DeepLearning.this.isClassifier() ? "Classification" : "Regression");
                Log.info(objArr);
                Log.info(new Object[]{"Number of model parameters (weights/biases): " + String.format("%,d", Long.valueOf(deepLearningModel.model_info().size()))});
                deepLearningModel.write_lock(self());
                new Job.ProgressUpdate("Setting up training data...").fork(DeepLearning.this._progressKey);
                DeepLearningModel.DeepLearningParameters deepLearningParameters = deepLearningModel.model_info().get_params();
                Frame frame2 = new Frame(deepLearningParameters.train()._key, DeepLearning.this._train.names(), DeepLearning.this._train.vecs());
                Frame frame3 = DeepLearning.this._valid != null ? new Frame(deepLearningParameters.valid()._key, DeepLearning.this._valid.names(), DeepLearning.this._valid.vecs()) : null;
                Frame frame4 = frame2;
                if (deepLearningParameters._force_load_balance) {
                    new Job.ProgressUpdate("Load balancing training data...").fork(DeepLearning.this._progressKey);
                    frame4 = reBalance(frame4, deepLearningParameters._replicate_training_data, deepLearningParameters._train.toString() + "." + deepLearningModel._key.toString() + ".train");
                }
                if (deepLearningModel._output.isClassifier() && deepLearningParameters._balance_classes) {
                    new Job.ProgressUpdate("Balancing class distribution of training data...").fork(DeepLearning.this._progressKey);
                    float[] fArr = new float[frame4.lastVec().domain().length];
                    if (deepLearningParameters._class_sampling_factors != null) {
                        if (deepLearningParameters._class_sampling_factors.length != frame4.lastVec().domain().length) {
                            throw new IllegalArgumentException("class_sampling_factors must have " + frame4.lastVec().domain().length + " elements");
                        }
                        fArr = (float[]) deepLearningParameters._class_sampling_factors.clone();
                    }
                    frame4 = MRUtils.sampleFrameStratified(frame4, frame4.lastVec(), fArr, deepLearningParameters._max_after_balance_size * ((float) frame4.numRows()), deepLearningParameters._seed, true, false);
                    deepLearningModel._output._modelClassDist = new MRUtils.ClassDist(frame4.lastVec()).doAll(new Vec[]{frame4.lastVec()}).rel_dist();
                }
                deepLearningModel._output.autoencoder = DeepLearning.this._parms._autoencoder;
                deepLearningModel.training_rows = frame4.numRows();
                Frame sampleFrame = MRUtils.sampleFrame(frame4, deepLearningParameters._score_training_samples, deepLearningParameters._seed);
                if (!DeepLearning.this._parms._quiet_mode) {
                    Log.info(new Object[]{"Number of chunks of the training data: " + frame4.anyVec().nChunks()});
                }
                if (frame3 != null) {
                    deepLearningModel.validation_rows = frame3.numRows();
                    if (deepLearningModel._output.isClassifier() && deepLearningParameters._balance_classes && deepLearningParameters._score_validation_sampling == DeepLearningModel.DeepLearningParameters.ClassSamplingMethod.Stratified) {
                        new Job.ProgressUpdate("Sampling validation data (stratified)...").fork(DeepLearning.this._progressKey);
                        frame = MRUtils.sampleFrameStratified(frame3, frame3.lastVec(), (float[]) null, deepLearningParameters._score_validation_samples > 0 ? deepLearningParameters._score_validation_samples : frame3.numRows(), deepLearningParameters._seed + 1, false, false);
                    } else {
                        new Job.ProgressUpdate("Sampling validation data...").fork(DeepLearning.this._progressKey);
                        frame = MRUtils.sampleFrame(frame3, deepLearningParameters._score_validation_samples, deepLearningParameters._seed + 1);
                    }
                    if (deepLearningParameters._force_load_balance) {
                        new Job.ProgressUpdate("Balancing class distribution of validation data...").fork(DeepLearning.this._progressKey);
                        frame = reBalance(frame, false, deepLearningParameters._valid.toString() + "." + deepLearningModel._key.toString() + ".valid");
                    }
                    if (!DeepLearning.this._parms._quiet_mode) {
                        Log.info(new Object[]{"Number of chunks of the validation data: " + frame.anyVec().nChunks()});
                    }
                }
                deepLearningModel.actual_train_samples_per_iteration = computeTrainSamplesPerIteration(deepLearningParameters, frame4.numRows(), deepLearningModel);
                if (deepLearningParameters._replicate_training_data) {
                    if (deepLearningModel.actual_train_samples_per_iteration == frame4.numRows() * (deepLearningParameters._single_node_mode ? 1 : H2O.CLOUD.size()) && !deepLearningParameters._shuffle_training_data && H2O.CLOUD.size() > 1 && !deepLearningParameters._reproducible) {
                        Log.info(new Object[]{"Enabling training data shuffling, because all nodes train on the full dataset (replicated training data)."});
                        deepLearningParameters._shuffle_training_data = true;
                    }
                }
                if (!deepLearningParameters._quiet_mode && deepLearningParameters._diagnostics) {
                    Log.info(new Object[]{"Initial model:\n" + deepLearningModel.model_info()});
                }
                if (DeepLearning.this._parms._autoencoder) {
                    new Job.ProgressUpdate("Scoring null model of autoencoder...").fork(DeepLearning.this._progressKey);
                    deepLearningModel.doScoring(sampleFrame, frame, self(), null);
                }
                deepLearningModel.update(self());
                deepLearningModel._timeLastScoreEnter = System.currentTimeMillis();
                Log.info(new Object[]{"Starting to train the Deep Learning model."});
                do {
                    new Job.ProgressUpdate("Training" + (deepLearningModel.run_time != 0 ? " at " + ((deepLearningModel.model_info().get_processed_total() * 1000) / deepLearningModel.run_time) + " samples/s..." : "...") + (deepLearningModel.run_time == 0 ? "" : " Estimated time left: " + PrettyPrint.msecs((long) ((deepLearningModel.run_time * (1.0d - DeepLearning.this.progress())) / DeepLearning.this.progress()), true))).fork(DeepLearning.this._progressKey);
                    deepLearningModel.set_model_info(deepLearningParameters._epochs == 0.0d ? deepLearningModel.model_info() : (H2O.CLOUD.size() <= 1 || !deepLearningParameters._replicate_training_data) ? ((DeepLearningTask) new DeepLearningTask(self(), deepLearningModel.model_info(), rowFraction(frame4, deepLearningParameters, deepLearningModel)).doAll(frame4)).model_info() : deepLearningParameters._single_node_mode ? ((DeepLearningTask2) new DeepLearningTask2(self(), frame4, deepLearningModel.model_info(), rowFraction(frame4, deepLearningParameters, deepLearningModel)).doAll(new Key[]{Key.make()})).model_info() : ((DeepLearningTask2) new DeepLearningTask2(self(), frame4, deepLearningModel.model_info(), rowFraction(frame4, deepLearningParameters, deepLearningModel)).doAllNodes()).model_info());
                    DeepLearning.this.update(deepLearningModel.actual_train_samples_per_iteration);
                } while (deepLearningModel.doScoring(sampleFrame, frame, self(), DeepLearning.this._progressKey));
                if (!DeepLearning.this.isCancelledOrCrashed() && DeepLearning.this._parms._overwrite_with_best_model && deepLearningModel.actual_best_model_key != null && DeepLearning.this._parms.getNumFolds() == 0 && (get = DKV.getGet(deepLearningModel.actual_best_model_key)) != null && get.error() < deepLearningModel.error() && Arrays.equals(get.model_info().units, deepLearningModel.model_info().units)) {
                    Log.info(new Object[]{"Setting the model to be the best model so far (based on scoring history)."});
                    DeepLearningModel.DeepLearningModelInfo deep_clone = get.model_info().deep_clone();
                    deep_clone.set_processed_global(deepLearningModel.model_info().get_processed_global());
                    deep_clone.set_processed_local(deepLearningModel.model_info().get_processed_local());
                    deepLearningModel.set_model_info(deep_clone);
                    deepLearningModel.update(self());
                    deepLearningModel.doScoring(sampleFrame, frame, self(), DeepLearning.this._progressKey);
                    if (!$assertionsDisabled && get.error() != deepLearningModel.error()) {
                        throw new AssertionError();
                    }
                }
                Log.info(new Object[]{"=============================================================================================================================================================================="});
                Log.info(new Object[]{"Finished training the Deep Learning model."});
                Log.info(new Object[]{deepLearningModel});
                Log.info(new Object[]{"=============================================================================================================================================================================="});
                if (deepLearningModel != null) {
                    deepLearningModel.unlock(self());
                    if (deepLearningModel.actual_best_model_key != null) {
                        if (!$assertionsDisabled && deepLearningModel.actual_best_model_key == deepLearningModel._key) {
                            throw new AssertionError();
                        }
                        DKV.remove(deepLearningModel.actual_best_model_key);
                    }
                }
                Iterator<Frame> it = this._delete_me.iterator();
                while (it.hasNext()) {
                    it.next().delete();
                }
                return deepLearningModel;
            } catch (Throwable th2) {
                if (deepLearningModel != null) {
                    deepLearningModel.unlock(self());
                    if (deepLearningModel.actual_best_model_key != null) {
                        if (!$assertionsDisabled && deepLearningModel.actual_best_model_key == deepLearningModel._key) {
                            throw new AssertionError();
                        }
                        DKV.remove(deepLearningModel.actual_best_model_key);
                    }
                }
                Iterator<Frame> it2 = this._delete_me.iterator();
                while (it2.hasNext()) {
                    it2.next().delete();
                }
                throw th2;
            }
        }

        private Frame reBalance(Frame frame, boolean z, String str) {
            int min = (int) Math.min(4 * H2O.NUMCPUS * (z ? 1 : H2O.CLOUD.size()), frame.numRows());
            if (frame.anyVec().nChunks() > min && !DeepLearning.this._parms._reproducible) {
                Log.info(new Object[]{"Dataset already contains " + frame.anyVec().nChunks() + " chunks. No need to rebalance."});
                return frame;
            }
            if (DeepLearning.this._parms._reproducible) {
                Log.warn(new Object[]{"Reproducibility enforced - using only 1 thread - can be slow."});
                min = 1;
            }
            if (!DeepLearning.this._parms._quiet_mode) {
                Log.info(new Object[]{"ReBalancing dataset into (at least) " + min + " chunks."});
            }
            Key make = Key.make(str + ".chunks" + min);
            RebalanceDataSet rebalanceDataSet = new RebalanceDataSet(frame, make, min);
            H2O.submitTask(rebalanceDataSet);
            rebalanceDataSet.join();
            Frame frame2 = DKV.get(make).get();
            this._delete_me.add(frame2);
            return frame2;
        }

        private long computeTrainSamplesPerIteration(DeepLearningModel.DeepLearningParameters deepLearningParameters, long j, DeepLearningModel deepLearningModel) {
            long j2;
            long j3 = deepLearningParameters._train_samples_per_iteration;
            if (!$assertionsDisabled && j3 != 0 && j3 != -1 && j3 != -2 && j3 < 1) {
                throw new AssertionError();
            }
            if (j3 == 0 || (!deepLearningParameters._replicate_training_data && j3 == -1)) {
                j2 = j;
                if (!deepLearningParameters._quiet_mode) {
                    Log.info(new Object[]{"Setting train_samples_per_iteration (" + deepLearningParameters._train_samples_per_iteration + ") to one epoch: #rows (" + j2 + ")."});
                }
            } else if (j3 == -1) {
                j2 = (deepLearningParameters._single_node_mode ? 1 : H2O.CLOUD.size()) * j;
                if (!deepLearningParameters._quiet_mode) {
                    Log.info(new Object[]{"Setting train_samples_per_iteration (" + deepLearningParameters._train_samples_per_iteration + ") to #nodes x #rows (" + j2 + ")."});
                }
            } else if (j3 == -2) {
                double d = 0.0d;
                for (H2ONode h2ONode : H2O.CLOUD._memary) {
                    d += h2ONode._heartbeat._gflops;
                }
                if (deepLearningParameters._single_node_mode) {
                    d /= H2O.CLOUD.size();
                }
                if (d == 0.0d) {
                    d = Linpack.run(H2O.SELF._heartbeat._cpus_allowed) * (deepLearningParameters._single_node_mode ? 1 : H2O.CLOUD.size());
                }
                long size = deepLearningModel.model_info().size();
                int[] iArr = new int[1];
                iArr[0] = ((long) ((int) (size * 4))) == size * 4 ? (int) (size * 4) : Neurons.missing_int_value;
                double[] dArr = new double[iArr.length];
                new NetworkTest.NetworkTester(iArr, (double[][]) null, dArr, ((double) size) > 1000000.0d ? 1 : 5, false, true).compute2();
                int floor = (deepLearningParameters._single_node_mode || H2O.CLOUD.size() == 1) ? 1 : 2 * ((int) Math.floor(Math.log(H2O.CLOUD.size()) / Math.log(2.0d)));
                double d2 = 30.0d;
                if (deepLearningParameters._activation == DeepLearningModel.DeepLearningParameters.Activation.Maxout || deepLearningParameters._activation == DeepLearningModel.DeepLearningParameters.Activation.MaxoutWithDropout) {
                    d2 = 30.0d * 8.0d;
                } else if (deepLearningParameters._activation == DeepLearningModel.DeepLearningParameters.Activation.Tanh || deepLearningParameters._activation == DeepLearningModel.DeepLearningParameters.Activation.TanhWithDropout) {
                    d2 = 30.0d * 5.0d;
                }
                double d3 = (deepLearningParameters._single_node_mode || H2O.CLOUD.size() == 1) ? 0.001d : 0.05d;
                deepLearningModel.time_for_communication_us = (H2O.CLOUD.size() == 1 ? 10000.0d : 0.0d) + (floor * dArr[0]);
                double d4 = (((d2 * size) / (d * 1.0E9d)) / H2O.SELF._heartbeat._cpus_allowed) * 1000000.0d;
                long min = Math.min((long) (((deepLearningModel.time_for_communication_us / d3) - deepLearningModel.time_for_communication_us) / d4), (deepLearningParameters._single_node_mode ? 1 : H2O.CLOUD.size()) * j * 10);
                if (min > j && Math.abs(min % j) / j < 0.2d) {
                    min -= min % j;
                }
                j2 = Math.max(1L, Math.min(min, (long) ((deepLearningParameters._epochs * j) / 10.0d)));
                if (!deepLearningParameters._quiet_mode) {
                    Log.info(new Object[]{"Auto-tuning parameter 'train_samples_per_iteration':"});
                    Log.info(new Object[]{"Estimated compute power : " + ((int) d) + " GFlops"});
                    Log.info(new Object[]{"Estimated time for comm : " + PrettyPrint.usecs((long) deepLearningModel.time_for_communication_us)});
                    Object[] objArr = new Object[1];
                    objArr[0] = "Estimated time per row  : " + (((long) d4) > 0 ? PrettyPrint.usecs((long) d4) : d4 + " usecs");
                    Log.info(objArr);
                    Log.info(new Object[]{"Estimated training speed: " + ((int) (1000000.0d / d4)) + " rows/sec"});
                    Log.info(new Object[]{"Setting train_samples_per_iteration (" + deepLearningParameters._train_samples_per_iteration + ") to auto-tuned value: " + j2});
                }
            } else {
                j2 = Math.min(j3, (long) (deepLearningParameters._epochs * j));
            }
            if ($assertionsDisabled || !(j2 == 0 || j2 == -1 || j2 == -2 || j2 < 1)) {
                return j2;
            }
            throw new AssertionError();
        }

        private float computeRowUsageFraction(long j, long j2, boolean z) {
            float f = ((float) j2) / ((float) j);
            if (z) {
                f /= H2O.CLOUD.size();
            }
            if ($assertionsDisabled || f > 0.0f) {
                return f;
            }
            throw new AssertionError();
        }

        private float rowFraction(Frame frame, DeepLearningModel.DeepLearningParameters deepLearningParameters, DeepLearningModel deepLearningModel) {
            return computeRowUsageFraction(frame.numRows(), deepLearningModel.actual_train_samples_per_iteration, deepLearningParameters._replicate_training_data);
        }

        static {
            $assertionsDisabled = !DeepLearning.class.desiredAssertionStatus();
        }
    }

    public Model.ModelCategory[] can_build() {
        return new Model.ModelCategory[]{Model.ModelCategory.Regression, Model.ModelCategory.Binomial, Model.ModelCategory.Multinomial};
    }

    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Stable;
    }

    public boolean isSupervised() {
        return !this._parms._autoencoder;
    }

    public DeepLearning(DeepLearningModel.DeepLearningParameters deepLearningParameters) {
        super("DeepLearning", deepLearningParameters);
        init(false);
    }

    public ModelBuilderSchema schema() {
        return new DeepLearningV3();
    }

    public Job<DeepLearningModel> trainModel() {
        return start(new DeepLearningDriver(), (long) (this._parms._epochs * this._train.numRows()));
    }

    public void init(boolean z) {
        super.init(z);
        this._parms.validate(this, z);
        if (z && error_count() == 0) {
            checkMemoryFootPrint();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static DataInfo makeDataInfo(Frame frame, Frame frame2, DeepLearningModel.DeepLearningParameters deepLearningParameters) {
        return new DataInfo(Key.make(), frame, frame2, deepLearningParameters._autoencoder ? 0 : 1, deepLearningParameters._autoencoder || deepLearningParameters._use_all_factor_levels, deepLearningParameters._autoencoder ? DataInfo.TransformType.NORMALIZE : DataInfo.TransformType.STANDARDIZE, frame.lastVec().isEnum() ? DataInfo.TransformType.NONE : DataInfo.TransformType.STANDARDIZE, deepLearningParameters._missing_values_handling == DeepLearningModel.DeepLearningParameters.MissingValuesHandling.Skip, true);
    }

    protected void checkMemoryFootPrint() {
        if (this._parms._checkpoint != null) {
            return;
        }
        long degreesOfFreedom = this._train.degreesOfFreedom() - (this._parms._autoencoder ? 0 : this._train.lastVec().cardinality());
        String[][] domains = this._train.domains();
        int i = 0;
        while (true) {
            if (i >= this._train.numCols() - (this._parms._autoencoder ? 0 : 1)) {
                break;
            }
            if (domains[i] != null) {
                degreesOfFreedom++;
            }
            i++;
        }
        long abs = this._parms._autoencoder ? degreesOfFreedom : Math.abs(this._train.lastVec().cardinality());
        long j = degreesOfFreedom * this._parms._hidden[0];
        int i2 = 1;
        while (i2 < this._parms._hidden.length) {
            j += this._parms._hidden[i2 - 1] * this._parms._hidden[i2];
            i2++;
        }
        long j2 = j + (this._parms._hidden[i2 - 1] * abs);
        for (int i3 = 0; i3 < this._parms._hidden.length; i3++) {
            j2 += this._parms._hidden[i3];
        }
        long j3 = j2 + abs;
        if (j3 > 1.0E8d) {
            String str = "Model is too large: " + j3 + " parameters. Try reducing the number of neurons in the hidden layers (or reduce the number of categorical factors).";
            error("_hidden", str);
            cancel(str);
        }
    }
}
