package water.rapids.prims;

import hex.Model;
import java.util.Arrays;
import water.MRTask;
import water.fvec.C0DChunk;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.Env;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.AstRoot;
import water.rapids.vals.ValFrame;
import water.util.ArrayUtils;

/* loaded from: input_file:water/rapids/prims/AstPredictedVsActualByVar.class */
public class AstPredictedVsActualByVar extends AstPrimitive<AstPredictedVsActualByVar> {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:water/rapids/prims/AstPredictedVsActualByVar$PredictedVsActualByVar.class */
    public static class PredictedVsActualByVar extends MRTask<PredictedVsActualByVar> {
        private final int _s;
        private double[] _preds;
        private double[] _acts;
        private double[] _weights;

        public PredictedVsActualByVar(Vec vec) {
            this._s = vec.domain().length + 1;
        }

        public void map(Chunk[] chunkArr) {
            this._preds = new double[this._s];
            this._acts = new double[this._s];
            this._weights = new double[this._s];
            Chunk chunk = chunkArr[0];
            Chunk chunk2 = chunkArr[1];
            Chunk chunk3 = chunkArr[2];
            Chunk c0DChunk = chunkArr.length == 4 ? chunkArr[3] : new C0DChunk(1.0d, chunk._len);
            for (int i = 0; i < chunk2._len; i++) {
                if (!chunk2.isNA(i) && c0DChunk.atd(i) != 0.0d) {
                    int atd = chunk3.isNA(i) ? this._s - 1 : (int) chunk3.atd(i);
                    double atd2 = c0DChunk.atd(i);
                    double[] dArr = this._preds;
                    dArr[atd] = dArr[atd] + (atd2 * chunk.atd(i));
                    double[] dArr2 = this._acts;
                    dArr2[atd] = dArr2[atd] + (atd2 * chunk2.atd(i));
                    double[] dArr3 = this._weights;
                    dArr3[atd] = dArr3[atd] + atd2;
                }
            }
        }

        public void reduce(PredictedVsActualByVar predictedVsActualByVar) {
            this._preds = ArrayUtils.add(this._preds, predictedVsActualByVar._preds);
            this._acts = ArrayUtils.add(this._acts, predictedVsActualByVar._acts);
            this._weights = ArrayUtils.add(this._weights, predictedVsActualByVar._weights);
        }

        protected void postGlobal() {
            for (int i = 0; i < this._weights.length; i++) {
                if (this._weights[i] != 0.0d) {
                    double[] dArr = this._preds;
                    int i2 = i;
                    dArr[i2] = dArr[i2] / this._weights[i];
                    double[] dArr2 = this._acts;
                    int i3 = i;
                    dArr2[i3] = dArr2[i3] / this._weights[i];
                }
            }
        }
    }

    public String[] args() {
        return new String[]{"model"};
    }

    public int nargs() {
        return 5;
    }

    public String str() {
        return "predicted.vs.actual.by.var";
    }

    /* renamed from: apply, reason: merged with bridge method [inline-methods] */
    public ValFrame m381apply(Env env, Env.StackHelp stackHelp, AstRoot[] astRootArr) {
        Model model = stackHelp.track(astRootArr[1].exec(env)).getModel();
        if (!model.isSupervised()) {
            throw new IllegalArgumentException("Only supervised models are supported for calculating predicted v actual");
        }
        if (model._output.isMultinomialClassifier()) {
            throw new IllegalArgumentException("Multinomial classification models are not supported by predicted v actual");
        }
        Frame frame = stackHelp.track(astRootArr[2].exec(env)).getFrame();
        String str = stackHelp.track(astRootArr[3].exec(env)).getStr();
        if (frame.vec(str) == null) {
            throw new IllegalArgumentException("Frame doesn't contain column '" + str + "'.");
        }
        Frame frame2 = stackHelp.track(astRootArr[4].exec(env)).getFrame();
        if (frame.numRows() != frame2.numRows()) {
            throw new IllegalArgumentException("Input frame and frame of predictions need to have same number of columns.");
        }
        Vec vec = frame2.vec(0);
        Vec vec2 = frame.vec(model._output.responseName());
        Vec vec3 = frame.vec(model._output.weightsName());
        if (vec2.domain() != vec.domain() && !Arrays.equals(vec2.domain(), vec.domain())) {
            throw new IllegalArgumentException("Actual and predicted need to have identical domain.");
        }
        Vec vec4 = frame.vec(str);
        Vec[] vecArr = {vec, vec2, vec4};
        if (vec3 != null) {
            vecArr = (Vec[]) ArrayUtils.append(vecArr, new Vec[]{vec3});
        }
        PredictedVsActualByVar predictedVsActualByVar = (PredictedVsActualByVar) new PredictedVsActualByVar(vec4).doAll(vecArr);
        return new ValFrame(new Frame(new String[]{str, frame2.name(0), "actual"}, new Vec[]{Vec.makeVec(ArrayUtils.append(vec4.domain(), (String[]) null), Vec.newKey()), Vec.makeVec(predictedVsActualByVar._preds, Vec.newKey()), Vec.makeVec(predictedVsActualByVar._acts, Vec.newKey())}));
    }
}
