package ml.shifu.guagua.example.lr;

import java.util.Arrays;
import java.util.Random;
import ml.shifu.guagua.io.Bytable;
import ml.shifu.guagua.master.MasterComputable;
import ml.shifu.guagua.master.MasterContext;
import ml.shifu.guagua.util.NumberFormatUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ml/shifu/guagua/example/lr/LogisticRegressionMaster.class */
public class LogisticRegressionMaster implements MasterComputable<LogisticRegressionParams, LogisticRegressionParams> {
    private static final Logger LOG = LoggerFactory.getLogger(LogisticRegressionMaster.class);
    private static final Random RANDOM = new Random();
    private int inputNum;
    private double[] weights;
    private double learnRate;

    private void init(MasterContext<LogisticRegressionParams, LogisticRegressionParams> masterContext) {
        this.inputNum = NumberFormatUtils.getInt("lr.input.num", 2);
        this.learnRate = NumberFormatUtils.getDouble("lr.learning.rate", 0.1d);
    }

    public LogisticRegressionParams compute(MasterContext<LogisticRegressionParams, LogisticRegressionParams> masterContext) {
        if (masterContext.isFirstIteration()) {
            init(masterContext);
            this.weights = new double[this.inputNum + 1];
            for (int i = 0; i < this.weights.length; i++) {
                this.weights[i] = RANDOM.nextDouble();
            }
        } else {
            double[] dArr = new double[this.inputNum + 1];
            double d = 0.0d;
            int i2 = 0;
            for (LogisticRegressionParams logisticRegressionParams : masterContext.getWorkerResults()) {
                if (logisticRegressionParams != null) {
                    for (int i3 = 0; i3 < dArr.length; i3++) {
                        int i4 = i3;
                        dArr[i4] = dArr[i4] + logisticRegressionParams.getParameters()[i3];
                    }
                    d += logisticRegressionParams.getError();
                }
                i2++;
            }
            for (int i5 = 0; i5 < this.weights.length; i5++) {
                double[] dArr2 = this.weights;
                int i6 = i5;
                dArr2[i6] = dArr2[i6] - (this.learnRate * dArr[i5]);
            }
            LOG.debug("DEBUG: Weights: {}", Arrays.toString(this.weights));
            LOG.info("Iteration {} with error {}", Integer.valueOf(masterContext.getCurrentIteration()), Double.valueOf(d / i2));
        }
        return new LogisticRegressionParams(this.weights);
    }

    /* renamed from: compute, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Bytable m12compute(MasterContext masterContext) {
        return compute((MasterContext<LogisticRegressionParams, LogisticRegressionParams>) masterContext);
    }
}
