/*
 * Decompiled with CFR 0.152.
 */
package org.apache.commons.math4.distribution.fitting;

import io.virtdata.shaded.oac.statistics.correlation.Covariance;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.math4.distribution.MixtureMultivariateNormalDistribution;
import org.apache.commons.math4.distribution.MultivariateNormalDistribution;
import org.apache.commons.math4.exception.ConvergenceException;
import org.apache.commons.math4.exception.DimensionMismatchException;
import org.apache.commons.math4.exception.NotStrictlyPositiveException;
import org.apache.commons.math4.exception.NumberIsTooLargeException;
import org.apache.commons.math4.exception.NumberIsTooSmallException;
import org.apache.commons.math4.exception.util.Localizable;
import org.apache.commons.math4.exception.util.LocalizedFormats;
import org.apache.commons.math4.linear.Array2DRowRealMatrix;
import org.apache.commons.math4.linear.RealMatrix;
import org.apache.commons.math4.linear.SingularMatrixException;
import org.apache.commons.math4.util.FastMath;
import org.apache.commons.math4.util.MathArrays;
import org.apache.commons.math4.util.Pair;

public class MultivariateNormalMixtureExpectationMaximization {
    private static final int DEFAULT_MAX_ITERATIONS = 1000;
    private static final double DEFAULT_THRESHOLD = 1.0E-5;
    private final double[][] data;
    private MixtureMultivariateNormalDistribution fittedModel;
    private double logLikelihood = 0.0;

    public MultivariateNormalMixtureExpectationMaximization(double[][] dArray) throws NotStrictlyPositiveException, DimensionMismatchException, NumberIsTooSmallException {
        if (dArray.length < 1) {
            throw new NotStrictlyPositiveException(dArray.length);
        }
        this.data = new double[dArray.length][dArray[0].length];
        for (int i = 0; i < dArray.length; ++i) {
            if (dArray[i].length != dArray[0].length) {
                throw new DimensionMismatchException(dArray[i].length, dArray[0].length);
            }
            if (dArray[i].length < 2) {
                throw new NumberIsTooSmallException((Localizable)LocalizedFormats.NUMBER_TOO_SMALL, (Number)dArray[i].length, 2, true);
            }
            this.data[i] = MathArrays.copyOf(dArray[i], dArray[i].length);
        }
    }

    public void fit(MixtureMultivariateNormalDistribution mixtureMultivariateNormalDistribution, int n, double d) throws SingularMatrixException, NotStrictlyPositiveException, DimensionMismatchException {
        if (n < 1) {
            throw new NotStrictlyPositiveException(n);
        }
        if (d < Double.MIN_VALUE) {
            throw new NotStrictlyPositiveException(d);
        }
        int n2 = this.data.length;
        int n3 = this.data[0].length;
        int n4 = mixtureMultivariateNormalDistribution.getComponents().size();
        int n5 = ((MultivariateNormalDistribution)mixtureMultivariateNormalDistribution.getComponents().get(0).getSecond()).getMeans().length;
        if (n5 != n3) {
            throw new DimensionMismatchException(n5, n3);
        }
        int n6 = 0;
        double d2 = 0.0;
        this.logLikelihood = Double.NEGATIVE_INFINITY;
        this.fittedModel = new MixtureMultivariateNormalDistribution(mixtureMultivariateNormalDistribution.getComponents());
        while (n6++ <= n && FastMath.abs(d2 - this.logLikelihood) > d) {
            int n7;
            int n8;
            d2 = this.logLikelihood;
            double d3 = 0.0;
            List list = this.fittedModel.getComponents();
            double[] dArray = new double[n4];
            MultivariateNormalDistribution[] multivariateNormalDistributionArray = new MultivariateNormalDistribution[n4];
            for (int i = 0; i < n4; ++i) {
                dArray[i] = list.get(i).getFirst();
                multivariateNormalDistributionArray[i] = (MultivariateNormalDistribution)list.get(i).getSecond();
            }
            double[][] dArray2 = new double[n2][n4];
            double[] dArray3 = new double[n4];
            double[][] dArray4 = new double[n4][n3];
            for (int i = 0; i < n2; ++i) {
                double d4 = this.fittedModel.density(this.data[i]);
                d3 += FastMath.log(d4);
                for (n8 = 0; n8 < n4; ++n8) {
                    dArray2[i][n8] = dArray[n8] * multivariateNormalDistributionArray[n8].density(this.data[i]) / d4;
                    int n9 = n8;
                    dArray3[n9] = dArray3[n9] + dArray2[i][n8];
                    for (n7 = 0; n7 < n3; ++n7) {
                        double[] dArray5 = dArray4[n8];
                        int n10 = n7;
                        dArray5[n10] = dArray5[n10] + dArray2[i][n8] * this.data[i][n7];
                    }
                }
            }
            this.logLikelihood = d3 / (double)n2;
            double[] dArray6 = new double[n4];
            double[][] dArray7 = new double[n4][n3];
            for (int i = 0; i < n4; ++i) {
                dArray6[i] = dArray3[i] / (double)n2;
                for (n8 = 0; n8 < n3; ++n8) {
                    dArray7[i][n8] = dArray4[i][n8] / dArray3[i];
                }
            }
            RealMatrix[] realMatrixArray = new RealMatrix[n4];
            for (n8 = 0; n8 < n4; ++n8) {
                realMatrixArray[n8] = new Array2DRowRealMatrix(n3, n3);
            }
            for (n8 = 0; n8 < n2; ++n8) {
                for (n7 = 0; n7 < n4; ++n7) {
                    Array2DRowRealMatrix array2DRowRealMatrix = new Array2DRowRealMatrix(MathArrays.ebeSubtract(this.data[n8], dArray7[n7]));
                    RealMatrix realMatrix = array2DRowRealMatrix.multiply(array2DRowRealMatrix.transpose()).scalarMultiply(dArray2[n8][n7]);
                    realMatrixArray[n7] = realMatrixArray[n7].add(realMatrix);
                }
            }
            double[][][] dArray8 = new double[n4][n3][n3];
            for (n7 = 0; n7 < n4; ++n7) {
                realMatrixArray[n7] = realMatrixArray[n7].scalarMultiply(1.0 / dArray3[n7]);
                dArray8[n7] = realMatrixArray[n7].getData();
            }
            this.fittedModel = new MixtureMultivariateNormalDistribution(dArray6, dArray7, dArray8);
        }
        if (FastMath.abs(d2 - this.logLikelihood) > d) {
            throw new ConvergenceException();
        }
    }

