package hex.util;

import Jama.CholeskyDecomposition;
import Jama.EigenvalueDecomposition;
import Jama.Matrix;
import hex.DataInfo;
import hex.FrameTask;
import hex.ToEigenVec;
import hex.gram.Gram;
import water.DKV;
import water.Job;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.util.ArrayUtils;

/* loaded from: input_file:hex/util/LinearAlgebraUtils.class */
public class LinearAlgebraUtils {
    public static ToEigenVec toEigen;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/util/LinearAlgebraUtils$BMulInPlaceTask.class */
    public static class BMulInPlaceTask extends MRTask<BMulInPlaceTask> {
        final DataInfo _xinfo;
        final double[][] _yt;
        final int _ncolX;
        static final /* synthetic */ boolean $assertionsDisabled;

        public BMulInPlaceTask(DataInfo dataInfo, double[][] dArr) {
            if (!$assertionsDisabled && (dArr == null || dArr[0].length != LinearAlgebraUtils.numColsExp(dataInfo._adaptedFrame, true))) {
                throw new AssertionError();
            }
            this._xinfo = dataInfo;
            this._ncolX = dataInfo._adaptedFrame.numCols();
            this._yt = dArr;
        }

        public void map(Chunk[] chunkArr) {
            if (!$assertionsDisabled && chunkArr.length != this._ncolX + this._yt.length) {
                throw new AssertionError();
            }
            Chunk[] chunkArr2 = new Chunk[this._ncolX];
            DataInfo.Row newDenseRow = this._xinfo.newDenseRow();
            System.arraycopy(chunkArr, 0, chunkArr2, 0, this._ncolX);
            for (int i = 0; i < chunkArr[0]._len; i++) {
                this._xinfo.extractDenseRow(chunkArr2, i, newDenseRow);
                if (!newDenseRow.isBad()) {
                    int i2 = this._ncolX;
                    for (double[] dArr : this._yt) {
                        chunkArr[i2].set(i, newDenseRow.innerProduct(dArr));
                        i2++;
                    }
                    if (!$assertionsDisabled && i2 != chunkArr.length) {
                        throw new AssertionError();
                    }
                }
            }
        }

