package hex.genmodel.algos.tree;

import hex.genmodel.PredictContributions;
import hex.genmodel.algos.tree.TreeSHAPPredictor;
import hex.genmodel.attributes.parameters.FeatureContribution;
import hex.genmodel.utils.ArrayUtils;

/* loaded from: input_file:hex/genmodel/algos/tree/ContributionsPredictor.class */
public abstract class ContributionsPredictor<E> implements PredictContributions {
    private final int _ncontribs;
    private final String[] _contribution_names;
    private final TreeSHAPPredictor<E> _treeSHAPPredictor;
    private final int _workspaceSize;
    private static final ThreadLocal<TreeSHAPPredictor.Workspace> _workspace;
    static final /* synthetic */ boolean $assertionsDisabled;

    public ContributionsPredictor(int i, String[] strArr, TreeSHAPPredictor<E> treeSHAPPredictor) {
        this._ncontribs = i;
        this._contribution_names = ArrayUtils.append(strArr, "BiasTerm");
        this._treeSHAPPredictor = treeSHAPPredictor;
        this._workspaceSize = this._treeSHAPPredictor.getWorkspaceSize();
    }

    @Override // hex.genmodel.PredictContributions
    public final String[] getContributionNames() {
        return this._contribution_names;
    }

    @Override // hex.genmodel.PredictContributions
    public final float[] calculateContributions(double[] dArr) {
        float[] fArr = new float[this._ncontribs];
        this._treeSHAPPredictor.calculateContributions(toInputRow(dArr), fArr, 0, -1, getWorkspace());
        return getContribs(fArr);
    }

    protected abstract E toInputRow(double[] dArr);

    public float[] getContribs(float[] fArr) {
        return fArr;
    }

    private TreeSHAPPredictor.Workspace getWorkspace() {
        TreeSHAPPredictor.Workspace workspace = _workspace.get();
        if (workspace == null || workspace.getSize() != this._workspaceSize) {
            workspace = this._treeSHAPPredictor.makeWorkspace();
            if (!$assertionsDisabled && workspace.getSize() != this._workspaceSize) {
                throw new AssertionError();
            }
            _workspace.set(workspace);
        }
        return workspace;
    }

    @Override // hex.genmodel.PredictContributions
    public FeatureContribution[] calculateContributions(double[] dArr, int i, int i2, boolean z) {
        int[] composeContributions = new ContributionComposer().composeContributions(ArrayUtils.range(0, this._contribution_names.length - 1), calculateContributions(dArr), i, i2, z);
        FeatureContribution[] featureContributionArr = new FeatureContribution[composeContributions.length];
        for (int i3 = 0; i3 < composeContributions.length; i3++) {
            featureContributionArr[i3] = new FeatureContribution(this._contribution_names[composeContributions[i3]], r0[composeContributions[i3]]);
        }
        return featureContributionArr;
    }

    static {
        $assertionsDisabled = !ContributionsPredictor.class.desiredAssertionStatus();
        _workspace = new ThreadLocal<>();
    }
}
