package cc.factorie.directed;

import cc.factorie.directed.DirectedFamily3;
import cc.factorie.directed.MultivariateGaussian;
import cc.factorie.la.DenseTensor1;
import cc.factorie.la.DenseTensor2;
import cc.factorie.la.Tensor1;
import cc.factorie.la.Tensor2;
import cc.factorie.variable.MutableTensorVar;
import org.jblas.Decompose;
import org.jblas.DoubleMatrix;
import org.jblas.Solve;
import scala.Array$;
import scala.Function1;
import scala.Predef$;
import scala.reflect.ClassTag$;
import scala.runtime.ScalaRunTime$;
import scala.util.Random;

/* compiled from: MultivariateGaussian.scala */
/* loaded from: input_file:cc/factorie/directed/MultivariateGaussian$.class */
public final class MultivariateGaussian$ implements DirectedFamily3<MutableTensorVar, MutableTensorVar, MutableTensorVar> {
    public static final MultivariateGaussian$ MODULE$ = null;

    static {
        new MultivariateGaussian$();
    }

    @Override // cc.factorie.directed.DirectedFamily3
    public Function1 apply(MutableTensorVar mutableTensorVar, MutableTensorVar mutableTensorVar2) {
        return DirectedFamily3.Cclass.apply(this, mutableTensorVar, mutableTensorVar2);
    }

    public double logpr(Tensor1 tensor1, Tensor1 tensor12, Tensor2 tensor2) {
        int length = tensor12.length();
        Tensor1 $minus = tensor1.$minus(tensor12);
        return (-0.5d) * ((length * scala.math.package$.MODULE$.log(6.283185307179586d)) + scala.math.package$.MODULE$.log(determinant(tensor2)) + invert(tensor2).$times($minus).mo1562dot($minus));
    }

    public double pr(Tensor1 tensor1, Tensor1 tensor12, Tensor2 tensor2) {
        return scala.math.package$.MODULE$.exp(logpr(tensor1, tensor12, tensor2));
    }

    public Tensor1 sampledValue(Tensor1 tensor1, Tensor2 tensor2, Random random) {
        return nextGaussian(tensor1, tensor2, random);
    }

    @Override // cc.factorie.directed.DirectedFamily3
    public MultivariateGaussian.Factor newFactor(MutableTensorVar mutableTensorVar, MutableTensorVar mutableTensorVar2, MutableTensorVar mutableTensorVar3) {
        return new MultivariateGaussian.Factor(mutableTensorVar, mutableTensorVar2, mutableTensorVar3);
    }

    public Tensor1 nextGaussian(Tensor1 tensor1, Tensor2 tensor2, Random random) {
        return new DenseTensor1((double[]) Array$.MODULE$.fill(tensor1.length(), new MultivariateGaussian$$anonfun$1(random), ClassTag$.MODULE$.Double())).$times(cholesky(tensor2)).$plus(tensor1);
    }

    public DenseTensor2 matrix2Tensor(DoubleMatrix doubleMatrix) {
        return new DenseTensor2(doubleMatrix.toArray2());
    }

    public DoubleMatrix tensor2Matrix(Tensor2 tensor2) {
        return new DoubleMatrix((double[][]) Predef$.MODULE$.doubleArrayOps(tensor2.asArray()).grouped(tensor2.dim1()).toArray(ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Double.TYPE))));
    }

    public DenseTensor2 invert(Tensor2 tensor2) {
        return matrix2Tensor(Solve.solve(tensor2Matrix(tensor2), DoubleMatrix.eye(tensor2.dim1())));
    }

    public DenseTensor2 cholesky(Tensor2 tensor2) {
        return matrix2Tensor(Decompose.cholesky(tensor2Matrix(tensor2)));
    }

    public double determinant(Tensor2 tensor2) {
        return ((DoubleMatrix) Decompose.lu(tensor2Matrix(tensor2)).u).diag().prod();
    }

    private MultivariateGaussian$() {
        MODULE$ = this;
        DirectedFamily3.Cclass.$init$(this);
    }
}