    public void fit(MixtureMultivariateNormalDistribution mixtureMultivariateNormalDistribution) throws SingularMatrixException, NotStrictlyPositiveException {
        this.fit(mixtureMultivariateNormalDistribution, 1000, 1.0E-5);
    }

    public static MixtureMultivariateNormalDistribution estimate(double[][] dArray, int n) throws NotStrictlyPositiveException, DimensionMismatchException {
        if (dArray.length < 2) {
            throw new NotStrictlyPositiveException(dArray.length);
        }
        if (n < 2) {
            throw new NumberIsTooSmallException(n, (Number)2, true);
        }
        if (n > dArray.length) {
            throw new NumberIsTooLargeException(n, (Number)dArray.length, true);
        }
        int n2 = dArray.length;
        int n3 = dArray[0].length;
        Object[] objectArray = new DataRow[n2];
        for (int i = 0; i < n2; ++i) {
            objectArray[i] = new DataRow(dArray[i]);
        }
        Arrays.sort(objectArray);
        double d = 1.0 / (double)n;
        ArrayList<Pair<Double, MultivariateNormalDistribution>> arrayList = new ArrayList<Pair<Double, MultivariateNormalDistribution>>(n);
        for (int i = 0; i < n; ++i) {
            int n4 = i * n2 / n;
            int n5 = (i + 1) * n2 / n;
            int n6 = n5 - n4;
            double[][] dArray2 = new double[n6][n3];
            double[] dArray3 = new double[n3];
            int n7 = n4;
            int n8 = 0;
            while (n7 < n5) {
                for (int j = 0; j < n3; ++j) {
                    double d2 = ((DataRow)objectArray[n7]).getRow()[j];
                    int n9 = j;
                    dArray3[n9] = dArray3[n9] + d2;
                    dArray2[n8][j] = d2;
                }
                ++n7;
                ++n8;
            }
            MathArrays.scaleInPlace(1.0 / (double)n6, dArray3);
            double[][] dArray4 = new Covariance(dArray2).getCovarianceMatrix().getData();
            MultivariateNormalDistribution multivariateNormalDistribution = new MultivariateNormalDistribution(dArray3, dArray4);
            arrayList.add(new Pair<Double, MultivariateNormalDistribution>(d, multivariateNormalDistribution));
        }
        return new MixtureMultivariateNormalDistribution((List<Pair<Double, MultivariateNormalDistribution>>)arrayList);
    }

    public double getLogLikelihood() {
        return this.logLikelihood;
    }

    public MixtureMultivariateNormalDistribution getFittedModel() {
        return new MixtureMultivariateNormalDistribution(this.fittedModel.getComponents());
    }

    private static class DataRow
    implements Comparable<DataRow> {
        private final double[] row;
        private Double mean;

        DataRow(double[] dArray) {
            this.row = dArray;
            this.mean = 0.0;
            for (int i = 0; i < dArray.length; ++i) {
                this.mean = this.mean + dArray[i];
            }
            this.mean = this.mean / (double)dArray.length;
        }

        @Override
        public int compareTo(DataRow dataRow) {
            return this.mean.compareTo(dataRow.mean);
        }

        public boolean equals(Object object) {
            if (this == object) {
                return true;
            }
            if (object instanceof DataRow) {
                return MathArrays.equals(this.row, ((DataRow)object).row);
            }
            return false;
        }

        public int hashCode() {
            return Arrays.hashCode(this.row);
        }

        public double[] getRow() {
            return this.row;
        }
    }
}