        static {
            $assertionsDisabled = !LinearAlgebraUtils.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:hex/util/LinearAlgebraUtils$BMulTask.class */
    public static class BMulTask extends FrameTask<BMulTask> {
        final double[][] _yt;

        public BMulTask(Key<Job> key, DataInfo dataInfo, double[][] dArr) {
            super(key, dataInfo);
            this._yt = dArr;
        }

        @Override // hex.FrameTask
        protected void processRow(long j, DataInfo.Row row, NewChunk[] newChunkArr) {
            for (int i = 0; i < this._yt.length; i++) {
                newChunkArr[i].addNum(row.innerProduct(this._yt[i]));
            }
        }
    }

    /* loaded from: input_file:hex/util/LinearAlgebraUtils$ForwardSolve.class */
    public static class ForwardSolve extends MRTask<ForwardSolve> {
        final DataInfo _ainfo;
        final int _ncols;
        final double[][] _L;
        public double _sse;
        static final /* synthetic */ boolean $assertionsDisabled;

        public ForwardSolve(DataInfo dataInfo, double[][] dArr) {
            if (!$assertionsDisabled && (dArr == null || dArr.length != dArr[0].length || dArr.length != dataInfo._adaptedFrame.numCols())) {
                throw new AssertionError();
            }
            this._ainfo = dataInfo;
            this._ncols = dataInfo._adaptedFrame.numCols();
            this._L = dArr;
            this._sse = 0.0d;
        }

        public void map(Chunk[] chunkArr) {
            if (!$assertionsDisabled && 2 * this._ncols != chunkArr.length) {
                throw new AssertionError();
            }
            Chunk[] chunkArr2 = new Chunk[this._ncols];
            System.arraycopy(chunkArr, 0, chunkArr2, 0, this._ncols);
            for (int i = 0; i < chunkArr[0]._len; i++) {
                DataInfo.Row newDenseRow = this._ainfo.newDenseRow();
                this._ainfo.extractDenseRow(chunkArr2, i, newDenseRow);
                if (!newDenseRow.isBad()) {
                    double[] forwardSolve = LinearAlgebraUtils.forwardSolve(this._L, newDenseRow.expandCats());
                    int i2 = 0;
                    for (int i3 = this._ncols; i3 < 2 * this._ncols; i3++) {
                        double atd = forwardSolve[i2] - chunkArr[i3].atd(i);
                        this._sse += atd * atd;
                        int i4 = i2;
                        i2++;
                        chunkArr[i3].set(i, forwardSolve[i4]);
                    }
                    if (!$assertionsDisabled && i2 != forwardSolve.length) {
                        throw new AssertionError();
                    }
                }
            }
        }

        static {
            $assertionsDisabled = !LinearAlgebraUtils.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:hex/util/LinearAlgebraUtils$ForwardSolveInPlace.class */
    public static class ForwardSolveInPlace extends MRTask<ForwardSolveInPlace> {
        final DataInfo _ainfo;
        final int _ncols;
        final double[][] _L;
        static final /* synthetic */ boolean $assertionsDisabled;

        public ForwardSolveInPlace(DataInfo dataInfo, double[][] dArr) {
            if (!$assertionsDisabled && (dArr == null || dArr.length != dArr[0].length || dArr.length != dataInfo._adaptedFrame.numCols())) {
                throw new AssertionError();
            }
            this._ainfo = dataInfo;
            this._ncols = dataInfo._adaptedFrame.numCols();
            this._L = dArr;
        }

        public void map(Chunk[] chunkArr) {
            if (!$assertionsDisabled && this._ncols != chunkArr.length) {
                throw new AssertionError();
            }
            Chunk[] chunkArr2 = new Chunk[this._ncols];
            System.arraycopy(chunkArr, 0, chunkArr2, 0, this._ncols);
            for (int i = 0; i < chunkArr[0]._len; i++) {
                DataInfo.Row newDenseRow = this._ainfo.newDenseRow();
                this._ainfo.extractDenseRow(chunkArr2, i, newDenseRow);
                if (!newDenseRow.isBad()) {
                    double[] forwardSolve = LinearAlgebraUtils.forwardSolve(this._L, newDenseRow.expandCats());
                    if (!$assertionsDisabled && forwardSolve.length != this._ncols) {
                        throw new AssertionError();
                    }
                    for (int i2 = 0; i2 < this._ncols; i2++) {
                        chunkArr[i2].set(i, forwardSolve[i2]);
                    }
                }
            }
        }

        static {
            $assertionsDisabled = !LinearAlgebraUtils.class.desiredAssertionStatus();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/util/LinearAlgebraUtils$ProjectOntoEigenVector.class */
    public static class ProjectOntoEigenVector extends MRTask<ProjectOntoEigenVector> {
        final double[] _yCoord;

        ProjectOntoEigenVector(double[] dArr) {
            this._yCoord = dArr;
        }

        public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
            for (int i = 0; i < chunkArr[0]._len; i++) {
                if (chunkArr[0].isNA(i)) {
                    newChunkArr[0].addNA();
                } else {
                    newChunkArr[0].addNum((float) this._yCoord[(int) chunkArr[0].at8(i)]);
                }
            }
        }
    }

    /* loaded from: input_file:hex/util/LinearAlgebraUtils$SMulTask.class */
    public static class SMulTask extends MRTask<SMulTask> {
        final DataInfo _ainfo;
        final int _ncolA;
        final int _ncolExp;
        final int _ncolQ;
        public double[][] _atq;
        static final /* synthetic */ boolean $assertionsDisabled;

        public SMulTask(DataInfo dataInfo, int i) {
            this._ainfo = dataInfo;
            this._ncolA = dataInfo._adaptedFrame.numCols();
            this._ncolExp = LinearAlgebraUtils.numColsExp(dataInfo._adaptedFrame, true);
            this._ncolQ = i;
        }

        public void map(Chunk[] chunkArr) {
            int categoricalId;
            if (!$assertionsDisabled && this._ncolA + this._ncolQ != chunkArr.length) {
                throw new AssertionError();
            }
            this._atq = new double[this._ncolExp][this._ncolQ];
            for (int i = this._ncolA; i < this._ncolA + this._ncolQ; i++) {
                for (int i2 = 0; i2 < this._ainfo._cats; i2++) {
                    for (int i3 = 0; i3 < chunkArr[0]._len; i3++) {
                        if (!chunkArr[i2].isNA(i3) || !this._ainfo._skipMissing) {
                            double atd = chunkArr[i].atd(i3);
                            double atd2 = chunkArr[i2].atd(i3);
                            if (!Double.isNaN(atd2)) {
                                categoricalId = this._ainfo.getCategoricalId(i2, (int) atd2);
                            } else if (this._ainfo._imputeMissing) {
                                categoricalId = this._ainfo.catNAFill()[i2];
                            } else if (this._ainfo._catMissing[i2]) {
                                categoricalId = this._ainfo._catOffsets[i2 + 1] - 1;
                            }
                            if (categoricalId >= 0) {
                                double[] dArr = this._atq[categoricalId];
                                int i4 = i - this._ncolA;
                                dArr[i4] = dArr[i4] + atd;
                            }
                        }
                    }
                }
                int i5 = 0;
                int numStart = this._ainfo.numStart();
                for (int i6 = this._ainfo._cats; i6 < this._ncolA; i6++) {
                    for (int i7 = 0; i7 < chunkArr[0]._len; i7++) {
                        if (!chunkArr[i6].isNA(i7) || !this._ainfo._skipMissing) {
                            double atd3 = chunkArr[i].atd(i7);
                            double modifyNumeric = LinearAlgebraUtils.modifyNumeric(chunkArr[i6].atd(i7), i5, this._ainfo);
                            double[] dArr2 = this._atq[numStart];
                            int i8 = i - this._ncolA;
                            dArr2[i8] = dArr2[i8] + (atd3 * modifyNumeric);
                        }
                    }
                    numStart++;
                    i5++;
                }
                if (!$assertionsDisabled && numStart != this._atq.length) {
                    throw new AssertionError();
                }
            }
        }

        public void reduce(SMulTask sMulTask) {
            ArrayUtils.add(this._atq, sMulTask._atq);
        }

        static {
            $assertionsDisabled = !LinearAlgebraUtils.class.desiredAssertionStatus();
        }
    }

    public static double[] forwardSolve(double[][] dArr, double[] dArr2) {
        if (!$assertionsDisabled && (dArr == null || dArr.length != dArr[0].length || dArr.length != dArr2.length)) {
            throw new AssertionError();
        }
        double[] dArr3 = new double[dArr2.length];
        for (int i = 0; i < dArr2.length; i++) {
            dArr3[i] = dArr2[i];
            for (int i2 = 0; i2 < i; i2++) {
                int i3 = i;
                dArr3[i3] = dArr3[i3] - (dArr[i][i2] * dArr3[i2]);
            }
            int i4 = i;
            dArr3[i4] = dArr3[i4] / dArr[i][i];
        }
        return dArr3;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double modifyNumeric(double d, int i, DataInfo dataInfo) {
        double d2 = (Double.isNaN(d) && dataInfo._imputeMissing) ? dataInfo._numMeans[i] : d;
        if (dataInfo._normSub != null && dataInfo._normMul != null) {
            d2 = (d2 - dataInfo._normSub[i]) * dataInfo._normMul[i];
        }
        return d2;
    }

    public static double[] expandRow(double[] dArr, DataInfo dataInfo, double[] dArr2, boolean z) {
        int categoricalId;
        for (int i = 0; i < dataInfo._cats; i++) {
            if (!Double.isNaN(dArr[i])) {
                categoricalId = dataInfo.getCategoricalId(i, (int) dArr[i]);
            } else if (dataInfo._imputeMissing) {
                categoricalId = dataInfo.catNAFill()[i];
            } else if (dataInfo._catMissing[i]) {
                categoricalId = dataInfo._catOffsets[i + 1] - 1;
            }
            if (categoricalId >= 0) {
                dArr2[categoricalId] = 1.0d;
            }
        }
        int i2 = dataInfo._cats;
        int numStart = dataInfo.numStart();
        for (int i3 = 0; i3 < dataInfo._nums; i3++) {
            dArr2[numStart] = z ? modifyNumeric(dArr[i2], i3, dataInfo) : dArr[i2];
            numStart++;
            i2++;
        }
        return dArr2;
    }

    public static double[][] computeR(Key<Job> key, DataInfo dataInfo, boolean z) {
        Gram.GramTask gramTask = new Gram.GramTask(key, dataInfo);
        gramTask.doAll(dataInfo._adaptedFrame);
        double[][] array = new CholeskyDecomposition(new Matrix(gramTask._gram.getXX())).getL().getArray();
        ArrayUtils.mult(array, Math.sqrt(gramTask._nobs));
        return z ? array : ArrayUtils.transpose(array);
    }

    public static double computeQ(Key<Job> key, DataInfo dataInfo, Frame frame) {
        ForwardSolve forwardSolve = new ForwardSolve(dataInfo, computeR(key, dataInfo, true));
        forwardSolve.doAll(frame);
        return forwardSolve._sse;
    }

    public static void computeQInPlace(Key<Job> key, DataInfo dataInfo) {
        new ForwardSolveInPlace(dataInfo, computeR(key, dataInfo, true)).doAll(dataInfo._adaptedFrame);
    }

    public static int numColsExp(Frame frame, boolean z) {
        int i = z ? 0 : 1;
        int i2 = 0;
        for (Vec vec : frame.vecs()) {
            i2 += (!vec.isCategorical() || vec.domain() == null) ? 1 : vec.domain().length - i;
        }
        return i2;
    }

    static double[] multiple(double[] dArr, int i, int i2) {
        int length = dArr.length;
        for (int i3 = 0; i3 < length; i3++) {
            int i4 = i3;
            dArr[i4] = dArr[i4] * i;
        }
        double[][] dArr2 = new double[length][length];
        int i5 = 0;
        while (i5 < length) {
            int i6 = 0;
            while (i6 < length) {
                dArr2[i5][i6] = ((i5 == i6 ? dArr[i5] : 0.0d) - ((dArr[i5] * dArr[i6]) / i)) / (i2 * Math.sqrt(dArr[i5] * dArr[i6]));
                if (Double.isNaN(dArr2[i5][i6])) {
                    dArr2[i5][i6] = 0.0d;
                }
                i6++;
            }
            i5++;
        }
        EigenvalueDecomposition eigenvalueDecomposition = new EigenvalueDecomposition(new Matrix(dArr2));
        return eigenvalueDecomposition.getV().getArray()[ArrayUtils.maxIndex(eigenvalueDecomposition.getRealEigenvalues())];
    }

    public static Vec toEigen(Vec vec) {
        Frame frame = new Frame(Key.make(), new String[]{"enum"}, new Vec[]{vec});
        DataInfo dataInfo = new DataInfo(frame, (Frame) null, 0, true, DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, false, true, false, false, false, false, false);
        DKV.put(dataInfo);
        Gram.GramTask gramTask = (Gram.GramTask) new Gram.GramTask(null, dataInfo).doAll(dataInfo._adaptedFrame);
        double[] dArr = new double[gramTask._gram._diag.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = (float) gramTask._gram._diag[i];
        }
        dataInfo.remove();
        return ((ProjectOntoEigenVector) new ProjectOntoEigenVector(multiple(dArr, (int) gramTask._nobs, 1)).doAll(1, (byte) 3, frame)).outputFrame().anyVec();
    }

    static {
        $assertionsDisabled = !LinearAlgebraUtils.class.desiredAssertionStatus();
        toEigen = new ToEigenVec() { // from class: hex.util.LinearAlgebraUtils.1
            public Vec toEigenVec(Vec vec) {
                return LinearAlgebraUtils.toEigen(vec);
            }
        };
    }
}
