package edu.emory.clir.clearnlp.classification.trainer;

import edu.emory.clir.clearnlp.classification.instance.IntInstance;
import edu.emory.clir.clearnlp.classification.model.SparseModel;
import edu.emory.clir.clearnlp.classification.model.StringModel;
import edu.emory.clir.clearnlp.classification.vector.SparseFeatureVector;
import edu.emory.clir.clearnlp.util.DSUtils;
import edu.emory.clir.clearnlp.util.MathUtils;

/* loaded from: input_file:edu/emory/clir/clearnlp/classification/trainer/RRM.class */
public class RRM extends AbstractAdaGrad {
    public RRM(SparseModel sparseModel, boolean z, double d, double d2, double d3) {
        super(sparseModel, z, d, d2, d3);
    }

    public RRM(StringModel stringModel, int i, int i2, boolean z, double d, double d2, double d3) {
        super(stringModel, i, i2, z, d, d2, d3);
    }

    @Override // edu.emory.clir.clearnlp.classification.trainer.AbstractOnlineTrainer
    protected boolean update(IntInstance intInstance, int i) {
        int bestLabel = getBestLabel(intInstance);
        if (intInstance.isLabel(bestLabel)) {
            return false;
        }
        updateGradients(intInstance, intInstance.getLabel(), bestLabel);
        updateWeights(intInstance, intInstance.getLabel(), bestLabel, i);
        return true;
    }

    private int getBestLabel(IntInstance intInstance) {
        double[] scores = this.w_vector.getScores(intInstance.getFeatureVector());
        int label = intInstance.getLabel();
        scores[label] = scores[label] - 1.0d;
        return DSUtils.maxIndex(scores);
    }

    private void updateGradients(IntInstance intInstance, int i, int i2) {
        SparseFeatureVector featureVector = intInstance.getFeatureVector();
        int size = featureVector.size();
        updateGradients(i, i2, 0, MathUtils.sq(this.d_bias));
        for (int i3 = 0; i3 < size; i3++) {
            updateGradients(i, i2, featureVector.getIndex(i3), MathUtils.sq(featureVector.getWeight(i3)));
        }
    }

    private void updateGradients(int i, int i2, int i3, double d) {
        if (this.w_vector.isBinaryLabel()) {
            double[] dArr = this.d_gradients;
            dArr[i3] = dArr[i3] + d;
            return;
        }
        double[] dArr2 = this.d_gradients;
        int weightIndex = this.w_vector.getWeightIndex(i, i3);
        dArr2[weightIndex] = dArr2[weightIndex] + d;
        double[] dArr3 = this.d_gradients;
        int weightIndex2 = this.w_vector.getWeightIndex(i2, i3);
        dArr3[weightIndex2] = dArr3[weightIndex2] + d;
    }

    private void updateWeights(IntInstance intInstance, int i, int i2, int i3) {
        SparseFeatureVector featureVector = intInstance.getFeatureVector();
        int size = featureVector.size();
        updateWeights(i, i2, i3, 0, this.d_bias);
        for (int i4 = 0; i4 < size; i4++) {
            updateWeights(i, i2, i3, featureVector.getIndex(i4), featureVector.getWeight(i4));
        }
    }

    private void updateWeights(int i, int i2, int i3, int i4, double d) {
        if (!this.w_vector.isBinaryLabel()) {
            updateWeight(this.w_vector.getWeightIndex(i, i4), d, i3);
            updateWeight(this.w_vector.getWeightIndex(i2, i4), -d, i3);
        } else {
            if (i == 1) {
                d *= -1.0d;
            }
            updateWeight(i4, d, i3);
        }
    }

    @Override // edu.emory.clir.clearnlp.classification.trainer.AbstractTrainer
    public String trainerInfo() {
        return getTrainerInfo("SVM");
    }
}
