/*
 * Decompiled with CFR 0.152.
 */
package io.cdap.mmds.modeler.param;

import com.google.common.collect.ImmutableSet;
import io.cdap.mmds.modeler.param.RegressionParams;
import io.cdap.mmds.spec.DoubleParam;
import io.cdap.mmds.spec.ParamSpec;
import io.cdap.mmds.spec.Params;
import io.cdap.mmds.spec.Range;
import io.cdap.mmds.spec.StringParam;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.spark.ml.classification.LogisticRegression;

public class LogisticRegressionParams
extends RegressionParams {
    private final DoubleParam threshold;
    private final StringParam family;

    public LogisticRegressionParams(Map<String, String> modelParams) {
        super(modelParams);
        this.threshold = new DoubleParam("threshold", "Threshold", "Threshold in binary classification. If the estimated probability of class label 1 is greater than threshold, then predict 1, else 0. A high threshold encourages the model to predict 0 more often. A low threshold encourages the model to predict 1 more often.", 0.5, new Range(0.0, 1.0, true, true), modelParams);
        this.family = new StringParam("family", "Family", "Label distribution to be used in the model. 'auto' will automatically select the family based on the number of classes. If numClasses == 1 or numClasses == 2, sets to 'binomial'. Else, sets to 'multinomial'. 'binomial' uses binary logistic regression with pivoting. 'multinomial' uses multinomial logistic (softmax) regression without pivoting.", "auto", (Set<String>)ImmutableSet.of((Object)"auto", (Object)"binomial", (Object)"multinomial"), modelParams);
    }

    @Override
    public List<ParamSpec> getSpec() {
        return Params.addParams(super.getSpec(), this.threshold, this.family);
    }

    @Override
    public Map<String, String> toMap() {
        return Params.putParams(super.toMap(), this.threshold, this.family);
    }

    public void setParams(LogisticRegression modeler) {
        modeler.setMaxIter(((Integer)this.maxIterations.getVal()).intValue());
        modeler.setStandardization(((Boolean)this.standardization.getVal()).booleanValue());
        modeler.setRegParam(((Double)this.regularizationParam.getVal()).doubleValue());
        modeler.setElasticNetParam(((Double)this.elasticNetParam.getVal()).doubleValue());
        modeler.setTol(((Double)this.tolerance.getVal()).doubleValue());
        modeler.setThreshold(((Double)this.threshold.getVal()).doubleValue());
        modeler.setFamily((String)this.family.getVal());
    }
}

