package hex.tree.gbm;

import hex.Model;
import hex.ModelBuilder;
import hex.genmodel.GenModel;
import hex.schemas.GBMV3;
import hex.tree.DHistogram;
import hex.tree.DTree;
import hex.tree.SharedTree;
import hex.tree.TreeJCodeGen;
import hex.tree.gbm.GBMModel;
import water.AutoBuffer;
import water.Job;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.Timer;

/* loaded from: input_file:hex/tree/gbm/GBM.class */
public class GBM extends SharedTree<GBMModel, GBMModel.GBMParameters, GBMModel.GBMOutput> {

    /* loaded from: input_file:hex/tree/gbm/GBM$GBMDecidedNode.class */
    static class GBMDecidedNode extends DTree.DecidedNode {
        GBMDecidedNode(DTree.UndecidedNode undecidedNode, DHistogram[] dHistogramArr) {
            super(undecidedNode, dHistogramArr);
        }

        @Override // hex.tree.DTree.DecidedNode
        public DTree.UndecidedNode makeUndecidedNode(DHistogram[] dHistogramArr) {
            return new GBMUndecidedNode(this._tree, this._nid, dHistogramArr);
        }

        @Override // hex.tree.DTree.DecidedNode
        public DTree.Split bestCol(DTree.UndecidedNode undecidedNode, DHistogram[] dHistogramArr) {
            DTree.Split scoreMSE;
            DTree.Split split = new DTree.Split(-1, -1, null, (byte) 0, Double.MAX_VALUE, Double.MAX_VALUE, Double.MAX_VALUE, 0L, 0L, 0.0d, 0.0d);
            if (dHistogramArr == null) {
                return split;
            }
            for (int i = 0; i < dHistogramArr.length; i++) {
                if (dHistogramArr[i] != null && dHistogramArr[i].nbins() > 1 && (scoreMSE = dHistogramArr[i].scoreMSE(i, this._tree._min_rows)) != null) {
                    if (scoreMSE.se() < split.se()) {
                        split = scoreMSE;
                    }
                    if (scoreMSE.se() <= 0.0d) {
                        break;
                    }
                }
            }
            return split;
        }
    }

    /* loaded from: input_file:hex/tree/gbm/GBM$GBMDriver.class */
    private class GBMDriver extends SharedTree<GBMModel, GBMModel.GBMParameters, GBMModel.GBMOutput>.Driver {
        static final /* synthetic */ boolean $assertionsDisabled;

        /* loaded from: input_file:hex/tree/gbm/GBM$GBMDriver$ComputeProb.class */
        class ComputeProb extends MRTask<ComputeProb> {
            ComputeProb() {
            }

            public void map(Chunk[] chunkArr) {
                Chunk chk_resp = GBM.this.chk_resp(chunkArr);
                if (GBM.this._parms._distribution == GBMModel.GBMParameters.Family.bernoulli) {
                    Chunk chk_tree = GBM.this.chk_tree(chunkArr, 0);
                    Chunk chk_work = GBM.this.chk_work(chunkArr, 0);
                    for (int i = 0; i < chk_resp._len; i++) {
                        chk_work.set(i, 1.0d / (1.0d + Math.exp(chk_tree.atd(i))));
                    }
                    return;
                }
                if (GBM.this._nclass <= 1) {
                    Chunk chk_tree2 = GBM.this.chk_tree(chunkArr, 0);
                    Chunk chk_work2 = GBM.this.chk_work(chunkArr, 0);
                    for (int i2 = 0; i2 < chk_resp._len; i2++) {
                        chk_work2.set(i2, (float) chk_tree2.atd(i2));
                    }
                    return;
                }
                double[] dArr = new double[GBM.this._nclass + 1];
                for (int i3 = 0; i3 < chk_resp._len; i3++) {
                    double score1 = GBM.this.score1(chunkArr, dArr, i3);
                    if (Double.isInfinite(score1)) {
                        for (int i4 = 0; i4 < GBM.this._nclass; i4++) {
                            GBM.this.chk_work(chunkArr, i4).set(i3, Double.isInfinite(dArr[i4 + 1]) ? 1.0f : 0.0f);
                        }
                    } else {
                        for (int i5 = 0; i5 < GBM.this._nclass; i5++) {
                            GBM.this.chk_work(chunkArr, i5).set(i3, (float) (dArr[i5 + 1] / score1));
                        }
                    }
                }
            }
        }

