package ml.shifu.guagua.example.nn;

/* loaded from: input_file:ml/shifu/guagua/example/nn/Weight.class */
public class Weight {
    private static final double ZERO_TOLERANCE = 1.0E-17d;
    private double learningRate;
    private String algorithm;
    private double[] lastDelta;
    private double[] lastGradient;
    private double eps;
    private double shrink;
    private double[] updateValues;
    private static final double DEFAULT_INITIAL_UPDATE = 0.1d;
    private static final double DEFAULT_MAX_STEP = 50.0d;
    private double decay = 1.0E-4d;
    private double outputEpsilon = 0.35d;
    private double momentum = 0.0d;

    public Weight(int i, double d, double d2, String str) {
        this.lastDelta = null;
        this.lastGradient = null;
        this.eps = 0.0d;
        this.shrink = 0.0d;
        this.updateValues = null;
        this.lastDelta = new double[i];
        this.lastGradient = new double[i];
        this.eps = this.outputEpsilon / d;
        this.shrink = d2 / (1.0d + d2);
        this.learningRate = d2;
        this.algorithm = str;
        this.updateValues = new double[i];
        for (int i2 = 0; i2 < this.updateValues.length; i2++) {
            this.updateValues[i2] = 0.1d;
            this.lastDelta[i2] = 0.0d;
        }
    }

    public double[] calculateWeights(double[] dArr, double[] dArr2) {
        for (int i = 0; i < dArr2.length; i++) {
            int i2 = i;
            dArr[i2] = dArr[i2] + updateWeight(i, dArr, dArr2);
        }
        return dArr;
    }

    private double updateWeight(int i, double[] dArr, double[] dArr2) {
        if (this.algorithm.equalsIgnoreCase(NNConstants.BACK_PROPAGATION)) {
            return updateWeightBP(i, dArr, dArr2);
        }
        if (this.algorithm.equalsIgnoreCase("Q")) {
            return updateWeightQBP(i, dArr, dArr2);
        }
        if (this.algorithm.equalsIgnoreCase(NNConstants.MANHATTAN_PROPAGATION)) {
            return updateWeightMHP(i, dArr, dArr2);
        }
        if (this.algorithm.equalsIgnoreCase(NNConstants.SCALEDCONJUGATEGRADIENT)) {
            return updateWeightSCG(i, dArr, dArr2);
        }
        if (this.algorithm.equalsIgnoreCase(NNConstants.RESILIENTPROPAGATION)) {
            return updateWeightRLP(i, dArr, dArr2);
        }
        return 0.0d;
    }

    private double updateWeightBP(int i, double[] dArr, double[] dArr2) {
        double d = (dArr2[i] * this.learningRate) + (this.lastDelta[i] * this.momentum);
        this.lastDelta[i] = d;
        return d;
    }

    private double updateWeightQBP(int i, double[] dArr, double[] dArr2) {
        double d;
        double d2 = dArr[i];
        double d3 = this.lastDelta[i];
        double d4 = (-dArr2[i]) + (this.decay * d2);
        double d5 = -this.lastGradient[i];
        double d6 = 0.0d;
        if (d3 < 0.0d) {
            if (d4 > 0.0d) {
                d6 = 0.0d - (this.eps * d4);
            }
            d = d4 >= this.shrink * d5 ? d6 + (this.learningRate * d3) : d6 + ((d3 * d4) / (d5 - d4));
        } else if (d3 > 0.0d) {
            if (d4 < 0.0d) {
                d6 = 0.0d - (this.eps * d4);
            }
            d = d4 <= this.shrink * d5 ? d6 + (this.learningRate * d3) : d6 + ((d3 * d4) / (d5 - d4));
        } else {
            d = 0.0d - (this.eps * d4);
        }
        this.lastDelta[i] = d;
        this.lastGradient[i] = dArr2[i];
        return d;
    }

    private double updateWeightMHP(int i, double[] dArr, double[] dArr2) {
        if (Math.abs(dArr2[i]) < ZERO_TOLERANCE) {
            return 0.0d;
        }
        return dArr2[i] > 0.0d ? this.learningRate : -this.learningRate;
    }

    private double updateWeightSCG(int i, double[] dArr, double[] dArr2) {
        return 0.0d;
    }

    private double updateWeightRLP(int i, double[] dArr, double[] dArr2) {
        int sign = NNUtils.sign(dArr2[i] * this.lastGradient[i]);
        double d = 0.0d;
        if (sign > 0) {
            double min = Math.min(this.updateValues[i] * 1.2d, DEFAULT_MAX_STEP);
            d = NNUtils.sign(dArr2[i]) * min;
            this.updateValues[i] = min;
            this.lastGradient[i] = dArr2[i];
        } else if (sign < 0) {
            this.updateValues[i] = Math.max(this.updateValues[i] * 0.5d, 1.0E-6d);
            d = -this.lastDelta[i];
            this.lastGradient[i] = 0.0d;
        } else if (sign == 0) {
            d = NNUtils.sign(dArr2[i]) * this.updateValues[i];
            this.lastGradient[i] = dArr2[i];
        }
        this.lastDelta[i] = d;
        return d;
    }
}
