package ml.shifu.guagua.mapreduce.example.nn;

import java.util.concurrent.atomic.AtomicBoolean;
import ml.shifu.guagua.io.Bytable;
import ml.shifu.guagua.mapreduce.example.nn.meta.NNParams;
import ml.shifu.guagua.master.MasterComputable;
import ml.shifu.guagua.master.MasterContext;
import ml.shifu.guagua.util.NumberFormatUtils;
import org.encog.neural.networks.BasicNetwork;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ml/shifu/guagua/mapreduce/example/nn/NNMaster.class */
public class NNMaster implements MasterComputable<NNParams, NNParams> {
    private static final Logger LOG = LoggerFactory.getLogger(NNMaster.class);
    private NNParams globalNNParams = new NNParams();
    private AtomicBoolean isInitialized = new AtomicBoolean(false);
    private Weight weightCalculator = null;
    private double learningRate;

    public NNParams compute(MasterContext<NNParams, NNParams> masterContext) {
        if (this.isInitialized.compareAndSet(false, true)) {
            NNParams initWeights = initWeights(masterContext);
            this.globalNNParams.setWeights(initWeights.getWeights());
            return initWeights;
        }
        if (masterContext.getWorkerResults() == null) {
            throw new IllegalArgumentException("workers' results are null.");
        }
        double d = 0.0d;
        double d2 = 0.0d;
        int i = 0;
        this.globalNNParams.reset();
        for (NNParams nNParams : masterContext.getWorkerResults()) {
            d += nNParams.getTestError();
            d2 += nNParams.getTrainError();
            this.globalNNParams.accumulateGradients(nNParams.getGradients());
            this.globalNNParams.accumulateTrainSize(nNParams.getTrainSize());
            i++;
        }
        if (i == 0) {
            throw new IllegalArgumentException("workers' results are empty.");
        }
        if (this.weightCalculator == null) {
            this.weightCalculator = new Weight(this.globalNNParams.getGradients().length, this.globalNNParams.getTrainSize(), this.learningRate, "Q");
        }
        double[] calculateWeights = this.weightCalculator.calculateWeights(this.globalNNParams.getWeights(), this.globalNNParams.getGradients());
        this.globalNNParams.setWeights(calculateWeights);
        double d3 = d / i;
        double d4 = d2 / i;
        LOG.info("NNMaster compute iteration {} ( avg train error {}, avg validation error {} )", new Object[]{Integer.valueOf(masterContext.getCurrentIteration()), Double.valueOf(d4), Double.valueOf(d3)});
        NNParams nNParams2 = new NNParams();
        nNParams2.setTrainError(d4);
        nNParams2.setTestError(d3);
        nNParams2.setGradients(new double[0]);
        nNParams2.setWeights(calculateWeights);
        LOG.debug("master result {} in iteration {}", nNParams2, Integer.valueOf(masterContext.getCurrentIteration()));
        return nNParams2;
    }

    private NNParams initWeights(MasterContext<NNParams, NNParams> masterContext) {
        int i = NumberFormatUtils.getInt(masterContext.getProps().getProperty(NNConstants.GUAGUA_NN_INPUT_NODES), 100);
        int i2 = NumberFormatUtils.getInt(masterContext.getProps().getProperty(NNConstants.GUAGUA_NN_HIDDEN_NODES), 2);
        int i3 = NumberFormatUtils.getInt(masterContext.getProps().getProperty(NNConstants.GUAGUA_NN_OUTPUT_NODES), 1);
        this.learningRate = NumberFormatUtils.getDouble(masterContext.getProps().getProperty(NNConstants.GUAGUA_NN_LEARNING_RATE, NNConstants.GUAGUA_NN_DEFAULT_LEARNING_RATE));
        BasicNetwork generateNetwork = NNUtils.generateNetwork(i, i2, i3);
        NNParams nNParams = new NNParams();
        nNParams.setTrainError(0.0d);
        nNParams.setTestError(0.0d);
        nNParams.setGradients(new double[0]);
        nNParams.setWeights(generateNetwork.getFlat().getWeights());
        return nNParams;
    }

    /* renamed from: compute, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Bytable m18compute(MasterContext masterContext) {
        return compute((MasterContext<NNParams, NNParams>) masterContext);
    }
}
