package hex.coxph;

import hex.DataInfo;
import hex.Model;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.ModelMetricsRegression;
import hex.StringPair;
import hex.coxph.Storage;
import hex.schemas.CoxPHModelV3;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import water.AutoBuffer;
import water.DKV;
import water.Futures;
import water.Job;
import water.Key;
import water.MRTask;
import water.api.schemas3.ModelSchemaV3;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.rapids.ast.prims.mungers.AstGroup;
import water.udf.CFuncRef;
import water.util.ArrayUtils;
import water.util.IcedHashMap;
import water.util.IcedInt;

/* loaded from: input_file:hex/coxph/CoxPHModel.class */
public class CoxPHModel extends Model<CoxPHModel, CoxPHParameters, CoxPHOutput> {

    /* loaded from: input_file:hex/coxph/CoxPHModel$CoxPHOutput.class */
    public static class CoxPHOutput extends Model.Output {
        Model.InteractionSpec _interactionSpec;
        DataInfo data_info;
        IcedHashMap<AstGroup.G, IcedInt> _strataMap;
        String[] _strataOnlyCols;
        public String[] _coef_names;
        public double[] _coef;
        public double[] _exp_coef;
        public double[] _exp_neg_coef;
        public double[] _se_coef;
        public double[] _z_coef;
        double[][] _var_coef;
        double _null_loglik;
        double _loglik;
        double _loglik_test;
        double _wald_test;
        double _score_test;
        double _rsq;
        double _maxrsq;
        double _lre;
        int _iter;
        double[][] _x_mean_cat;
        double[][] _x_mean_num;
        double[] _mean_offset;
        String[] _offset_names;
        long _n;
        long _n_missing;
        long _total_event;
        double[] _time;
        double[] _n_risk;
        double[] _n_event;
        double[] _n_censor;
        double[] _cumhaz_0;
        double[] _var_cumhaz_1;
        FrameMatrix _var_cumhaz_2_matrix;
        Key<Frame> _var_cumhaz_2;
        CoxPHParameters.CoxPHTies _ties;
        String _formula;

        /* loaded from: input_file:hex/coxph/CoxPHModel$CoxPHOutput$CoxPHInteractionBuilder.class */
        private class CoxPHInteractionBuilder implements Model.InteractionBuilder {
            private CoxPHInteractionBuilder() {
            }

            public Frame makeInteractions(Frame frame) {
                frame.add(Model.makeInteractions(frame, false, CoxPHOutput.this._interactionSpec.makeInteractionPairs(frame), CoxPHOutput.this.data_info._useAllFactorLevels, CoxPHOutput.this.data_info._skipMissing, CoxPHOutput.this.data_info._predictor_transform == DataInfo.TransformType.STANDARDIZE));
                return frame;
            }
        }

        public CoxPHOutput(CoxPH coxPH, Frame frame, Frame frame2, IcedHashMap<AstGroup.G, IcedInt> icedHashMap) {
            super(coxPH, fullFrame(coxPH, frame, frame2));
            this._strataOnlyCols = new String[this._names.length - frame._names.length];
            for (int i = 0; i < this._strataOnlyCols.length; i++) {
                this._strataOnlyCols[i] = this._names[i];
            }
            this._ties = ((CoxPHParameters) coxPH._parms)._ties;
            this._formula = ((CoxPHParameters) coxPH._parms).toFormula(frame2);
            this._interactionSpec = ((CoxPHParameters) coxPH._parms).interactionSpec();
            this._strataMap = icedHashMap;
        }

        private static Frame fullFrame(CoxPH coxPH, Frame frame, Frame frame2) {
            if (!((CoxPHParameters) coxPH._parms).isStratified()) {
                return frame;
            }
            Frame frame3 = new Frame(new Vec[0]);
            for (String str : ((CoxPHParameters) coxPH._parms)._stratify_by) {
                if (frame.vec(str) == null) {
                    frame3.add(str, frame2.vec(str));
                }
            }
            frame3.add(frame);
            return frame3;
        }

        public ModelCategory getModelCategory() {
            return ModelCategory.CoxPH;
        }

        public Model.InteractionBuilder interactionBuilder() {
            if (this._interactionSpec != null) {
                return new CoxPHInteractionBuilder();
            }
            return null;
        }
    }

    /* loaded from: input_file:hex/coxph/CoxPHModel$CoxPHParameters.class */
    public static class CoxPHParameters extends Model.Parameters {
        public String _start_column;
        public String _stop_column;
        public String[] _stratify_by;
        public boolean _use_all_factor_levels;
        public String[] _interactions_only;
        final String _strata_column = "__strata";
        public CoxPHTies _ties = CoxPHTies.efron;
        public double _init = 0.0d;
        public double _lre_min = 9.0d;
        public int _max_iterations = 20;
        public String[] _interactions = null;
        public StringPair[] _interaction_pairs = null;
        public boolean _calc_cumhaz = true;

