package hex.psvm;

import hex.DataInfo;
import hex.FrameTask;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.psvm.PSVMModel;
import hex.psvm.psvm.IncompleteCholeskyFactorization;
import hex.psvm.psvm.Kernel;
import hex.psvm.psvm.PrimalDualIPM;
import java.util.ArrayList;
import java.util.Arrays;
import water.AutoBuffer;
import water.DKV;
import water.H2O;
import water.Key;
import water.MRTask;
import water.Scope;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.TwoDimTable;
import water.util.VecUtils;

/* loaded from: input_file:hex/psvm/PSVM.class */
public class PSVM extends ModelBuilder<PSVMModel, PSVMModel.PSVMParameters, PSVMModel.PSVMModelOutput> {
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/psvm/PSVM$CalculateRhoTask.class */
    private static class CalculateRhoTask extends FrameTask<CalculateRhoTask> {
        DataInfo.Row[] _selected;
        Vec _alpha;
        Kernel _kernel;
        double[] _rhos;
        transient long _offset;
        transient Chunk _alphaChunk;

        public CalculateRhoTask(DataInfo dataInfo, DataInfo.Row[] rowArr, Vec vec, Kernel kernel) {
            super(null, dataInfo);
            this._selected = rowArr;
            this._alpha = vec;
            this._kernel = kernel;
        }

        @Override // hex.FrameTask
        public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
            this._alphaChunk = this._alpha.chunkForChunkIdx(chunkArr[0].cidx());
            this._offset = this._alphaChunk.start();
            this._rhos = new double[this._selected.length];
            super.map(chunkArr, newChunkArr);
        }

        @Override // hex.FrameTask
        protected boolean skipRow(long j) {
            return this._alphaChunk.isNA((int) (j - this._offset));
        }

        @Override // hex.FrameTask
        protected void processRow(long j, DataInfo.Row row) {
            for (int i = 0; i < this._selected.length; i++) {
                double[] dArr = this._rhos;
                int i2 = i;
                dArr[i2] = dArr[i2] + (this._alphaChunk.atd((int) (j - this._offset)) * this._kernel.calcKernel(row, this._selected[i]));
            }
        }

        public void reduce(CalculateRhoTask calculateRhoTask) {
            this._rhos = ArrayUtils.add(this._rhos, calculateRhoTask._rhos);
        }

