package edu.emory.clir.clearnlp.classification.trainer;

import edu.emory.clir.clearnlp.classification.model.SparseModel;
import edu.emory.clir.clearnlp.classification.model.StringModel;

/* loaded from: input_file:edu/emory/clir/clearnlp/classification/trainer/AbstractAdaGrad.class */
public abstract class AbstractAdaGrad extends AbstractOnlineTrainer {
    protected double[] d_gradients;
    protected double d_alpha;
    protected double d_rho;
    protected double d_bias;

    public AbstractAdaGrad(SparseModel sparseModel, boolean z, double d, double d2, double d3) {
        super(sparseModel, z);
        init(d, d2, d3);
    }

    public AbstractAdaGrad(StringModel stringModel, int i, int i2, boolean z, double d, double d2, double d3) {
        super(stringModel, i, i2, z);
        init(d, d2, d3);
    }

    private void init(double d, double d2, double d3) {
        this.d_gradients = new double[this.w_vector.size()];
        this.d_alpha = d;
        this.d_rho = d2;
        this.d_bias = d3;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void updateWeight(int i, double d, int i2) {
        double cost = getCost(i) * d;
        this.w_vector.add(i, (float) cost);
        if (average()) {
            double[] dArr = this.d_average;
            dArr[i] = dArr[i] + (cost * i2);
        }
    }

    private double getCost(int i) {
        return this.d_alpha / (this.d_rho + Math.sqrt(this.d_gradients[i]));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public String getTrainerInfo(String str) {
        return String.format("AdaGrad-%s: alpha = %4.3f, rho = %4.3f, rho = %4.3f, average = %b", str, Double.valueOf(this.d_alpha), Double.valueOf(this.d_rho), Double.valueOf(this.d_bias), Boolean.valueOf(average()));
    }
}
