package dragon.ir.topicmodel;

import dragon.matrix.IntSparseMatrix;
import java.util.Random;

/* loaded from: input_file:dragon/ir/topicmodel/CrossMixtureModel.class */
public class CrossMixtureModel extends AbstractModel {
    protected IntSparseMatrix[] arrTopicReader;
    protected double[] bkgModel;
    protected double bkgCoefficient;
    protected double comCoefficient;
    protected int themeNum;
    protected int collectionNum;
    protected int maxTermNum;
    protected int maxDocNum;
    private double[][][] arrDocWeight;
    private double[][][] arrProb;
    private double[][] arrCommonProb;

    public CrossMixtureModel(IntSparseMatrix[] intSparseMatrixArr, int i, double[] dArr, double d, double d2) {
        this.arrTopicReader = intSparseMatrixArr;
        this.themeNum = i;
        this.collectionNum = this.arrTopicReader.length;
        this.bkgModel = new double[dArr.length];
        this.comCoefficient = d2;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            this.bkgModel[i2] = dArr[i2] * d;
        }
        this.bkgCoefficient = d;
        this.maxTermNum = this.arrTopicReader[0].columns();
        this.maxDocNum = this.arrTopicReader[0].rows();
        for (int i3 = 1; i3 < this.arrTopicReader.length; i3++) {
            if (this.arrTopicReader[i3].columns() > this.maxTermNum) {
                this.maxTermNum = this.arrTopicReader[i3].columns();
            }
            if (this.arrTopicReader[i3].rows() > this.maxDocNum) {
                this.maxDocNum = this.arrTopicReader[i3].rows();
            }
        }
    }

    public double[][][] getModels() {
        return this.arrProb;
    }

    public double[][] getCommonModels() {
        return this.arrCommonProb;
    }

    public double[][][] getDocMemberships() {
        return this.arrDocWeight;
    }

    public boolean estimateModel() {
        this.arrProb = new double[this.collectionNum][this.themeNum][this.maxTermNum];
        this.arrCommonProb = new double[this.themeNum][this.maxTermNum];
        this.arrDocWeight = new double[this.collectionNum][this.themeNum][this.maxDocNum];
        double[][][] dArr = new double[this.collectionNum][this.themeNum][this.maxTermNum];
        double[][] dArr2 = new double[this.themeNum][this.maxTermNum];
        double[] dArr3 = new double[this.themeNum];
        initialize(this.maxTermNum, this.collectionNum, this.themeNum, this.maxDocNum, this.arrCommonProb, this.arrProb, this.arrDocWeight);
        printStatus("Estimating the coefficients of simple mixture model...");
        for (int i = 0; i < this.iterations; i++) {
            printStatus("Iteration #" + (i + 1));
            for (int i2 = 0; i2 < this.themeNum; i2++) {
                for (int i3 = 0; i3 < this.maxTermNum; i3++) {
                    dArr2[i2][i3] = 0.0d;
                }
            }
            for (int i4 = 0; i4 < this.collectionNum; i4++) {
                for (int i5 = 0; i5 < this.themeNum; i5++) {
                    for (int i6 = 0; i6 < this.maxTermNum; i6++) {
                        dArr[i4][i5][i6] = 0.0d;
                    }
                }
            }
            for (int i7 = 0; i7 < this.collectionNum; i7++) {
                int rows = this.arrTopicReader[i7].rows();
                for (int i8 = 0; i8 < rows; i8++) {
                    int[] nonZeroColumnsInRow = this.arrTopicReader[i7].getNonZeroColumnsInRow(i8);
                    int[] nonZeroIntScoresInRow = this.arrTopicReader[i7].getNonZeroIntScoresInRow(i8);
                    for (int i9 = 0; i9 < this.themeNum; i9++) {
                        dArr3[i9] = 0.0d;
                    }
                    for (int i10 = 0; i10 < nonZeroColumnsInRow.length; i10++) {
                        int i11 = nonZeroColumnsInRow[i10];
                        double d = 0.0d;
                        for (int i12 = 0; i12 < this.themeNum; i12++) {
                            d += ((this.comCoefficient * this.arrCommonProb[i12][i10]) + ((1.0d - this.comCoefficient) * this.arrProb[i7][i12][i10])) * this.arrDocWeight[i7][i12][i8];
                        }
                        double d2 = this.bkgModel[i11] / ((d * (1.0d - this.bkgCoefficient)) + this.bkgModel[i11]);
                        for (int i13 = 0; i13 < this.themeNum; i13++) {
                            double d3 = d == 0.0d ? 0.0d : (((this.comCoefficient * this.arrCommonProb[i13][i11]) + ((1.0d - this.comCoefficient) * this.arrProb[i7][i13][i11])) * this.arrDocWeight[i7][i13][i8]) / d;
                            double d4 = (this.comCoefficient * this.arrCommonProb[i13][i11]) + ((1.0d - this.comCoefficient) * this.arrProb[i7][i13][i11]);
                            double d5 = d4 > 0.0d ? (this.comCoefficient * this.arrCommonProb[i13][i11]) / d4 : 0.0d;
                            double d6 = nonZeroIntScoresInRow[i10] * d3;
                            int i14 = i13;
                            dArr3[i14] = dArr3[i14] + d6;
                            double d7 = d6 * (1.0d - d2);
                            double[] dArr4 = dArr[i7][i13];
                            dArr4[i11] = dArr4[i11] + (d7 * (1.0d - d5));
                            double[] dArr5 = dArr2[i13];
                            dArr5[i11] = dArr5[i11] + (d7 * d5);
                        }
                    }
                    double d8 = 0.0d;
                    for (int i15 = 0; i15 < this.themeNum; i15++) {
                        d8 += dArr3[i15];
                    }
                    if (d8 > 0.0d) {
                        for (int i16 = 0; i16 < this.themeNum; i16++) {
                            this.arrDocWeight[i7][i16][i8] = dArr3[i16] / d8;
                        }
                    } else {
                        for (int i17 = 0; i17 < this.themeNum; i17++) {
                            this.arrDocWeight[i7][i17][i8] = 0.0d;
                        }
                    }
                }
            }
            for (int i18 = 0; i18 < this.themeNum; i18++) {
                double d9 = 0.0d;
                for (int i19 = 0; i19 < this.maxTermNum; i19++) {
                    d9 += dArr2[i18][i19];
                }
                for (int i20 = 0; i20 < this.maxTermNum; i20++) {
                    this.arrCommonProb[i18][i20] = dArr2[i18][i20] / d9;
                }
                for (int i21 = 0; i21 < this.collectionNum; i21++) {
                    double d10 = 0.0d;
                    for (int i22 = 0; i22 < this.maxTermNum; i22++) {
                        d10 += dArr[i21][i18][i22];
                    }
                    for (int i23 = 0; i23 < this.maxTermNum; i23++) {
                        this.arrProb[i21][i18][i23] = dArr[i21][i18][i23] / d10;
                    }
                }
            }
        }
        printStatus("");
        return true;
    }

    protected void initialize(int i, int i2, int i3, int i4, double[][] dArr, double[][][] dArr2, double[][][] dArr3) {
        double d = 1.0d / i;
        for (int i5 = 0; i5 < i3; i5++) {
            for (int i6 = 0; i6 < i; i6++) {
                dArr[i5][i6] = d;
            }
        }
        for (int i7 = 0; i7 < i2; i7++) {
            for (int i8 = 0; i8 < i3; i8++) {
                for (int i9 = 0; i9 < i; i9++) {
                    dArr2[i7][i8][i9] = d;
                }
            }
        }
        Random random = this.seed >= 0 ? new Random(this.seed) : new Random();
        for (int i10 = 0; i10 < i2; i10++) {
            for (int i11 = 0; i11 < i4; i11++) {
                double d2 = 0.0d;
                for (int i12 = 0; i12 < i3; i12++) {
                    dArr3[i10][i12][i11] = random.nextDouble();
                    d2 += dArr3[i10][i12][i11];
                }
                for (int i13 = 0; i13 < i3; i13++) {
                    dArr3[i10][i13][i11] = dArr3[i10][i13][i11] / d2;
                }
            }
        }
    }
}