        double getRho() {
            double d = 0.0d;
            for (int i = 0; i < this._selected.length; i++) {
                d += this._selected[i].response[0] - this._rhos[i];
            }
            return d / this._selected.length;
        }
    }

    /* loaded from: input_file:hex/psvm/PSVM$CollectSupportVecSamplesTask.class */
    private static class CollectSupportVecSamplesTask extends FrameTask<CollectSupportVecSamplesTask> {
        private Vec _svs;
        private int _num_selected;
        DataInfo.Row[][] _selected;
        private transient long[] _local_selected_idxs;

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Type inference failed for: r1v2, types: [hex.DataInfo$Row[], hex.DataInfo$Row[][]] */
        @Override // hex.FrameTask
        public void setupLocal() {
            super.setupLocal();
            this._selected = new DataInfo.Row[H2O.CLOUD.size()];
            int[] localChunkIds = VecUtils.getLocalChunkIds(this._svs);
            long j = 0;
            for (int i : localChunkIds) {
                j += this._svs.chunkLen(i);
            }
            int length = (int) ((this._num_selected * j) / this._svs.length());
            this._local_selected_idxs = new long[length];
            this._selected[H2O.SELF.index()] = new DataInfo.Row[length];
            int i2 = 0;
            loop1: for (int i3 : localChunkIds) {
                Chunk chunkForChunkIdx = this._svs.chunkForChunkIdx(i3);
                for (int i4 = 0; i4 < chunkForChunkIdx._len; i4++) {
                    int i5 = i2;
                    i2++;
                    this._local_selected_idxs[i5] = chunkForChunkIdx.at8(i4);
                    if (i2 == length) {
                        break loop1;
                    }
                }
            }
            Arrays.sort(this._local_selected_idxs);
        }

        CollectSupportVecSamplesTask(DataInfo dataInfo, Vec vec, int i) {
            super(null, dataInfo);
            this._svs = vec;
            this._num_selected = i;
        }

        @Override // hex.FrameTask
        protected boolean skipRow(long j) {
            return Arrays.binarySearch(this._local_selected_idxs, j) < 0;
        }

        @Override // hex.FrameTask
        protected void processRow(long j, DataInfo.Row row) {
            this._selected[H2O.SELF.index()][Arrays.binarySearch(this._local_selected_idxs, j)] = row.deepClone();
        }

        public void reduce(CollectSupportVecSamplesTask collectSupportVecSamplesTask) {
            for (int i = 0; i < H2O.CLOUD.size(); i++) {
                if (collectSupportVecSamplesTask._selected[i] != null) {
                    this._selected[i] = collectSupportVecSamplesTask._selected[i];
                }
            }
        }

        DataInfo.Row[] getSelected() {
            return (DataInfo.Row[]) ArrayUtils.flat(this._selected);
        }
    }

    /* loaded from: input_file:hex/psvm/PSVM$CompressVectorsTask.class */
    private static class CompressVectorsTask extends MRTask<CompressVectorsTask> {
        private final DataInfo _dinfo;
        private byte[] _csvs;

        CompressVectorsTask(DataInfo dataInfo) {
            this._dinfo = dataInfo;
        }

        public void map(Chunk[] chunkArr) {
            AutoBuffer autoBuffer = new AutoBuffer();
            Chunk chunk = chunkArr[chunkArr.length - 1];
            Chunk[] chunkArr2 = (Chunk[]) Arrays.copyOf(chunkArr, chunkArr.length - 1);
            DataInfo.Row newDenseRow = this._dinfo.newDenseRow();
            SupportVector supportVector = new SupportVector();
            for (int i = 0; i < chunk._len; i++) {
                if (!chunk.isNA(i)) {
                    this._dinfo.extractDenseRow(chunkArr2, i, newDenseRow);
                    supportVector.fill(chunk.atd(i), newDenseRow.numVals, newDenseRow.binIds);
                    supportVector.compress(autoBuffer);
                }
            }
            this._csvs = autoBuffer.buf();
        }

        public void reduce(CompressVectorsTask compressVectorsTask) {
            this._csvs = ArrayUtils.append(this._csvs, compressVectorsTask._csvs);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/psvm/PSVM$IPMInfo.class */
    public static class IPMInfo implements PrimalDualIPM.ProgressObserver {
        int _iter;
        double _sgap;
        double _resp;
        double _resd;
        boolean _converged;

        private IPMInfo() {
        }

        @Override // hex.psvm.psvm.PrimalDualIPM.ProgressObserver
        public void reportProgress(int i, double d, double d2, double d3, boolean z) {
            this._iter = i;
            this._sgap = d;
            this._resp = d2;
            this._resd = d3;
            this._converged = z;
        }
    }

    /* loaded from: input_file:hex/psvm/PSVM$RegulateAlphaTask.class */
    private static class RegulateAlphaTask extends MRTask<RegulateAlphaTask> {
        private double _c_pos;
        private double _c_neg;
        private double _sv_threshold;
        long _svs_count;
        long _bsv_count;

        private RegulateAlphaTask(double d, double d2, double d3) {
            this._c_pos = d;
            this._c_neg = d2;
            this._sv_threshold = d3;
        }

        public void map(Chunk chunk, Chunk chunk2, NewChunk newChunk) {
            double d;
            for (int i = 0; i < chunk._len; i++) {
                double atd = chunk.atd(i);
                if (atd <= this._sv_threshold) {
                    chunk.setNA(i);
                } else {
                    this._svs_count++;
                    newChunk.addNum(chunk.start() + i);
                    double d2 = chunk2.atd(i) > 0.0d ? this._c_pos : this._c_neg;
                    if (d2 - atd <= this._sv_threshold) {
                        d = d2;
                        this._bsv_count++;
                    } else {
                        d = atd;
                    }
                    chunk.set(i, d * chunk2.atd(i));
                }
            }
        }

        public void reduce(RegulateAlphaTask regulateAlphaTask) {
            this._svs_count += regulateAlphaTask._svs_count;
            this._bsv_count += regulateAlphaTask._bsv_count;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public Vec updateStats(PSVMModel.PSVMModelOutput pSVMModelOutput) {
            pSVMModelOutput._svs_count = this._svs_count;
            pSVMModelOutput._bsv_count = this._bsv_count;
            return outputFrame().vec(0);
        }
    }

    /* loaded from: input_file:hex/psvm/PSVM$SVMDriver.class */
    private class SVMDriver extends ModelBuilder<PSVMModel, PSVMModel.PSVMParameters, PSVMModel.PSVMModelOutput>.Driver {
        static final /* synthetic */ boolean $assertionsDisabled;

        private SVMDriver() {
            super(PSVM.this);
        }

        /* JADX WARN: Type inference failed for: r0v12, types: [hex.psvm.PSVM$SVMDriver$1] */
        DataInfo adaptTrain() {
            Frame frame = new Frame(PSVM.this.train());
            frame.remove(((PSVMModel.PSVMParameters) PSVM.this._parms)._response_column);
            if (PSVM.this.response().naCnt() > 0) {
                throw new IllegalStateException("NA values in response column are currently not supported.");
            }
            frame.add(((PSVMModel.PSVMParameters) PSVM.this._parms)._response_column, PSVM.this.response().domain() == null ? PSVM.this.response() : new MRTask() { // from class: hex.psvm.PSVM.SVMDriver.1
                public void map(Chunk chunk, NewChunk newChunk) {
                    for (int i = 0; i < chunk._len; i++) {
                        if (chunk.at8(i) == 0) {
                            newChunk.addNum(-1.0d);
                        } else {
                            newChunk.addNum(1.0d);
                        }
                    }
                }
            }.doAll((byte) 3, new Vec[]{PSVM.this.response()}).outputFrame().vec(0));
            frame.add("two_norm_sq", Scope.track(frame.anyVec().makeZero()));
            return new DataInfo(frame, (Frame) null, 2, true, DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, true, false, false, false, false, false, (Model.InteractionSpec) null).disableIntercept();
        }

        Frame prototypeFrame(DataInfo dataInfo) {
            Frame frame = new Frame(dataInfo._adaptedFrame);
            frame.remove("two_norm_sq");
            return frame;
        }

        public void computeImpl() {
            PSVMModel pSVMModel = null;
            try {
                PSVM.this.init(true);
                PSVM.this._job.update(0L, "Initializing model training");
                DataInfo adaptTrain = adaptTrain();
                Scope.track_generic(adaptTrain);
                DKV.put(adaptTrain);
                if (((PSVMModel.PSVMParameters) PSVM.this._parms)._gamma == -1.0d) {
                    ((PSVMModel.PSVMParameters) PSVM.this._parms)._gamma = 1.0d / adaptTrain.fullN();
                    Log.info(new Object[]{"Set gamma = " + ((PSVMModel.PSVMParameters) PSVM.this._parms)._gamma});
                }
                Vec vec = adaptTrain._adaptedFrame.vec(((PSVMModel.PSVMParameters) PSVM.this._parms)._response_column);
                PSVMModel pSVMModel2 = new PSVMModel(PSVM.this._result, (PSVMModel.PSVMParameters) PSVM.this._parms, new PSVMModel.PSVMModelOutput(PSVM.this, prototypeFrame(adaptTrain), PSVM.this.response().domain()));
                pSVMModel2.delete_and_lock(PSVM.this._job);
                int rankICF = PSVM.this.getRankICF(((PSVMModel.PSVMParameters) PSVM.this._parms)._rank_ratio, adaptTrain._adaptedFrame.numRows());
                Log.info(new Object[]{"Desired rank of ICF matrix = " + rankICF});
                PSVM.this._job.update(0L, "Running Incomplete Cholesky Factorization");
                Frame icf = IncompleteCholeskyFactorization.icf(adaptTrain, ((PSVMModel.PSVMParameters) PSVM.this._parms).kernel(), rankICF, ((PSVMModel.PSVMParameters) PSVM.this._parms)._fact_threshold);
                Scope.track(new Frame[]{icf});
                PSVM.this._job.update(0L, "Running IPM");
                IPMInfo iPMInfo = new IPMInfo();
                Vec solve = PrimalDualIPM.solve(icf, vec, ((PSVMModel.PSVMParameters) PSVM.this._parms).ipmParms(), iPMInfo);
                icf.remove();
                Log.info(new Object[]{"IPM finished"});
                Vec updateStats = ((RegulateAlphaTask) new RegulateAlphaTask(((PSVMModel.PSVMParameters) PSVM.this._parms).c_pos(), ((PSVMModel.PSVMParameters) PSVM.this._parms).c_neg(), ((PSVMModel.PSVMParameters) PSVM.this._parms)._sv_threshold).doAll((byte) 3, new Vec[]{solve, vec})).updateStats((PSVMModel.PSVMModelOutput) pSVMModel2._output);
                if (!$assertionsDisabled && updateStats.length() != ((PSVMModel.PSVMModelOutput) pSVMModel2._output)._svs_count) {
                    throw new AssertionError();
                }
                Scope.track(updateStats);
                Frame frame = new Frame(Key.make(pSVMModel2._key + "_alpha"));
                frame.add("alpha", solve);
                DKV.put(frame);
                ((PSVMModel.PSVMModelOutput) pSVMModel2._output)._alpha_key = frame._key;
                int min = (int) Math.min(updateStats.length(), 1000L);
                DataInfo.Row[] selected = ((CollectSupportVecSamplesTask) new CollectSupportVecSamplesTask(adaptTrain, updateStats, min).doAll(adaptTrain._adaptedFrame)).getSelected();
                ((PSVMModel.PSVMModelOutput) pSVMModel2._output)._rho = ((CalculateRhoTask) new CalculateRhoTask(adaptTrain, selected, solve, ((PSVMModel.PSVMParameters) PSVM.this._parms).kernel()).doAll(adaptTrain._adaptedFrame)).getRho();
                long j = 0;
                for (DataInfo.Row row : selected) {
                    j += PSVM.toSupportVector(0.0d, row).estimateSize();
                }
                if (updateStats.length() > min) {
                    j = (j * updateStats.length()) / min;
                }
                boolean z = j >= 2147483647L;
                if (z) {
                    Log.err(new Object[]{"Estimated model size (" + j + "B) exceeds limits of DKV. Support vectors will not be stored."});
                    pSVMModel2.addWarning("Model too big (size " + j + "B) exceeds maximum model size. Support vectors will not be stored as a part of the model. You can still inspect what vectors were chosen and what are their alpha coefficients (see Frame alpha in model output).");
                    ((PSVMModel.PSVMModelOutput) pSVMModel2._output)._compressed_svs = new byte[0];
                } else {
                    Frame frame2 = new Frame(adaptTrain._adaptedFrame);
                    frame2.add("__alpha", solve);
                    ((PSVMModel.PSVMModelOutput) pSVMModel2._output)._compressed_svs = ((CompressVectorsTask) new CompressVectorsTask(adaptTrain).doAll(frame2))._csvs;
                    if (!$assertionsDisabled && updateStats.length() <= min && ((PSVMModel.PSVMModelOutput) pSVMModel2._output)._compressed_svs.length != j) {
                        throw new AssertionError();
                    }
                }
                Log.info(new Object[]{"Total #support vectors: " + ((PSVMModel.PSVMModelOutput) pSVMModel2._output)._svs_count + " (size in memory " + j + "B)"});
                ((PSVMModel.PSVMModelOutput) pSVMModel2._output)._model_summary = PSVM.createModelSummaryTable((PSVMModel.PSVMModelOutput) pSVMModel2._output, iPMInfo);
                pSVMModel2.update(PSVM.this._job);
                if (!z) {
                    if (((PSVMModel.PSVMParameters) PSVM.this._parms)._disable_training_metrics) {
                        Log.warn(new Object[]{"Not creating training metrics: scoring disabled (use disable_training_metrics = false to override)"});
                        pSVMModel2.addWarning("Not creating training metrics: scoring disabled (use disable_training_metrics = false to override)");
                    } else {
                        PSVM.this._job.update(0L, "Scoring training frame");
                        Frame frame3 = new Frame(PSVM.this.train());
                        pSVMModel2.adaptTestForTrain(frame3, true, true);
                        ((PSVMModel.PSVMModelOutput) pSVMModel2._output)._training_metrics = pSVMModel2.makeModelMetrics(PSVM.this.train(), frame3, "Training metrics");
                    }
                    if (PSVM.this.valid() != null) {
                        PSVM.this._job.update(0L, "Scoring validation frame");
                        Frame frame4 = new Frame(PSVM.this.valid());
                        pSVMModel2.adaptTestForTrain(frame4, true, true);
                        ((PSVMModel.PSVMModelOutput) pSVMModel2._output)._validation_metrics = pSVMModel2.makeModelMetrics(PSVM.this.valid(), frame4, "Validation metrics");
                    }
                }
                Scope.untrack(new Key[]{solve._key});
                if (pSVMModel2 != null) {
                    pSVMModel2.unlock(PSVM.this._job);
                }
            } catch (Throwable th) {
                if (0 != 0) {
                    pSVMModel.unlock(PSVM.this._job);
                }
                throw th;
            }
        }

        static {
            $assertionsDisabled = !PSVM.class.desiredAssertionStatus();
        }
    }

    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Binomial};
    }

    public boolean isSupervised() {
        return true;
    }

    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Experimental;
    }

    public PSVM(boolean z) {
        super(new PSVMModel.PSVMParameters(), z);
    }

    public PSVM(PSVMModel.PSVMParameters pSVMParameters) {
        super(pSVMParameters);
        init(false);
    }

    public void init(boolean z) {
        super.init(z);
        if (z) {
        }
    }

    public void checkDistributions() {
        if (this._response.isCategorical()) {
            if (this._response.cardinality() != 2) {
                error("_response", "Expected a binary categorical response, but instead got response with " + this._response.cardinality() + " categories.");
            }
        } else {
            if (this._response.min() == -1.0d && this._response.max() == 1.0d && this._response.isInt() && this._response.nzCnt() == this._response.length()) {
                return;
            }
            error("_response", "Non-categorical response provided, please make sure the response is either binary categorical response or uses only values -1/+1 in case of numerical response.");
        }
    }

    protected boolean computePriorClassDistribution() {
        return false;
    }

    protected int init_getNClass() {
        return 2;
    }

    protected ModelBuilder<PSVMModel, PSVMModel.PSVMParameters, PSVMModel.PSVMModelOutput>.Driver trainModelImpl() {
        return new SVMDriver();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public int getRankICF(double d, long j) {
        return d == -1.0d ? (int) Math.sqrt(j) : (int) (d * j);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static SupportVector toSupportVector(double d, DataInfo.Row row) {
        if (row.isSparse()) {
            throw new UnsupportedOperationException("Sparse rows are not yet supported");
        }
        return new SupportVector().fill(d, row.numVals, row.binIds);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static TwoDimTable createModelSummaryTable(PSVMModel.PSVMModelOutput pSVMModelOutput, IPMInfo iPMInfo) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        arrayList.add("Number of Support Vectors");
        arrayList2.add("long");
        arrayList3.add("%d");
        arrayList.add("Number of Bounded Support Vectors");
        arrayList2.add("long");
        arrayList3.add("%d");
        arrayList.add("Raw Model Size in Bytes");
        arrayList2.add("long");
        arrayList3.add("%d");
        arrayList.add("rho");
        arrayList2.add("double");
        arrayList3.add("%.5f");
        arrayList.add("Number of Iterations");
        arrayList2.add("long");
        arrayList3.add("%d");
        arrayList.add("Surrogate Gap");
        arrayList2.add("double");
        arrayList3.add("%.5f");
        arrayList.add("Primal Residual");
        arrayList2.add("double");
        arrayList3.add("%.5f");
        arrayList.add("Dual Residual");
        arrayList2.add("double");
        arrayList3.add("%.5f");
        TwoDimTable twoDimTable = new TwoDimTable("Model Summary", (String) null, new String[1], (String[]) arrayList.toArray(new String[0]), (String[]) arrayList2.toArray(new String[0]), (String[]) arrayList3.toArray(new String[0]), "");
        int i = 0 + 1;
        twoDimTable.set(0, 0, Long.valueOf(pSVMModelOutput._svs_count));
        int i2 = i + 1;
        twoDimTable.set(0, i, Long.valueOf(pSVMModelOutput._bsv_count));
        int i3 = i2 + 1;
        twoDimTable.set(0, i2, Integer.valueOf(pSVMModelOutput._compressed_svs != null ? pSVMModelOutput._compressed_svs.length : -1));
        int i4 = i3 + 1;
        twoDimTable.set(0, i3, Double.valueOf(pSVMModelOutput._rho));
        int i5 = i4 + 1;
        twoDimTable.set(0, i4, Integer.valueOf(iPMInfo._iter));
        int i6 = i5 + 1;
        twoDimTable.set(0, i5, Double.valueOf(iPMInfo._sgap));
        int i7 = i6 + 1;
        twoDimTable.set(0, i6, Double.valueOf(iPMInfo._resp));
        int i8 = i7 + 1;
        twoDimTable.set(0, i7, Double.valueOf(iPMInfo._resd));
        if ($assertionsDisabled || i8 == arrayList.size()) {
            return twoDimTable;
        }
        throw new AssertionError();
    }

    static {
        $assertionsDisabled = !PSVM.class.desiredAssertionStatus();
    }
}
