package biz.k11i.xgboost;

import biz.k11i.xgboost.config.PredictorConfiguration;
import biz.k11i.xgboost.gbm.GradBooster;
import biz.k11i.xgboost.learner.ObjFunction;
import biz.k11i.xgboost.spark.SparkModelParam;
import biz.k11i.xgboost.util.FVec;
import biz.k11i.xgboost.util.ModelReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.Serializable;

/* loaded from: input_file:biz/k11i/xgboost/Predictor.class */
public class Predictor implements Serializable {
    private ModelParam mparam;
    private SparkModelParam sparkModelParam;
    private String name_obj;
    private String name_gbm;
    private ObjFunction obj;
    private GradBooster gbm;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:biz/k11i/xgboost/Predictor$ModelParam.class */
    public static class ModelParam implements Serializable {
        final float base_score;
        final int num_feature;
        final int num_class;
        final int saved_with_pbuffer;
        final int[] reserved;

        ModelParam(float f, int i, ModelReader modelReader) throws IOException {
            this.base_score = f;
            this.num_feature = i;
            this.num_class = modelReader.readInt();
            this.saved_with_pbuffer = modelReader.readInt();
            this.reserved = modelReader.readIntArray(30);
        }
    }

    public Predictor(InputStream inputStream) throws IOException {
        this(inputStream, null);
    }

    public Predictor(InputStream inputStream, PredictorConfiguration predictorConfiguration) throws IOException {
        predictorConfiguration = predictorConfiguration == null ? PredictorConfiguration.DEFAULT : predictorConfiguration;
        ModelReader modelReader = new ModelReader(inputStream);
        readParam(modelReader);
        initObjFunction(predictorConfiguration);
        initObjGbm();
        this.gbm.loadModel(predictorConfiguration, modelReader, this.mparam.saved_with_pbuffer != 0);
    }

    void readParam(ModelReader modelReader) throws IOException {
        float asFloat;
        int asUnsignedInt;
        byte[] readByteArray = modelReader.readByteArray(4);
        byte[] readByteArray2 = modelReader.readByteArray(4);
        if (readByteArray[0] == 98 && readByteArray[1] == 105 && readByteArray[2] == 110 && readByteArray[3] == 102) {
            asFloat = modelReader.asFloat(readByteArray2);
            asUnsignedInt = modelReader.readUnsignedInt();
        } else if (readByteArray[0] == 0 && readByteArray[1] == 5 && readByteArray[2] == 95) {
            String str = null;
            if (readByteArray[3] == 99 && readByteArray2[0] == 108 && readByteArray2[1] == 115 && readByteArray2[2] == 95) {
                str = SparkModelParam.MODEL_TYPE_CLS;
            } else if (readByteArray[3] == 114 && readByteArray2[0] == 101 && readByteArray2[1] == 103 && readByteArray2[2] == 95) {
                str = SparkModelParam.MODEL_TYPE_REG;
            }
            if (str != null) {
                this.sparkModelParam = new SparkModelParam(str, modelReader.readUTF((readByteArray2[3] << 8) + modelReader.readByteAsInt()), modelReader);
                asFloat = modelReader.readFloat();
                asUnsignedInt = modelReader.readUnsignedInt();
            } else {
                asFloat = modelReader.asFloat(readByteArray);
                asUnsignedInt = modelReader.asUnsignedInt(readByteArray2);
            }
        } else {
            asFloat = modelReader.asFloat(readByteArray);
            asUnsignedInt = modelReader.asUnsignedInt(readByteArray2);
        }
        this.mparam = new ModelParam(asFloat, asUnsignedInt, modelReader);
        this.name_obj = modelReader.readString();
        this.name_gbm = modelReader.readString();
    }

    void initObjFunction(PredictorConfiguration predictorConfiguration) {
        this.obj = predictorConfiguration.getObjFunction();
        if (this.obj == null) {
            this.obj = ObjFunction.fromName(this.name_obj);
        }
    }

    void initObjGbm() {
        this.obj = ObjFunction.fromName(this.name_obj);
        this.gbm = GradBooster.Factory.createGradBooster(this.name_gbm);
        this.gbm.setNumClass(this.mparam.num_class);
    }

    public float[] predict(FVec fVec) {
        return predict(fVec, false);
    }

    public float[] predict(FVec fVec, boolean z) {
        return predict(fVec, z, 0);
    }

    public float[] predict(FVec fVec, boolean z, int i) {
        float[] predictRaw = predictRaw(fVec, i);
        if (!z) {
            predictRaw = this.obj.predTransform(predictRaw);
        }
        return predictRaw;
    }

    float[] predictRaw(FVec fVec, int i) {
        float[] predict = this.gbm.predict(fVec, i);
        for (int i2 = 0; i2 < predict.length; i2++) {
            int i3 = i2;
            predict[i3] = predict[i3] + this.mparam.base_score;
        }
        return predict;
    }

    public float predictSingle(FVec fVec) {
        return predictSingle(fVec, false);
    }

    public float predictSingle(FVec fVec, boolean z) {
        return predictSingle(fVec, z, 0);
    }

    public float predictSingle(FVec fVec, boolean z, int i) {
        float predictSingleRaw = predictSingleRaw(fVec, i);
        if (!z) {
            predictSingleRaw = this.obj.predTransform(predictSingleRaw);
        }
        return predictSingleRaw;
    }

    float predictSingleRaw(FVec fVec, int i) {
        return this.gbm.predictSingle(fVec, i) + this.mparam.base_score;
    }

    public int[] predictLeaf(FVec fVec) {
        return predictLeaf(fVec, 0);
    }

    public int[] predictLeaf(FVec fVec, int i) {
        return this.gbm.predictLeaf(fVec, i);
    }

    public String[] predictLeafPath(FVec fVec) {
        return predictLeafPath(fVec, 0);
    }

    public String[] predictLeafPath(FVec fVec, int i) {
        return this.gbm.predictLeafPath(fVec, i);
    }

    public SparkModelParam getSparkModelParam() {
        return this.sparkModelParam;
    }

    public int getNumClass() {
        return this.mparam.num_class;
    }

    public GradBooster getBooster() {
        return this.gbm;
    }
}
