package ml.shifu.guagua.mapreduce.example.nn;

import java.util.Arrays;
import org.encog.engine.network.activation.ActivationFunction;
import org.encog.mathutil.error.ErrorCalculation;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLDataPair;
import org.encog.neural.error.ErrorFunction;
import org.encog.neural.flat.FlatNetwork;
import org.encog.neural.networks.BasicNetwork;

/* loaded from: input_file:ml/shifu/guagua/mapreduce/example/nn/Gradient.class */
public class Gradient {
    private FlatNetwork network;
    private final MLDataSet training;
    private double error;
    private double[] flatSpot;
    private final ErrorFunction errorFunction;
    private final ErrorCalculation errorCalculation = new ErrorCalculation();
    private final double[] layerDelta = new double[getNetwork().getLayerOutput().length];
    private double[] gradients = new double[getNetwork().getWeights().length];
    private final double[] actual = new double[getNetwork().getOutputCount()];
    private double[] weights = getNetwork().getWeights();
    private final int[] layerIndex = getNetwork().getLayerIndex();
    private final int[] layerCounts = getNetwork().getLayerCounts();
    private final int[] weightIndex = getNetwork().getWeightIndex();
    private final double[] layerOutput = getNetwork().getLayerOutput();
    private final double[] layerSums = getNetwork().getLayerSums();
    private final int[] layerFeedCounts = getNetwork().getLayerFeedCounts();
    private final MLDataPair pair = BasicMLDataPair.createPair(getNetwork().getInputCount(), getNetwork().getOutputCount());

    public Gradient(FlatNetwork flatNetwork, MLDataSet mLDataSet, double[] dArr, ErrorFunction errorFunction) {
        this.network = flatNetwork;
        this.training = mLDataSet;
        this.flatSpot = dArr;
        this.errorFunction = errorFunction;
    }

    private void process(double[] dArr, double[] dArr2, double d) {
        getNetwork().compute(dArr, this.actual);
        this.errorCalculation.updateError(this.actual, dArr2, d);
        this.errorFunction.calculateError(dArr2, this.actual, getLayerDelta());
        for (int i = 0; i < this.actual.length; i++) {
            getLayerDelta()[i] = (getNetwork().getActivationFunctions()[0].derivativeFunction(this.layerSums[i], this.layerOutput[i]) + this.flatSpot[0]) * getLayerDelta()[i] * d;
        }
        for (int beginTraining = getNetwork().getBeginTraining(); beginTraining < getNetwork().getEndTraining(); beginTraining++) {
            processLevel(beginTraining);
        }
    }

    private void processLevel(int i) {
        int i2 = this.layerIndex[i + 1];
        int i3 = this.layerIndex[i];
        int i4 = this.layerCounts[i + 1];
        int i5 = this.layerFeedCounts[i];
        int i6 = this.weightIndex[i];
        ActivationFunction activationFunction = getNetwork().getActivationFunctions()[i + 1];
        double d = this.flatSpot[i + 1];
        int i7 = i2;
        for (int i8 = 0; i8 < i4; i8++) {
            double d2 = this.layerOutput[i7];
            double d3 = 0.0d;
            int i9 = i3;
            int i10 = i6 + i8;
            for (int i11 = 0; i11 < i5; i11++) {
                double[] dArr = this.gradients;
                int i12 = i10;
                dArr[i12] = dArr[i12] + (d2 * getLayerDelta()[i9]);
                d3 += this.weights[i10] * getLayerDelta()[i9];
                i10 += i4;
                i9++;
            }
            getLayerDelta()[i7] = d3 * (activationFunction.derivativeFunction(this.layerSums[i7], this.layerOutput[i7]) + d);
            i7++;
        }
    }

    public final void run() {
        try {
            this.errorCalculation.reset();
            Arrays.fill(this.gradients, 0.0d);
            for (int i = 0; i < this.training.getRecordCount(); i++) {
                this.training.getRecord(i, this.pair);
                process(this.pair.getInputArray(), this.pair.getIdealArray(), this.pair.getSignificance());
            }
            this.error = this.errorCalculation.calculate();
        } catch (Throwable th) {
            throw new RuntimeException(th);
        }
    }

    public ErrorCalculation getErrorCalculation() {
        return this.errorCalculation;
    }

    public double[] getGradients() {
        return this.gradients;
    }

    public double getError() {
        return this.error;
    }

    public double[] getWeights() {
        return this.weights;
    }

    public void setWeights(double[] dArr) {
        this.weights = dArr;
        getNetwork().setWeights(dArr);
    }

    public void setParams(BasicNetwork basicNetwork) {
        this.network = basicNetwork.getFlat();
        this.weights = basicNetwork.getFlat().getWeights();
    }

    public FlatNetwork getNetwork() {
        return this.network;
    }

    public double[] getLayerDelta() {
        return this.layerDelta;
    }
}
