package cc.factorie.directed;

import cc.factorie.directed.Mixture;
import cc.factorie.directed.MultivariateGaussian;
import cc.factorie.directed.MultivariateGaussianMixture;
import cc.factorie.infer.AssignmentSummary;
import cc.factorie.infer.DiscreteMarginal1;
import cc.factorie.infer.DiscreteSummary1;
import cc.factorie.infer.Infer;
import cc.factorie.infer.Maximize;
import cc.factorie.infer.Summary;
import cc.factorie.la.DenseTensor1;
import cc.factorie.la.Tensor1;
import cc.factorie.model.Model;
import cc.factorie.variable.DiscreteVar;
import cc.factorie.variable.DiscreteVariable;
import cc.factorie.variable.HashMapAssignment;
import cc.factorie.variable.MutableTensorVar;
import cc.factorie.variable.Var;
import scala.Function1;
import scala.None$;
import scala.Option;
import scala.Predef$;
import scala.Some;
import scala.collection.Iterable;
import scala.collection.Iterator;
import scala.collection.Seq;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.StringBuilder;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: MultivariateGaussian.scala */
/* loaded from: input_file:cc/factorie/directed/MaximizeMultivariateGaussianMean$.class */
public final class MaximizeMultivariateGaussianMean$ implements Maximize<Iterable<MutableTensorVar>, DirectedModel> {
    public static final MaximizeMultivariateGaussianMean$ MODULE$ = null;

    static {
        new MaximizeMultivariateGaussianMean$();
    }

    @Override // cc.factorie.infer.Maximize
    public void maximize(Iterable<MutableTensorVar> iterable, DirectedModel directedModel, Summary summary) {
        Maximize.Cclass.maximize(this, iterable, directedModel, summary);
    }

    @Override // cc.factorie.infer.Maximize
    public Summary maximize$default$3() {
        return Maximize.Cclass.maximize$default$3(this);
    }

    @Override // cc.factorie.infer.Infer
    public Summary infer$default$3() {
        return Infer.Cclass.infer$default$3(this);
    }

    public Option<Tensor1> maxMean(MutableTensorVar mutableTensorVar, DirectedModel directedModel, Summary summary) {
        return getMeanFromFactors(directedModel.extendedChildFactors(mutableTensorVar), new MaximizeMultivariateGaussianMean$$anonfun$maxMean$1(mutableTensorVar), new MaximizeMultivariateGaussianMean$$anonfun$maxMean$2(mutableTensorVar), summary);
    }

    public void apply(MutableTensorVar mutableTensorVar, DirectedModel directedModel, DiscreteSummary1<DiscreteVar> discreteSummary1) {
        maxMean(mutableTensorVar, directedModel, discreteSummary1).foreach(new MaximizeMultivariateGaussianMean$$anonfun$apply$1(mutableTensorVar));
    }

    public DiscreteSummary1<DiscreteVar> apply$default$3() {
        return null;
    }

    public AssignmentSummary infer(Iterable<MutableTensorVar> iterable, DirectedModel directedModel, Summary summary) {
        HashMapAssignment hashMapAssignment = new HashMapAssignment((Seq<Var>) Nil$.MODULE$);
        iterable.foreach(new MaximizeMultivariateGaussianMean$$anonfun$infer$1(directedModel, summary, hashMapAssignment));
        return new AssignmentSummary(hashMapAssignment);
    }

    public Option<Tensor1> getMeanFromFactors(Iterable<DirectedFactor> iterable, Function1<MultivariateGaussian.Factor, Object> function1, Function1<MultivariateGaussianMixture.Factor, Object> function12, Summary summary) {
        BoxedUnit boxedUnit;
        Iterator it = iterable.iterator();
        DenseTensor1 denseTensor1 = null;
        double d = 0.0d;
        while (it.hasNext()) {
            DirectedFactor directedFactor = (DirectedFactor) it.next();
            if (directedFactor instanceof MultivariateGaussian.Factor) {
                MultivariateGaussian.Factor factor = (MultivariateGaussian.Factor) directedFactor;
                MutableTensorVar mo1626_1 = factor.mo1626_1();
                if (BoxesRunTime.unboxToBoolean(function1.apply(factor))) {
                    if (denseTensor1 == null) {
                        denseTensor1 = new DenseTensor1(((Tensor1) mo1626_1.mo139value()).length(), 0.0d);
                    }
                    denseTensor1.$plus$eq(mo1626_1.mo139value());
                    d++;
                    BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
                }
            }
            if (directedFactor instanceof MultivariateGaussianMixture.Factor) {
                MultivariateGaussianMixture.Factor factor2 = (MultivariateGaussianMixture.Factor) directedFactor;
                MutableTensorVar mo1626_12 = factor2.mo1626_1();
                DiscreteVariable _4 = factor2._4();
                if (BoxesRunTime.unboxToInt(function12.apply(factor2)) != -1) {
                    if (denseTensor1 == null) {
                        denseTensor1 = new DenseTensor1(((Tensor1) mo1626_12.mo139value()).length(), 0.0d);
                    }
                    DiscreteMarginal1 discreteMarginal1 = summary == null ? null : (DiscreteMarginal1) summary.marginal(_4);
                    int unboxToInt = BoxesRunTime.unboxToInt(function12.apply(factor2));
                    if (discreteMarginal1 != null) {
                        double apply = discreteMarginal1.proportions().mo364apply(unboxToInt);
                        Predef$.MODULE$.assert(apply == apply);
                        Predef$.MODULE$.assert(apply >= 0.0d);
                        Predef$.MODULE$.assert(apply <= 1.0d);
                        denseTensor1.$plus$eq(mo1626_12.mo139value(), apply);
                        d += apply;
                        boxedUnit = BoxedUnit.UNIT;
                    } else if (_4.intValue() == unboxToInt) {
                        denseTensor1.$plus$eq(mo1626_12.mo139value());
                        d += 1.0d;
                        boxedUnit = BoxedUnit.UNIT;
                    } else {
                        boxedUnit = BoxedUnit.UNIT;
                    }
                }
            }
            if (!(directedFactor instanceof Mixture.Factor)) {
                Predef$.MODULE$.println(new StringBuilder().append("MaximizeMultivariateGaussianMean can't handle factor ").append(directedFactor.getClass().getName()).append("=").append(directedFactor).toString());
                return None$.MODULE$;
            }
            BoxedUnit boxedUnit3 = BoxedUnit.UNIT;
        }
        denseTensor1.$div$eq(d);
        return new Some(denseTensor1);
    }

    @Override // cc.factorie.infer.Infer
    public /* bridge */ /* synthetic */ Summary infer(Iterable iterable, Model model, Summary summary) {
        return infer((Iterable<MutableTensorVar>) iterable, (DirectedModel) model, summary);
    }

    private MaximizeMultivariateGaussianMean$() {
        MODULE$ = this;
        Infer.Cclass.$init$(this);
        Maximize.Cclass.$init$(this);
    }
}
