package hex.genmodel.algos.deeplearning;

import hex.genmodel.algos.deeplearning.ActivationUtils;
import hex.genmodel.algos.deeplearning.DeeplearningMojoModel;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.logging.impl.SimpleLog;

/* loaded from: input_file:hex/genmodel/algos/deeplearning/NeuralNetwork.class */
public class NeuralNetwork {
    public String _activation;
    double _drop_out_ratio;
    public DeeplearningMojoModel.StoreWeightsBias _weightsAndBias;
    public double[] _inputs;
    public double[] _outputs;
    public int _outSize;
    public int _inSize;
    public int _maxK;
    List<String> _validActivation = Arrays.asList("Linear", "Softmax", "ExpRectifierWithDropout", "ExpRectifier", "Rectifier", "RectifierWithDropout", "MaxoutWithDropout", "Maxout", "TanhWithDropout", "Tanh");
    static final /* synthetic */ boolean $assertionsDisabled;

    public NeuralNetwork(String str, double d, DeeplearningMojoModel.StoreWeightsBias storeWeightsBias, double[] dArr, int i) {
        this._maxK = 1;
        validateInputs(str, d, storeWeightsBias._wValues.length, storeWeightsBias._bValues.length, dArr.length, i);
        this._activation = str;
        this._drop_out_ratio = d;
        this._weightsAndBias = storeWeightsBias;
        this._inputs = dArr;
        this._outSize = i;
        this._inSize = this._inputs.length;
        this._outputs = new double[this._outSize];
        if ("Maxout".equals(this._activation) || "MaxoutWithDropout".equals(this._activation)) {
            this._maxK = storeWeightsBias._bValues.length / i;
        }
    }

    public double[] fprop1Layer() {
        return createActFuns(this._activation).eval(this._maxK == 1 ? formNNInputs() : formNNInputsMaxOut(), this._drop_out_ratio, this._maxK);
    }

    public double[] formNNInputs() {
        double[] dArr = new double[this._outSize];
        int length = this._inputs.length;
        int length2 = dArr.length;
        int i = length - (length % 8);
        int i2 = ((length / 8) * 8) - 1;
        int i3 = 0;
        for (int i4 = 0; i4 < length2; i4++) {
            double d = 0.0d;
            double d2 = 0.0d;
            double d3 = 0.0d;
            double d4 = 0.0d;
            double d5 = 0.0d;
            double d6 = 0.0d;
            double d7 = 0.0d;
            double d8 = 0.0d;
            for (int i5 = 0; i5 < i2; i5 += 8) {
                int i6 = i3 + i5;
                d += this._weightsAndBias._wValues[i6] * this._inputs[i5];
                d2 += this._weightsAndBias._wValues[i6 + 1] * this._inputs[i5 + 1];
                d3 += this._weightsAndBias._wValues[i6 + 2] * this._inputs[i5 + 2];
                d4 += this._weightsAndBias._wValues[i6 + 3] * this._inputs[i5 + 3];
                d5 += this._weightsAndBias._wValues[i6 + 4] * this._inputs[i5 + 4];
                d6 += this._weightsAndBias._wValues[i6 + 5] * this._inputs[i5 + 5];
                d7 += this._weightsAndBias._wValues[i6 + 6] * this._inputs[i5 + 6];
                d8 += this._weightsAndBias._wValues[i6 + 7] * this._inputs[i5 + 7];
            }
            int i7 = i4;
            dArr[i7] = dArr[i7] + d + d2 + d3 + d4;
            int i8 = i4;
            dArr[i8] = dArr[i8] + d5 + d6 + d7 + d8;
            for (int i9 = i; i9 < length; i9++) {
                int i10 = i4;
                dArr[i10] = dArr[i10] + (this._weightsAndBias._wValues[i3 + i9] * this._inputs[i9]);
            }
            int i11 = i4;
            dArr[i11] = dArr[i11] + this._weightsAndBias._bValues[i4];
            i3 += length;
        }
        return dArr;
    }