        /* loaded from: input_file:hex/coxph/CoxPHModel$CoxPHParameters$CoxPHTies.class */
        public enum CoxPHTies {
            efron,
            breslow
        }

        public String algoName() {
            return "CoxPH";
        }

        public String fullName() {
            return "Cox Proportional Hazards";
        }

        public String javaName() {
            return CoxPHModel.class.getName();
        }

        public long progressUnits() {
            return ((this._max_iterations + 1) * 2) + 1;
        }

        String[] responseCols() {
            String[] strArr = this._start_column != null ? new String[]{this._start_column} : new String[0];
            if (isStratified()) {
                strArr = (String[]) ArrayUtils.append(strArr, new String[]{this._start_column});
            }
            return (String[]) ArrayUtils.append(strArr, new String[]{this._stop_column, this._response_column});
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public Vec startVec() {
            return train().vec(this._start_column);
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public Vec stopVec() {
            return train().vec(this._stop_column);
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public Model.InteractionSpec interactionSpec() {
            String[] strArr;
            if (this._interactions_only == null || this._stratify_by == null) {
                strArr = this._interactions_only != null ? this._interactions_only : this._stratify_by;
            } else {
                String[] strArr2 = (String[]) this._interactions_only.clone();
                Arrays.sort(strArr2);
                String[] strArr3 = (String[]) this._stratify_by.clone();
                Arrays.sort(strArr3);
                strArr = ArrayUtils.union(strArr2, strArr3, true);
            }
            return Model.InteractionSpec.create(this._interactions, this._interaction_pairs, strArr, this._stratify_by);
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public boolean isStratified() {
            return this._stratify_by != null && this._stratify_by.length > 0;
        }

        String toFormula(Frame frame) {
            StringBuilder sb = new StringBuilder();
            sb.append("Surv(");
            if (this._start_column != null) {
                sb.append(this._start_column).append(", ");
            }
            sb.append(this._stop_column).append(", ").append(this._response_column);
            sb.append(") ~ ");
            Set hashSet = this._stratify_by != null ? new HashSet(Arrays.asList(this._stratify_by)) : Collections.emptySet();
            Set hashSet2 = this._interactions_only != null ? new HashSet(Arrays.asList(this._interactions_only)) : Collections.emptySet();
            HashSet<String> hashSet3 = new HashSet<String>() { // from class: hex.coxph.CoxPHModel.CoxPHParameters.1
                {
                    add(CoxPHParameters.this._start_column);
                    if (CoxPHParameters.this._stop_column != null) {
                        add(CoxPHParameters.this._stop_column);
                    }
                    add(CoxPHParameters.this._response_column);
                    add("__strata");
                    if (CoxPHParameters.this._weights_column != null) {
                        add(CoxPHParameters.this._weights_column);
                    }
                    if (CoxPHParameters.this._ignored_columns != null) {
                        addAll(Arrays.asList(CoxPHParameters.this._ignored_columns));
                    }
                }
            };
            String str = "";
            for (String str2 : frame._names) {
                if ((this._offset_column == null || !this._offset_column.equals(str2)) && !hashSet.contains(str2) && !hashSet2.contains(str2) && !hashSet3.contains(str2)) {
                    sb.append(str).append(str2);
                    str = " + ";
                }
            }
            if (this._offset_column != null) {
                sb.append(str).append("offset(").append(this._offset_column).append(")");
            }
            if (interactionSpec() != null) {
                for (Model.InteractionPair interactionPair : interactionSpec().makeInteractionPairs(frame)) {
                    sb.append(str);
                    String str3 = frame._names[interactionPair.getV1()];
                    String str4 = frame._names[interactionPair.getV2()];
                    if (hashSet.contains(str3)) {
                        sb.append("strata(").append(str3).append(")");
                    } else {
                        sb.append(str3);
                    }
                    sb.append(":");
                    if (hashSet.contains(str4)) {
                        sb.append("strata(").append(str4).append(")");
                    } else {
                        sb.append(str4);
                    }
                    str = " + ";
                }
            }
            if (this._stratify_by != null) {
                String sb2 = sb.toString();
                for (String str5 : this._stratify_by) {
                    String str6 = "strata(" + str5 + ")";
                    if (!sb2.contains(str6)) {
                        sb.append(str).append(str6);
                        str = " + ";
                    }
                }
            }
            return sb.toString();
        }
    }

    /* loaded from: input_file:hex/coxph/CoxPHModel$CoxPHScore.class */
    private static class CoxPHScore extends MRTask<CoxPHScore> {
        private DataInfo _dinfo;
        private double[] _coef;
        private double[] _lpBase;
        private int _numStart;
        private boolean _hasStrata;

        private CoxPHScore(DataInfo dataInfo, CoxPHOutput coxPHOutput, boolean z) {
            int length = coxPHOutput._x_mean_cat.length;
            this._dinfo = dataInfo;
            this._hasStrata = z;
            this._coef = coxPHOutput._coef;
            this._numStart = coxPHOutput._x_mean_cat[0].length;
            this._lpBase = new double[length];
            for (int i = 0; i < length; i++) {
                for (int i2 = 0; i2 < coxPHOutput._x_mean_cat[i].length; i2++) {
                    double[] dArr = this._lpBase;
                    int i3 = i;
                    dArr[i3] = dArr[i3] + (coxPHOutput._x_mean_cat[i][i2] * this._coef[i2]);
                }
                for (int i4 = 0; i4 < coxPHOutput._x_mean_num[i].length; i4++) {
                    double[] dArr2 = this._lpBase;
                    int i5 = i;
                    dArr2[i5] = dArr2[i5] + (coxPHOutput._x_mean_num[i][i4] * this._coef[i4 + this._numStart]);
                }
            }
        }

        public void map(Chunk[] chunkArr, NewChunk newChunk) {
            DataInfo.Row newDenseRow = this._dinfo.newDenseRow();
            for (int i = 0; i < chunkArr[0]._len; i++) {
                this._dinfo.extractDenseRow(chunkArr, i, newDenseRow);
                if (newDenseRow.predictors_bad) {
                    newChunk.addNA();
                } else {
                    double atd = this._hasStrata ? chunkArr[this._dinfo.responseChunkId(0)].atd(i) : 0.0d;
                    if (Double.isNaN(atd)) {
                        newChunk.addNA();
                    } else {
                        newChunk.addNum(newDenseRow.innerProduct(this._coef) - this._lpBase[(int) atd]);
                    }
                }
            }
        }
    }

    /* loaded from: input_file:hex/coxph/CoxPHModel$FrameMatrix.class */
    public static class FrameMatrix extends Storage.DenseRowMatrix {
        Key<Frame> _frame_key;

        /* JADX INFO: Access modifiers changed from: package-private */
        public FrameMatrix(Key<Frame> key, int i, int i2) {
            super(i, i2);
            this._frame_key = key;
        }

        public final AutoBuffer write_impl(AutoBuffer autoBuffer) {
            Key.write_impl(this._frame_key, autoBuffer);
            return autoBuffer;
        }

        public final FrameMatrix read_impl(AutoBuffer autoBuffer) {
            this._frame_key = Key.read_impl((Key) null, autoBuffer);
            if (DKV.getGet(this._frame_key) == null) {
                toFrame(this._frame_key);
            }
            return this;
        }
    }

    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] strArr) {
        return new ModelMetricsRegression.MetricBuilderRegression();
    }

    public ModelSchemaV3 schema() {
        return new CoxPHModelV3();
    }

    public CoxPHModel(Key key, CoxPHParameters coxPHParameters, CoxPHOutput coxPHOutput) {
        super(key, coxPHParameters, coxPHOutput);
    }

    protected Frame predictScoreImpl(Frame frame, Frame frame2, String str, Job job, boolean z, CFuncRef cFuncRef) {
        int i = 0;
        for (String str2 : ((CoxPHParameters) this._parms).responseCols()) {
            if (frame2.find(str2) != -1) {
                i++;
            }
        }
        DataInfo scoringInfo = ((CoxPHOutput) this._output).data_info.scoringInfo(((CoxPHOutput) this._output)._names, frame2, i, false);
        return ((CoxPHScore) new CoxPHScore(scoringInfo, (CoxPHOutput) this._output, ((CoxPHParameters) this._parms).isStratified()).doAll((byte) 3, scoringInfo._adaptedFrame)).outputFrame(Key.make(str), new String[]{"lp"}, (String[][]) null);
    }

    /* JADX WARN: Removed duplicated region for block: B:11:0x006a  */
    /* JADX WARN: Removed duplicated region for block: B:8:0x002d  */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public java.lang.String[] adaptTestForTrain(water.fvec.Frame r6, boolean r7, boolean r8) {
        /*
            r5 = this;
            r0 = r5
            hex.Model$Parameters r0 = r0._parms
            hex.coxph.CoxPHModel$CoxPHParameters r0 = (hex.coxph.CoxPHModel.CoxPHParameters) r0
            boolean r0 = r0.isStratified()
            if (r0 == 0) goto L25
            r0 = r6
            r1 = r5
            hex.Model$Parameters r1 = r1._parms
            hex.coxph.CoxPHModel$CoxPHParameters r1 = (hex.coxph.CoxPHModel.CoxPHParameters) r1
            java.lang.Class r1 = r1.getClass()
            java.lang.String r1 = "__strata"
            water.fvec.Vec r0 = r0.vec(r1)
            if (r0 != 0) goto L25
            r0 = 1
            goto L26
        L25:
            r0 = 0
        L26:
            r9 = r0
            r0 = r9
            if (r0 == 0) goto L5c
            r0 = r6
            water.fvec.Vec r0 = r0.anyVec()
            r1 = 9221120237041090560(0x7ff8000000000000, double:NaN)
            water.fvec.Vec r0 = r0.makeCon(r1)
            r10 = r0
            r0 = r5
            water.util.IcedHashMap r0 = r0._toDelete
            r1 = r10
            water.Key r1 = r1._key
            java.lang.String r2 = "adapted missing strata vector"
            java.lang.Object r0 = r0.put(r1, r2)
            r0 = r6
            r1 = r5
            hex.Model$Parameters r1 = r1._parms
            hex.coxph.CoxPHModel$CoxPHParameters r1 = (hex.coxph.CoxPHModel.CoxPHParameters) r1
            java.lang.Class r1 = r1.getClass()
            java.lang.String r1 = "__strata"
            r2 = r10
            water.fvec.Vec r0 = r0.add(r1, r2)
        L5c:
            r0 = r5
            r1 = r6
            r2 = r7
            r3 = r8
            java.lang.String[] r0 = super.adaptTestForTrain(r1, r2, r3)
            r10 = r0
            r0 = r9
            if (r0 == 0) goto Lc7
            r0 = r6
            r1 = r5
            hex.Model$Parameters r1 = r1._parms
            hex.coxph.CoxPHModel$CoxPHParameters r1 = (hex.coxph.CoxPHModel.CoxPHParameters) r1
            java.lang.String[] r1 = r1._stratify_by
            r2 = r5
            hex.Model$Output r2 = r2._output
            hex.coxph.CoxPHModel$CoxPHOutput r2 = (hex.coxph.CoxPHModel.CoxPHOutput) r2
            water.util.IcedHashMap<water.rapids.ast.prims.mungers.AstGroup$G, water.util.IcedInt> r2 = r2._strataMap
            water.fvec.Vec r0 = hex.coxph.CoxPH.StrataTask.makeStrataVec(r0, r1, r2)
            r11 = r0
            r0 = r5
            water.util.IcedHashMap r0 = r0._toDelete
            r1 = r11
            water.Key r1 = r1._key
            java.lang.String r2 = "adapted missing strata vector"
            java.lang.Object r0 = r0.put(r1, r2)
            r0 = r6
            r1 = r6
            r2 = r5
            hex.Model$Parameters r2 = r2._parms
            hex.coxph.CoxPHModel$CoxPHParameters r2 = (hex.coxph.CoxPHModel.CoxPHParameters) r2
            java.lang.Class r2 = r2.getClass()
            java.lang.String r2 = "__strata"
            int r1 = r1.find(r2)
            r2 = r11
            water.fvec.Vec r0 = r0.replace(r1, r2)
            r0 = r5
            hex.Model$Output r0 = r0._output
            hex.coxph.CoxPHModel$CoxPHOutput r0 = (hex.coxph.CoxPHModel.CoxPHOutput) r0
            java.lang.String[] r0 = r0._strataOnlyCols
            if (r0 == 0) goto Lc7
            r0 = r6
            r1 = r5
            hex.Model$Output r1 = r1._output
            hex.coxph.CoxPHModel$CoxPHOutput r1 = (hex.coxph.CoxPHModel.CoxPHOutput) r1
            java.lang.String[] r1 = r1._strataOnlyCols
            water.fvec.Frame r0 = r0.remove(r1)
        Lc7:
            r0 = r10
            return r0
        */
        throw new UnsupportedOperationException("Method not decompiled: hex.coxph.CoxPHModel.adaptTestForTrain(water.fvec.Frame, boolean, boolean):java.lang.String[]");
    }

    public double[] score0(double[] dArr, double[] dArr2) {
        throw new UnsupportedOperationException("CoxPHModel.score0 should never be called");
    }

    protected Futures remove_impl(Futures futures, boolean z) {
        Frame frame = ((CoxPHOutput) this._output)._var_cumhaz_2 != null ? (Frame) ((CoxPHOutput) this._output)._var_cumhaz_2.get() : null;
        if (frame != null) {
            frame.remove(futures);
        }
        super.remove_impl(futures, z);
        return futures;
    }
}
