package cc.factorie.app.regress;

import cc.factorie.la.Tensor1;
import cc.factorie.la.Tensor2;
import cc.factorie.model.WeightsSet;
import cc.factorie.optimize.Example;
import cc.factorie.optimize.MultivariateOptimizableObjective;
import cc.factorie.optimize.OptimizableObjectives$;
import cc.factorie.optimize.Trainer;
import cc.factorie.variable.TensorVar;
import scala.Function1;
import scala.Predef$;
import scala.collection.Iterable;
import scala.collection.Iterable$;
import scala.math.Numeric$IntIsIntegral$;
import scala.runtime.BoxesRunTime;

/* compiled from: Regression.scala */
/* loaded from: input_file:cc/factorie/app/regress/LinearRegressionTrainer$.class */
public final class LinearRegressionTrainer$ {
    public static final LinearRegressionTrainer$ MODULE$ = null;

    static {
        new LinearRegressionTrainer$();
    }

    public <E extends TensorVar, A extends TensorVar> LinearRegressor<E, A> train(Iterable<A> iterable, Function1<A, E> function1, double d, MultivariateOptimizableObjective<Tensor1> multivariateOptimizableObjective) {
        LinearRegressionTrainer$$anon$1 linearRegressionTrainer$$anon$1 = new LinearRegressionTrainer$$anon$1();
        linearRegressionTrainer$$anon$1.variance_$eq(1.0d / d);
        return trainCustom(iterable, function1, new LinearRegressionTrainer$$anonfun$2(linearRegressionTrainer$$anon$1), multivariateOptimizableObjective);
    }

    public <E extends TensorVar, A extends TensorVar> MultivariateOptimizableObjective<Tensor1> train$default$4() {
        return OptimizableObjectives$.MODULE$.squaredMultivariate();
    }

    public <E extends TensorVar, A extends TensorVar> LinearRegressor<E, A> trainCustom(Iterable<A> iterable, Function1<A, E> function1, Function1<WeightsSet, Trainer> function12, MultivariateOptimizableObjective<Tensor1> multivariateOptimizableObjective) {
        TensorVar tensorVar = (TensorVar) iterable.head();
        TensorVar tensorVar2 = (TensorVar) function1.apply(tensorVar);
        LinearRegressionModel linearRegressionModel = new LinearRegressionModel(BoxesRunTime.unboxToInt(Predef$.MODULE$.intArrayOps(tensorVar2.mo121value().dimensions()).product(Numeric$IntIsIntegral$.MODULE$)), BoxesRunTime.unboxToInt(Predef$.MODULE$.intArrayOps(tensorVar.mo121value().dimensions()).product(Numeric$IntIsIntegral$.MODULE$)));
        Trainer trainer = (Trainer) function12.apply(linearRegressionModel.parameters());
        Iterable<Example> iterable2 = (Iterable) iterable.map(new LinearRegressionTrainer$$anonfun$3(function1, multivariateOptimizableObjective, linearRegressionModel), Iterable$.MODULE$.canBuildFrom());
        while (!trainer.isConverged()) {
            trainer.processExamples(iterable2);
        }
        return new LinearRegressor<>(function1, (Tensor2) linearRegressionModel.weights().mo121value());
    }

    public <E extends TensorVar, A extends TensorVar> MultivariateOptimizableObjective<Tensor1> trainCustom$default$4() {
        return OptimizableObjectives$.MODULE$.squaredMultivariate();
    }

    private LinearRegressionTrainer$() {
        MODULE$ = this;
    }
}
