package edu.emory.clir.clearnlp.classification.trainer;

import edu.emory.clir.clearnlp.classification.instance.IntInstance;
import edu.emory.clir.clearnlp.classification.model.SparseModel;
import edu.emory.clir.clearnlp.classification.model.StringModel;
import edu.emory.clir.clearnlp.classification.vector.SparseFeatureVector;
import edu.emory.clir.clearnlp.util.MathUtils;

/* loaded from: input_file:edu/emory/clir/clearnlp/classification/trainer/AdaGradLR.class */
public class AdaGradLR extends AbstractAdaGrad {
    public AdaGradLR(SparseModel sparseModel, boolean z, double d, double d2, double d3) {
        super(sparseModel, z, d, d2, d3);
    }

    public AdaGradLR(StringModel stringModel, int i, int i2, boolean z, double d, double d2, double d3) {
        super(stringModel, i, i2, z, d, d2, d3);
    }

    @Override // edu.emory.clir.clearnlp.classification.trainer.AbstractOnlineTrainer
    protected boolean update(IntInstance intInstance, int i) {
        double[] gradients = getGradients(intInstance);
        if (gradients[intInstance.getLabel()] <= 0.01d) {
            return false;
        }
        updateGradients(intInstance, gradients);
        updateWeights(intInstance, gradients, i);
        return true;
    }

    private double[] getGradients(IntInstance intInstance) {
        double[] scores = this.w_vector.getScores(intInstance.getFeatureVector(), true);
        int length = scores.length;
        for (int i = 0; i < length; i++) {
            int i2 = i;
            scores[i2] = scores[i2] * (-1.0d);
        }
        int label = intInstance.getLabel();
        scores[label] = scores[label] + 1.0d;
        return scores;
    }

    private void updateGradients(IntInstance intInstance, double[] dArr) {
        SparseFeatureVector featureVector = intInstance.getFeatureVector();
        int size = featureVector.size();
        int labelSize = this.w_vector.getLabelSize();
        double[] dArr2 = new double[labelSize];
        for (int i = 0; i < labelSize; i++) {
            dArr2[i] = MathUtils.sq(dArr[i]);
        }
        updateGradients(dArr2, 0, this.d_bias);
        for (int i2 = 0; i2 < size; i2++) {
            updateGradients(dArr2, featureVector.getIndex(i2), MathUtils.sq(featureVector.getWeight(i2)));
        }
    }

    private void updateGradients(double[] dArr, int i, double d) {
        int labelSize = this.w_vector.getLabelSize();
        for (int i2 = 0; i2 < labelSize; i2++) {
            double[] dArr2 = this.d_gradients;
            int weightIndex = this.w_vector.getWeightIndex(i2, i);
            dArr2[weightIndex] = dArr2[weightIndex] + (d * dArr[i2]);
        }
    }

    private void updateWeights(IntInstance intInstance, double[] dArr, int i) {
        SparseFeatureVector featureVector = intInstance.getFeatureVector();
        int size = featureVector.size();
        updateWeights(dArr, 0, this.d_bias, i);
        for (int i2 = 0; i2 < size; i2++) {
            updateWeights(dArr, featureVector.getIndex(i2), featureVector.getWeight(i2), i);
        }
    }

    private void updateWeights(double[] dArr, int i, double d, int i2) {
        int labelSize = this.w_vector.getLabelSize();
        for (int i3 = 0; i3 < labelSize; i3++) {
            updateWeight(this.w_vector.getWeightIndex(i3, i), d * dArr[i3], i2);
        }
    }

    @Override // edu.emory.clir.clearnlp.classification.trainer.AbstractTrainer
    public String trainerInfo() {
        return getTrainerInfo("LR");
    }
}
