package water.rapids.ast.prims.advmath;

import java.util.Random;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.Env;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.AstRoot;
import water.rapids.vals.ValFrame;
import water.util.RandomUtils;
import water.util.VecUtils;

/* loaded from: input_file:water/rapids/ast/prims/advmath/AstKFold.class */
public class AstKFold extends AstPrimitive {
    @Override // water.rapids.ast.AstPrimitive
    public String[] args() {
        return new String[]{"ary", "nfolds", "seed"};
    }

    @Override // water.rapids.ast.AstPrimitive
    public int nargs() {
        return 4;
    }

    @Override // water.rapids.ast.AstRoot
    public String str() {
        return "kfold_column";
    }

    public static Vec kfoldColumn(Vec vec, final int i, final long j) {
        new MRTask() { // from class: water.rapids.ast.prims.advmath.AstKFold.1
            @Override // water.MRTask
            public void map(Chunk chunk) {
                long start = chunk.start();
                for (int i2 = 0; i2 < chunk._len; i2++) {
                    chunk.set(i2, Math.abs(RandomUtils.getRNG((start + j) + i2).nextInt()) % i);
                }
            }
        }.doAll(vec);
        return vec;
    }

    public static Vec moduloKfoldColumn(Vec vec, final int i) {
        new MRTask() { // from class: water.rapids.ast.prims.advmath.AstKFold.2
            @Override // water.MRTask
            public void map(Chunk chunk) {
                long start = chunk.start();
                for (int i2 = 0; i2 < chunk._len; i2++) {
                    chunk.set(i2, (int) ((start + i2) % i));
                }
            }
        }.doAll(vec);
        return vec;
    }

    public static Vec stratifiedKFoldColumn(Vec vec, final int i, long j) {
        if (!vec.isCategorical() && (!vec.isNumeric() || !vec.isInt())) {
            throw new IllegalArgumentException("stratification only applies to integer and categorical columns. Got: " + vec.get_type_str());
        }
        final long[] domain = new VecUtils.CollectDomain().doAll(vec).domain();
        final int length = vec.isNumeric() ? domain.length : vec.domain().length;
        final long[] jArr = new long[length];
        for (int i2 = 0; i2 < length; i2++) {
            jArr[i2] = RandomUtils.getRNG(j + i2).nextLong();
        }
        return new MRTask() { // from class: water.rapids.ast.prims.advmath.AstKFold.3
            private int getFoldId(long j2, long j3) {
                return Math.abs(RandomUtils.getRNG(j2 + j3).nextInt()) % i;
            }

            @Override // water.MRTask
            public void map(Chunk[] chunkArr) {
                long start = chunkArr[0].start();
                for (int i3 = 0; i3 < i; i3++) {
                    for (int i4 = 0; i4 < length; i4++) {
                        for (int i5 = 0; i5 < chunkArr[0]._len; i5++) {
                            if (!chunkArr[0].isNA(i5)) {
                                if (chunkArr[0].at8(i5) == (domain == null ? i4 : domain[i4]) && i3 == getFoldId(start + i5, jArr[i4])) {
                                    chunkArr[1].set(i5, i3);
                                }
                            } else if ((start + i5) % i == i3) {
                                chunkArr[1].set(i5, i3);
                            }
                        }
                    }
                }
            }
        }.doAll(new Frame(vec, vec.makeZero()))._fr.vec(1);
    }

    @Override // water.rapids.ast.AstPrimitive
    public ValFrame apply(Env env, Env.StackHelp stackHelp, AstRoot[] astRootArr) {
        Vec makeZero = stackHelp.track(astRootArr[1].exec(env)).getFrame().anyVec().makeZero();
        int num = (int) astRootArr[2].exec(env).getNum();
        long num2 = (long) astRootArr[3].exec(env).getNum();
        Vec[] vecArr = new Vec[1];
        vecArr[0] = kfoldColumn(makeZero, num, num2 == -1 ? new Random().nextLong() : num2);
        return new ValFrame(new Frame(vecArr));
    }
}
