package ml.shifu.guagua.mapreduce.example.lr;

import com.google.common.base.Splitter;
import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
import ml.shifu.guagua.hadoop.io.GuaguaLineRecordReader;
import ml.shifu.guagua.hadoop.io.GuaguaWritableAdapter;
import ml.shifu.guagua.io.Bytable;
import ml.shifu.guagua.io.GuaguaFileSplit;
import ml.shifu.guagua.util.NumberFormatUtils;
import ml.shifu.guagua.worker.AbstractWorkerComputable;
import ml.shifu.guagua.worker.WorkerContext;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ml/shifu/guagua/mapreduce/example/lr/LogisticRegressionWorker.class */
public class LogisticRegressionWorker extends AbstractWorkerComputable<LogisticRegressionParams, LogisticRegressionParams, GuaguaWritableAdapter<LongWritable>, GuaguaWritableAdapter<Text>> {
    private static final Logger LOG = LoggerFactory.getLogger(LogisticRegressionWorker.class);
    private int inputNum;
    private int outputNum;
    private List<Data> dataList;
    private double[] weights;
    private Splitter splitter = Splitter.on(",");

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ml/shifu/guagua/mapreduce/example/lr/LogisticRegressionWorker$Data.class */
    public static class Data {
        private final double[] inputs;
        private final double[] outputs;

        public Data(double[] dArr, double[] dArr2) {
            this.inputs = dArr;
            this.outputs = dArr2;
        }
    }

    public void initRecordReader(GuaguaFileSplit guaguaFileSplit) throws IOException {
        setRecordReader(new GuaguaLineRecordReader(guaguaFileSplit));
    }

    public void init(WorkerContext<LogisticRegressionParams, LogisticRegressionParams> workerContext) {
        this.inputNum = NumberFormatUtils.getInt("lr.input.num", 2);
        this.outputNum = 1;
        this.dataList = new LinkedList();
    }

    public LogisticRegressionParams doCompute(WorkerContext<LogisticRegressionParams, LogisticRegressionParams> workerContext) {
        if (workerContext.isFirstIteration()) {
            return new LogisticRegressionParams();
        }
        this.weights = workerContext.getLastMasterResult().getParameters();
        double[] dArr = new double[this.inputNum + 1];
        double d = 0.0d;
        int i = 0;
        for (Data data : this.dataList) {
            double sigmoid = sigmoid(data.inputs, this.weights) - data.outputs[0];
            d += (sigmoid * sigmoid) / 2.0d;
            for (int i2 = 0; i2 < dArr.length; i2++) {
                int i3 = i2;
                dArr[i3] = dArr[i3] + (sigmoid * data.inputs[i2]);
            }
            i++;
        }
        LOG.info("Iteration {} with error {}", Integer.valueOf(workerContext.getCurrentIteration()), Double.valueOf(d / i));
        return new LogisticRegressionParams(dArr, d / i);
    }

    private double sigmoid(double[] dArr, double[] dArr2) {
        double d = 0.0d;
        for (int i = 0; i < dArr2.length; i++) {
            d += dArr2[i] * dArr[i];
        }
        return 1.0d / (1.0d + Math.exp(-d));
    }

    public void load(GuaguaWritableAdapter<LongWritable> guaguaWritableAdapter, GuaguaWritableAdapter<Text> guaguaWritableAdapter2, WorkerContext<LogisticRegressionParams, LogisticRegressionParams> workerContext) {
        String text = guaguaWritableAdapter2.getWritable().toString();
        double[] dArr = new double[this.inputNum + 1];
        double[] dArr2 = new double[this.outputNum];
        int i = 0;
        int i2 = 0;
        int i3 = 0 + 1;
        dArr[0] = 1.0d;
        for (String str : this.splitter.split(text)) {
            if (i >= this.inputNum) {
                if (i < this.inputNum || i >= this.inputNum + this.outputNum) {
                    break;
                }
                int i4 = i2;
                i2++;
                dArr2[i4] = Double.valueOf(str).doubleValue();
            } else {
                int i5 = i3;
                i3++;
                dArr[i5] = Double.valueOf(str).doubleValue();
            }
            i++;
        }
        this.dataList.add(new Data(dArr, dArr2));
    }

    public /* bridge */ /* synthetic */ void load(Bytable bytable, Bytable bytable2, WorkerContext workerContext) {
        load((GuaguaWritableAdapter<LongWritable>) bytable, (GuaguaWritableAdapter<Text>) bytable2, (WorkerContext<LogisticRegressionParams, LogisticRegressionParams>) workerContext);
    }

    /* renamed from: doCompute, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Bytable m15doCompute(WorkerContext workerContext) {
        return doCompute((WorkerContext<LogisticRegressionParams, LogisticRegressionParams>) workerContext);
    }
}
