package edu.emory.clir.clearnlp.bin.helper;

import edu.emory.clir.clearnlp.classification.configuration.AbstractTrainerConfiguration;
import edu.emory.clir.clearnlp.classification.configuration.AdaGradTrainerConfiguration;
import edu.emory.clir.clearnlp.classification.model.AbstractModel;
import edu.emory.clir.clearnlp.classification.model.SparseModel;
import edu.emory.clir.clearnlp.classification.model.StringModel;
import edu.emory.clir.clearnlp.classification.trainer.AbstractTrainer;
import edu.emory.clir.clearnlp.classification.trainer.AdaGradLR;
import edu.emory.clir.clearnlp.classification.trainer.AdaGradSVM;
import org.kohsuke.args4j.Option;

/* loaded from: input_file:edu/emory/clir/clearnlp/bin/helper/AdaGradClassify.class */
public class AdaGradClassify extends AbstractClassifyOnline {

    @Option(name = "-a", usage = "the learning rate (default: 0.1)", required = false, metaVar = "<double>")
    private double d_alpha;

    @Option(name = "-r", usage = "the tolerance of termination criterion (default: 0.1)", required = false, metaVar = "<double>")
    private double d_rho;

    @Option(name = "-b", usage = "the bias (default: 0.0)", required = false, metaVar = "<double>")
    private double d_bias;

    @Option(name = "-average", usage = "if set, average weights (default: false)", required = false, metaVar = "<boolean>")
    protected boolean b_average;

    @Option(name = "-logistic", usage = "if set, logistic regression (default: false)", required = false, metaVar = "<boolean>")
    protected boolean b_logistic;

    public AdaGradClassify(String[] strArr) {
        super(strArr);
        this.d_alpha = 0.01d;
        this.d_rho = 0.1d;
        this.d_bias = 0.0d;
        this.b_average = false;
        this.b_logistic = false;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.emory.clir.clearnlp.bin.helper.AbstractClassify
    public AbstractTrainerConfiguration createTrainConfiguration() {
        return new AdaGradTrainerConfiguration(this.i_vectorType, this.b_binary, this.i_labelCutoff, this.i_featureCutoff, this.i_numberOfThreads, this.b_average, this.d_alpha, this.d_rho, this.d_bias);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.emory.clir.clearnlp.bin.helper.AbstractClassify
    public AbstractTrainer getTrainer(AbstractTrainerConfiguration abstractTrainerConfiguration, AbstractModel<?, ?> abstractModel) {
        AdaGradTrainerConfiguration adaGradTrainerConfiguration = (AdaGradTrainerConfiguration) abstractTrainerConfiguration;
        return isSparseModel(abstractModel) ? this.b_logistic ? new AdaGradLR((SparseModel) abstractModel, adaGradTrainerConfiguration.isAverage(), adaGradTrainerConfiguration.getLearningRate(), adaGradTrainerConfiguration.getRidge(), adaGradTrainerConfiguration.getBias()) : new AdaGradSVM((SparseModel) abstractModel, adaGradTrainerConfiguration.isAverage(), adaGradTrainerConfiguration.getLearningRate(), adaGradTrainerConfiguration.getRidge(), adaGradTrainerConfiguration.getBias()) : this.b_logistic ? new AdaGradLR((StringModel) abstractModel, adaGradTrainerConfiguration.getLabelCutoff(), adaGradTrainerConfiguration.getFeatureCutoff(), adaGradTrainerConfiguration.isAverage(), adaGradTrainerConfiguration.getLearningRate(), adaGradTrainerConfiguration.getRidge(), adaGradTrainerConfiguration.getBias()) : new AdaGradSVM((StringModel) abstractModel, adaGradTrainerConfiguration.getLabelCutoff(), adaGradTrainerConfiguration.getFeatureCutoff(), adaGradTrainerConfiguration.isAverage(), adaGradTrainerConfiguration.getLearningRate(), adaGradTrainerConfiguration.getRidge(), adaGradTrainerConfiguration.getBias());
    }

    public static void main(String[] strArr) {
        new AdaGradClassify(strArr);
    }
}