        /* loaded from: input_file:hex/tree/gbm/GBM$GBMDriver$ComputeRes.class */
        class ComputeRes extends MRTask<ComputeRes> {
            ComputeRes() {
            }

            public void map(Chunk[] chunkArr) {
                Chunk chk_resp = GBM.this.chk_resp(chunkArr);
                if (GBM.this._parms._distribution == GBMModel.GBMParameters.Family.bernoulli) {
                    for (int i = 0; i < chk_resp._len; i++) {
                        if (!chk_resp.isNA(i)) {
                            int at8 = (int) chk_resp.at8(i);
                            Chunk chk_work = GBM.this.chk_work(chunkArr, 0);
                            chk_work.set(i, (at8 - 1.0f) + ((float) chk_work.atd(i)));
                        }
                    }
                    return;
                }
                if (GBM.this._nclass <= 1) {
                    Chunk chk_work2 = GBM.this.chk_work(chunkArr, 0);
                    for (int i2 = 0; i2 < chk_resp._len; i2++) {
                        chk_work2.set(i2, (float) (chk_resp.atd(i2) - chk_work2.atd(i2)));
                    }
                    return;
                }
                for (int i3 = 0; i3 < chk_resp._len; i3++) {
                    if (!chk_resp.isNA(i3)) {
                        int at82 = (int) chk_resp.at8(i3);
                        int i4 = 0;
                        while (i4 < GBM.this._nclass) {
                            if (((GBMModel) GBM.this._model)._output._distribution[i4] != 0) {
                                Chunk chk_work3 = GBM.this.chk_work(chunkArr, i4);
                                chk_work3.set(i3, (at82 == i4 ? 1.0f : 0.0f) - ((float) chk_work3.atd(i3)));
                            }
                            i4++;
                        }
                    }
                }
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:hex/tree/gbm/GBM$GBMDriver$GammaPass.class */
        public class GammaPass extends MRTask<GammaPass> {
            final DTree[] _trees;
            final int[] _leafs;
            final boolean _isBernoulli;
            double[][] _rss;
            double[][] _gss;
            static final /* synthetic */ boolean $assertionsDisabled;

            GammaPass(DTree[] dTreeArr, int[] iArr, boolean z) {
                this._leafs = iArr;
                this._trees = dTreeArr;
                this._isBernoulli = z;
            }

            /* JADX WARN: Type inference failed for: r1v4, types: [double[], double[][]] */
            /* JADX WARN: Type inference failed for: r1v9, types: [double[], double[][]] */
            public void map(Chunk[] chunkArr) {
                this._gss = new double[GBM.this._nclass];
                this._rss = new double[GBM.this._nclass];
                Chunk chk_resp = GBM.this.chk_resp(chunkArr);
                for (int i = 0; i < GBM.this._nclass; i++) {
                    DTree dTree = this._trees[i];
                    int i2 = this._leafs[i];
                    if (dTree != null) {
                        double[] dArr = new double[dTree._len - i2];
                        this._gss[i] = dArr;
                        double[] dArr2 = new double[dTree._len - i2];
                        this._rss[i] = dArr2;
                        Chunk chk_nids = GBM.this.chk_nids(chunkArr, i);
                        Chunk chk_work = GBM.this.chk_work(chunkArr, i);
                        if (dTree.root() instanceof DTree.LeafNode) {
                            continue;
                        } else {
                            for (int i3 = 0; i3 < chk_nids._len; i3++) {
                                int at8 = (int) chk_nids.at8(i3);
                                if (at8 >= 0) {
                                    if (dTree.node(at8) instanceof DTree.UndecidedNode) {
                                        at8 = dTree.node(at8)._pid;
                                    }
                                    DTree.DecidedNode decided = dTree.decided(at8);
                                    if (decided._split._col == -1) {
                                        decided = dTree.decided(decided._pid);
                                    }
                                    int ns = decided.ns(chunkArr, i3);
                                    if (!$assertionsDisabled && (i2 > ns || ns >= dTree._len)) {
                                        throw new AssertionError("leaf: " + i2 + " leafnid: " + ns + " tree._len: " + dTree._len + "\ndn: " + decided);
                                    }
                                    if (!$assertionsDisabled && !(dTree.node(ns) instanceof DTree.LeafNode)) {
                                        throw new AssertionError();
                                    }
                                    chk_nids.set(i3, ns);
                                    if (!$assertionsDisabled && chk_work.isNA(i3)) {
                                        throw new AssertionError();
                                    }
                                    double atd = chk_work.atd(i3);
                                    double abs = Math.abs(atd);
                                    if (this._isBernoulli) {
                                        double atd2 = chk_resp.atd(i3) - atd;
                                        int i4 = ns - i2;
                                        dArr[i4] = dArr[i4] + (atd2 * (1.0d - atd2));
                                    } else {
                                        int i5 = ns - i2;
                                        dArr[i5] = dArr[i5] + (GBM.this._nclass > 1 ? abs * (1.0d - abs) : 1.0d);
                                    }
                                    int i6 = ns - i2;
                                    dArr2[i6] = dArr2[i6] + atd;
                                }
                            }
                        }
                    }
                }
            }

            public void reduce(GammaPass gammaPass) {
                ArrayUtils.add(this._gss, gammaPass._gss);
                ArrayUtils.add(this._rss, gammaPass._rss);
            }

            static {
                $assertionsDisabled = !GBM.class.desiredAssertionStatus();
            }
        }

        private GBMDriver() {
            super();
        }

        /* JADX WARN: Type inference failed for: r0v44, types: [hex.tree.gbm.GBM$GBMDriver$1] */
        @Override // hex.tree.SharedTree.Driver
        protected void buildModel() {
            final double d = GBM.this._initialPrediction;
            if (d != 0.0d) {
                new MRTask() { // from class: hex.tree.gbm.GBM.GBMDriver.1
                    public void map(Chunk chunk) {
                        for (int i = 0; i < chunk._len; i++) {
                            chunk.set(i, d);
                        }
                    }
                }.doAll(new Vec[]{GBM.this.vec_tree(GBM.this._train, 0)});
            }
            if (GBM.this._parms._checkpoint) {
                Timer timer = new Timer();
                new ResidualsCollector(GBM.this._ncols, GBM.this._nclass, ((GBMModel) GBM.this._model)._output._treeKeys).doAll(GBM.this._train);
                Log.info(new Object[]{"Reconstructing tree residuals stats from checkpointed model took " + timer});
            }
            for (int i = 0; i < GBM.this._parms._ntrees; i++) {
                if (!(i == 0 && GBM.this._parms._checkpoint) && GBM.this.doScoringAndSaveModel(false, false, false) >= GBM.this._parms._r2_stopping) {
                    return;
                }
                new ComputeProb().doAll(GBM.this._train);
                new ComputeRes().doAll(GBM.this._train);
                Timer timer2 = new Timer();
                buildNextKTrees();
                Log.info(new Object[]{(i + 1) + ". tree was built in " + timer2.toString()});
                GBM.this.update(1L);
                if (!GBM.this.isRunning()) {
                    return;
                }
            }
            GBM.this.doScoringAndSaveModel(true, false, false);
        }

        /* JADX WARN: Type inference failed for: r0v31, types: [hex.tree.gbm.GBM$GBMDriver$2] */
        private void buildNextKTrees() {
            final DTree[] dTreeArr = new DTree[GBM.this._nclass];
            DHistogram[][][] dHistogramArr = new DHistogram[GBM.this._nclass][1][GBM.this._ncols];
            int max = Math.max(TreeJCodeGen.MAX_NODES, GBM.this._parms._nbins);
            for (int i = 0; i < GBM.this._nclass; i++) {
                if (((GBMModel) GBM.this._model)._output._distribution[i] != 0 && (i != 1 || GBM.this._nclass != 2)) {
                    dTreeArr[i] = new DTree(GBM.this._train._names, GBM.this._ncols, (char) GBM.this._parms._nbins, (char) GBM.this._nclass, GBM.this._parms._min_rows);
                    new GBMUndecidedNode(dTreeArr[i], -1, DHistogram.initialHist(GBM.this._train, GBM.this._ncols, max, dHistogramArr[i][0], false));
                }
            }
            int[] iArr = new int[GBM.this._nclass];
            for (int i2 = 0; i2 < GBM.this._parms._max_depth; i2++) {
                if (!GBM.this.isRunning()) {
                    return;
                }
                dHistogramArr = GBM.this.buildLayer(GBM.this._train, max, dTreeArr, iArr, dHistogramArr, false, false);
                if (dHistogramArr == null) {
                    break;
                }
            }
            for (int i3 = 0; i3 < GBM.this._nclass; i3++) {
                DTree dTree = dTreeArr[i3];
                if (dTree != null) {
                    int len = dTree.len();
                    iArr[i3] = len;
                    for (int i4 = 0; i4 < len; i4++) {
                        if (dTree.node(i4) instanceof DTree.DecidedNode) {
                            DTree.DecidedNode decided = dTree.decided(i4);
                            if (decided._split._col != -1) {
                                for (int i5 = 0; i5 < decided._nids.length; i5++) {
                                    int i6 = decided._nids[i5];
                                    if (i6 == -1 || (dTree.node(i6) instanceof DTree.UndecidedNode) || ((dTree.node(i6) instanceof DTree.DecidedNode) && ((DTree.DecidedNode) dTree.node(i6))._split.col() == -1)) {
                                        decided._nids[i5] = new GBMLeafNode(dTree, i4).nid();
                                    }
                                }
                            } else if (i4 == 0) {
                                new GBMLeafNode(dTree, -1, 0);
                            }
                        }
                    }
                }
            }
            GammaPass gammaPass = (GammaPass) new GammaPass(dTreeArr, iArr, GBM.this._parms._distribution == GBMModel.GBMParameters.Family.bernoulli).doAll(GBM.this._train);
            double d = (GBM.this._nclass <= 1 || GBM.this._parms._distribution == GBMModel.GBMParameters.Family.bernoulli) ? 1.0d : (GBM.this._nclass - 1) / GBM.this._nclass;
            for (int i7 = 0; i7 < GBM.this._nclass; i7++) {
                DTree dTree2 = dTreeArr[i7];
                if (dTree2 != null) {
                    for (int i8 = 0; i8 < dTree2._len - iArr[i7]; i8++) {
                        float f = (float) (((GBM.this._parms._learn_rate * d) * gammaPass._rss[i7][i8]) / gammaPass._gss[i7][i8]);
                        if (gammaPass._gss[i7][i8] == 0.0d) {
                            f = (float) (Math.signum(gammaPass._rss[i7][i8]) * 10000.0d);
                        }
                        if (GBM.this._parms._distribution == GBMModel.GBMParameters.Family.multinomial) {
                            if (f > 10000.0d) {
                                f = 10000.0f;
                            } else if (f < -10000.0d) {
                                f = -10000.0f;
                            }
                        }
                        if (!$assertionsDisabled && (Float.isNaN(f) || Float.isInfinite(f))) {
                            throw new AssertionError();
                        }
                        ((DTree.LeafNode) dTree2.node(iArr[i7] + i8))._pred = f;
                    }
                }
            }
            new MRTask() { // from class: hex.tree.gbm.GBM.GBMDriver.2
                public void map(Chunk[] chunkArr) {
                    for (int i9 = 0; i9 < GBM.this._nclass; i9++) {
                        if (dTreeArr[i9] != null) {
                            Chunk chk_nids = GBM.this.chk_nids(chunkArr, i9);
                            Chunk chk_tree = GBM.this.chk_tree(chunkArr, i9);
                            for (int i10 = 0; i10 < chk_nids._len; i10++) {
                                if (((int) chk_nids.at8(i10)) >= 0) {
                                    chk_tree.set(i10, (float) (chk_tree.atd(i10) + ((DTree.LeafNode) r0.node(r0))._pred));
                                    chk_nids.set(i10, 0L);
                                }
                            }
                        }
                    }
                }
            }.doAll(GBM.this._train);
            ((GBMModel) GBM.this._model)._output.addKTrees(dTreeArr);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // hex.tree.SharedTree.Driver
        public GBMModel makeModel(Key key, GBMModel.GBMParameters gBMParameters, double d, double d2) {
            return new GBMModel(key, gBMParameters, new GBMModel.GBMOutput(GBM.this, d, d2));
        }

        static {
            $assertionsDisabled = !GBM.class.desiredAssertionStatus();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/tree/gbm/GBM$GBMLeafNode.class */
    public static class GBMLeafNode extends DTree.LeafNode {
        static final /* synthetic */ boolean $assertionsDisabled;

        GBMLeafNode(DTree dTree, int i) {
            super(dTree, i);
        }

        GBMLeafNode(DTree dTree, int i, int i2) {
            super(dTree, i, i2);
        }

        @Override // hex.tree.DTree.Node
        protected AutoBuffer compress(AutoBuffer autoBuffer) {
            if ($assertionsDisabled || !Double.isNaN(this._pred)) {
                return autoBuffer.put4f(this._pred);
            }
            throw new AssertionError();
        }

        @Override // hex.tree.DTree.Node
        protected int size() {
            return 4;
        }

        static {
            $assertionsDisabled = !GBM.class.desiredAssertionStatus();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/tree/gbm/GBM$GBMUndecidedNode.class */
    public static class GBMUndecidedNode extends DTree.UndecidedNode {
        GBMUndecidedNode(DTree dTree, int i, DHistogram[] dHistogramArr) {
            super(dTree, i, dHistogramArr);
        }

        @Override // hex.tree.DTree.UndecidedNode
        public int[] scoreCols(DHistogram[] dHistogramArr) {
            return null;
        }
    }

    public Model.ModelCategory[] can_build() {
        return new Model.ModelCategory[]{Model.ModelCategory.Regression, Model.ModelCategory.Binomial, Model.ModelCategory.Multinomial};
    }

    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Stable;
    }

    public GBM(GBMModel.GBMParameters gBMParameters) {
        super(GBMGrid.MODEL_NAME, gBMParameters);
        init(false);
    }

    /* renamed from: schema, reason: merged with bridge method [inline-methods] */
    public GBMV3 m168schema() {
        return new GBMV3();
    }

    public Job<GBMModel> trainModel() {
        return start(new GBMDriver(), this._parms._ntrees);
    }

    public Vec vresponse() {
        return super.vresponse() == null ? response() : super.vresponse();
    }

    @Override // hex.tree.SharedTree
    public void init(boolean z) {
        super.init(z);
        double d = 0.0d;
        if (z) {
            d = this._response.mean();
            this._initialPrediction = this._nclass == 1 ? d : this._nclass == 2 ? (-0.5d) * Math.log(d / (1.0d - d)) : 0.0d;
            if (this._parms._distribution == GBMModel.GBMParameters.Family.AUTO) {
                if (this._nclass == 1) {
                    this._parms._distribution = GBMModel.GBMParameters.Family.gaussian;
                }
                if (this._nclass == 2) {
                    this._parms._distribution = GBMModel.GBMParameters.Family.bernoulli;
                }
                if (this._nclass >= 3) {
                    this._parms._distribution = GBMModel.GBMParameters.Family.multinomial;
                }
            }
        }
        switch (this._parms._distribution) {
            case bernoulli:
                if (this._nclass == 2) {
                    if (this._response != null) {
                        this._initialPrediction = Math.log(d / (1.0d - d));
                        break;
                    }
                } else {
                    error("_distribution", "Binomial requires the response to be a 2-class categorical");
                    break;
                }
                break;
            case multinomial:
                if (!isClassifier()) {
                    error("_distribution", "Multinomial requires an enum response.");
                    break;
                }
                break;
            case gaussian:
                if (isClassifier()) {
                    error("_distribution", "Gaussian requires the response to be numeric.");
                    break;
                }
                break;
            case AUTO:
                break;
            default:
                error("_distribution", "Invalid distribution: " + this._parms._distribution);
                break;
        }
        if (0.0d >= this._parms._learn_rate || this._parms._learn_rate > 1.0d) {
            error("_learn_rate", "learn_rate must be between 0 and 1");
        }
    }

    @Override // hex.tree.SharedTree
    protected DTree.DecidedNode makeDecided(DTree.UndecidedNode undecidedNode, DHistogram[] dHistogramArr) {
        return new GBMDecidedNode(undecidedNode, dHistogramArr);
    }

    @Override // hex.tree.SharedTree
    protected double score1(Chunk[] chunkArr, double[] dArr, int i) {
        if (this._parms._distribution == GBMModel.GBMParameters.Family.bernoulli) {
            dArr[1] = 1.0d / (1.0d + Math.exp(chk_tree(chunkArr, 0).atd(i)));
            dArr[2] = 1.0d - dArr[1];
            return 1.0d;
        }
        if (this._nclass == 1) {
            double atd = chk_tree(chunkArr, 0).atd(i);
            dArr[0] = atd;
            return atd;
        }
        if (this._nclass == 2) {
            dArr[1] = Math.exp(chk_tree(chunkArr, 0).atd(i));
            dArr[2] = 1.0d / dArr[1];
            return dArr[1] + dArr[2];
        }
        for (int i2 = 0; i2 < this._nclass; i2++) {
            dArr[i2 + 1] = chk_tree(chunkArr, i2).atd(i);
        }
        return GenModel.log_rescale(dArr);
    }
}