    public double[] formNNInputsMaxOut() {
        double[] dArr = new double[this._outSize * this._maxK];
        for (int i = 0; i < this._maxK; i++) {
            for (int i2 = 0; i2 < this._outSize; i2++) {
                int i3 = (this._maxK * i2) + i;
                for (int i4 = 0; i4 < this._inSize; i4++) {
                    dArr[i3] = dArr[i3] + (this._inputs[i4] * this._weightsAndBias._wValues[(this._maxK * ((i2 * this._inSize) + i4)) + i]);
                }
                dArr[i3] = dArr[i3] + this._weightsAndBias._bValues[i3];
            }
        }
        return dArr;
    }

    public void validateInputs(String str, double d, int i, int i2, int i3, int i4) {
        if (!$assertionsDisabled && !this._validActivation.contains(str)) {
            throw new AssertionError("activation must be one of \"Linear\", \"Softmax\", \"ExpRectifierWithDropout\", \"ExpRectifier\", \"Rectifier\", \"RectifierWithDropout\", \"MaxoutWithDropout\", \"Maxout\", \"TanhWithDropout\", \"Tanh\"");
        }
        if (!$assertionsDisabled && i % (i3 * i4) != 0) {
            throw new AssertionError("Your neural network layer number of input * number of outputs should equal length of your weight vector");
        }
        if (!$assertionsDisabled && i2 % i4 != 0) {
            throw new AssertionError("Number of bias should equal number of nodes in your nerual network layer.");
        }
        if (!$assertionsDisabled && (d < 0.0d || d >= 1.0d)) {
            throw new AssertionError("drop_out_ratio must be >=0 and < 1.");
        }
        if (!$assertionsDisabled && i4 <= 0) {
            throw new AssertionError("number of nodes in neural network must exceed 0.");
        }
    }

    public ActivationUtils.ActivationFunctions createActFuns(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case -2018804923:
                if (str.equals("Linear")) {
                    z = false;
                    break;
                }
                break;
            case -1997255798:
                if (str.equals("Maxout")) {
                    z = 7;
                    break;
                }
                break;
            case -1311381806:
                if (str.equals("TanhWithDropout")) {
                    z = 8;
                    break;
                }
                break;
            case -1162547947:
                if (str.equals("Rectifier")) {
                    z = 4;
                    break;
                }
                break;
            case -371007270:
                if (str.equals("Softmax")) {
                    z = true;
                    break;
                }
                break;
            case -104439935:
                if (str.equals("ExpRectifierWithDropout")) {
                    z = 2;
                    break;
                }
                break;
            case 2599175:
                if (str.equals("Tanh")) {
                    z = 9;
                    break;
                }
                break;
            case 1276617839:
                if (str.equals("MaxoutWithDropout")) {
                    z = 6;
                    break;
                }
                break;
            case 1758277124:
                if (str.equals("RectifierWithDropout")) {
                    z = 5;
                    break;
                }
                break;
            case 2028718008:
                if (str.equals("ExpRectifier")) {
                    z = 3;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return new ActivationUtils.LinearOut();
            case true:
                return new ActivationUtils.SoftmaxOut();
            case true:
                return new ActivationUtils.ExpRectifierDropoutOut();
            case true:
                return new ActivationUtils.ExpRectifierOut();
            case true:
                return new ActivationUtils.RectifierOut();
            case true:
                return new ActivationUtils.RectifierDropoutOut();
            case SimpleLog.LOG_LEVEL_FATAL /* 6 */:
                return new ActivationUtils.MaxoutDropoutOut();
            case SimpleLog.LOG_LEVEL_OFF /* 7 */:
                return new ActivationUtils.MaxoutOut();
            case true:
                return new ActivationUtils.TanhDropoutOut();
            case true:
                return new ActivationUtils.TanhOut();
            default:
                throw new UnsupportedOperationException("Unexpected activation function: " + str);
        }
    }

    static {
        $assertionsDisabled = !NeuralNetwork.class.desiredAssertionStatus();
    }
}
