package ml.shifu.guagua.mapreduce.example.nn;

import com.google.common.base.Splitter;
import java.io.IOException;
import ml.shifu.guagua.GuaguaRuntimeException;
import ml.shifu.guagua.hadoop.io.GuaguaLineRecordReader;
import ml.shifu.guagua.hadoop.io.GuaguaWritableAdapter;
import ml.shifu.guagua.io.Bytable;
import ml.shifu.guagua.io.GuaguaFileSplit;
import ml.shifu.guagua.mapreduce.example.nn.meta.NNParams;
import ml.shifu.guagua.util.NumberFormatUtils;
import ml.shifu.guagua.worker.AbstractWorkerComputable;
import ml.shifu.guagua.worker.WorkerContext;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.encog.engine.network.activation.ActivationSigmoid;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.ml.data.basic.BasicMLDataPair;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.encog.neural.error.LinearErrorFunction;
import org.encog.neural.flat.FlatNetwork;
import org.encog.neural.networks.BasicNetwork;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ml/shifu/guagua/mapreduce/example/nn/NNWorker.class */
public class NNWorker extends AbstractWorkerComputable<NNParams, NNParams, GuaguaWritableAdapter<LongWritable>, GuaguaWritableAdapter<Text>> {
    private static final Logger LOG = LoggerFactory.getLogger(NNWorker.class);
    private MLDataSet trainingData = null;
    private MLDataSet testingData = null;
    private Gradient gradient;
    private long count;
    private int inputs;
    private int hiddens;
    private int outputs;

    private void initMemoryDataSet() {
        this.trainingData = new BasicMLDataSet();
        this.testingData = new BasicMLDataSet();
    }

    public void init(WorkerContext<NNParams, NNParams> workerContext) {
        this.inputs = NumberFormatUtils.getInt(workerContext.getProps().getProperty(NNConstants.GUAGUA_NN_INPUT_NODES), 100);
        this.hiddens = NumberFormatUtils.getInt(workerContext.getProps().getProperty(NNConstants.GUAGUA_NN_HIDDEN_NODES), 2);
        this.outputs = NumberFormatUtils.getInt(workerContext.getProps().getProperty(NNConstants.GUAGUA_NN_OUTPUT_NODES), 1);
        LOG.info("NNWorker is loading data into memory.");
        initMemoryDataSet();
    }

    public NNParams doCompute(WorkerContext<NNParams, NNParams> workerContext) {
        if (workerContext.getCurrentIteration() == 1) {
            return buildEmptyNNParams(workerContext);
        }
        if (workerContext.getLastMasterResult() == null) {
            LOG.warn("Master result of last iteration is null.");
            return null;
        }
        LOG.debug("Set current model with params {}", workerContext.getLastMasterResult());
        if (this.gradient == null) {
            initGradient(this.trainingData, workerContext.getLastMasterResult().getWeights());
        }
        this.gradient.setWeights(workerContext.getLastMasterResult().getWeights());
        this.gradient.run();
        double error = this.gradient.getError();
        double calculateError = this.testingData.getRecordCount() > 0 ? this.gradient.getNetwork().calculateError(this.testingData) : 0.0d;
        LOG.info("NNWorker compute iteration {} (train error {} validation error {})", new Object[]{Integer.valueOf(workerContext.getCurrentIteration()), Double.valueOf(error), Double.valueOf(calculateError)});
        NNParams nNParams = new NNParams();
        nNParams.setTestError(calculateError);
        nNParams.setTrainError(error);
        nNParams.setGradients(this.gradient.getGradients());
        nNParams.setWeights(new double[0]);
        nNParams.setTrainSize(this.trainingData.getRecordCount());
        return nNParams;
    }

    private void initGradient(MLDataSet mLDataSet, double[] dArr) {
        BasicNetwork generateNetwork = NNUtils.generateNetwork(this.inputs, this.hiddens, this.outputs);
        generateNetwork.getFlat().setWeights(dArr);
        FlatNetwork flat = generateNetwork.getFlat();
        double[] dArr2 = new double[flat.getActivationFunctions().length];
        for (int i = 0; i < flat.getActivationFunctions().length; i++) {
            dArr2[i] = flat.getActivationFunctions()[i] instanceof ActivationSigmoid ? 0.1d : 0.0d;
        }
        this.gradient = new Gradient(flat, mLDataSet.openAdditional(), dArr2, new LinearErrorFunction());
    }

    private NNParams buildEmptyNNParams(WorkerContext<NNParams, NNParams> workerContext) {
        NNParams nNParams = new NNParams();
        nNParams.setWeights(new double[0]);
        nNParams.setGradients(new double[0]);
        nNParams.setTestError(0.0d);
        nNParams.setTrainError(0.0d);
        return nNParams;
    }

    protected void postLoad(WorkerContext<NNParams, NNParams> workerContext) {
        LOG.info("- # Records of the whole data set: {}.", Long.valueOf(this.count));
        LOG.info("- # Records of the training data set: {}.", Long.valueOf(this.trainingData.getRecordCount()));
        LOG.info("- # Records of the testing data set: {}.", Long.valueOf(this.testingData.getRecordCount()));
    }

    public void load(GuaguaWritableAdapter<LongWritable> guaguaWritableAdapter, GuaguaWritableAdapter<Text> guaguaWritableAdapter2, WorkerContext<NNParams, NNParams> workerContext) {
        this.count++;
        if (this.count % 100000 == 0) {
            LOG.info("Read {} records.", Long.valueOf(this.count));
        }
        double[] dArr = new double[1];
        int i = NumberFormatUtils.getInt(workerContext.getProps().getProperty(NNConstants.GUAGUA_NN_INPUT_NODES), 100);
        double[] dArr2 = new double[i];
        int i2 = 0;
        for (String str : Splitter.on(NNConstants.NN_DEFAULT_COLUMN_SEPARATOR).split(guaguaWritableAdapter2.getWritable().toString())) {
            if (i2 == 0) {
                int i3 = i2;
                i2++;
                dArr[i3] = NumberFormatUtils.getDouble(str, 0.0d);
            } else {
                int i4 = i2;
                i2++;
                int i5 = i4 - 1;
                if (i5 >= i) {
                    break;
                } else {
                    dArr2[i5] = NumberFormatUtils.getDouble(str, 0.0d);
                }
            }
        }
        if (i2 < i + 1) {
            throw new GuaguaRuntimeException(String.format("Not enough data columns, input nodes setting:%s, data column:%s", Integer.valueOf(i), Integer.valueOf(i2)));
        }
        int i6 = NumberFormatUtils.getInt(workerContext.getProps().getProperty(NNConstants.NN_RECORD_SCALE), 1);
        int i7 = 0;
        while (i7 < i6) {
            double[] dArr3 = i7 == 0 ? dArr2 : new double[dArr2.length];
            double[] dArr4 = i7 == 0 ? dArr : new double[dArr.length];
            System.arraycopy(dArr2, 0, dArr3, 0, dArr2.length);
            BasicMLDataPair basicMLDataPair = new BasicMLDataPair(new BasicMLData(dArr3), new BasicMLData(dArr4));
            if (Math.random() >= 0.5d) {
                this.trainingData.add(basicMLDataPair);
            } else {
                this.testingData.add(basicMLDataPair);
            }
            i7++;
        }
    }

    public void initRecordReader(GuaguaFileSplit guaguaFileSplit) throws IOException {
        setRecordReader(new GuaguaLineRecordReader());
        getRecordReader().initialize(guaguaFileSplit);
    }

    public MLDataSet getTrainingData() {
        return this.trainingData;
    }

    public void setTrainingData(MLDataSet mLDataSet) {
        this.trainingData = mLDataSet;
    }

    public MLDataSet getTestingData() {
        return this.testingData;
    }

    public void setTestingData(MLDataSet mLDataSet) {
        this.testingData = mLDataSet;
    }

    public /* bridge */ /* synthetic */ void load(Bytable bytable, Bytable bytable2, WorkerContext workerContext) {
        load((GuaguaWritableAdapter<LongWritable>) bytable, (GuaguaWritableAdapter<Text>) bytable2, (WorkerContext<NNParams, NNParams>) workerContext);
    }

    /* renamed from: doCompute, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Bytable m21doCompute(WorkerContext workerContext) {
        return doCompute((WorkerContext<NNParams, NNParams>) workerContext);
    }
}
