package burlap.behavior.singleagent.learnfromdemo.mlirl.commonrfs;

import burlap.behavior.functionapproximation.FunctionGradient;
import burlap.behavior.functionapproximation.ParametricFunction;
import burlap.behavior.functionapproximation.dense.DenseStateFeatures;
import burlap.behavior.singleagent.learnfromdemo.mlirl.support.DifferentiableRF;
import burlap.mdp.core.action.Action;
import burlap.mdp.core.state.State;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

/* loaded from: input_file:burlap/behavior/singleagent/learnfromdemo/mlirl/commonrfs/LinearStateActionDifferentiableRF.class */
public class LinearStateActionDifferentiableRF implements DifferentiableRF {
    protected Map<Action, Integer> actionMap;
    protected double[] parameters;
    protected int dim;
    protected DenseStateFeatures fvGen;
    protected int numStateFeatures;
    int numActions;

    public LinearStateActionDifferentiableRF(DenseStateFeatures denseStateFeatures, int i, Action... actionArr) {
        this.numActions = 0;
        this.fvGen = denseStateFeatures;
        this.numStateFeatures = i;
        this.actionMap = new HashMap(actionArr.length);
        for (int i2 = 0; i2 < actionArr.length; i2++) {
            this.actionMap.put(actionArr[i2], Integer.valueOf(i2));
        }
        this.numActions = actionArr.length;
        this.parameters = new double[this.numActions * this.numStateFeatures];
        this.dim = this.numActions * this.numStateFeatures;
    }

    public void addAction(Action action) {
        this.actionMap.put(action, Integer.valueOf(this.numActions));
        this.numActions++;
        this.parameters = new double[this.numActions * this.numStateFeatures];
        this.dim = this.numActions * this.numStateFeatures;
    }

    @Override // burlap.mdp.singleagent.model.RewardFunction
    public double reward(State state, Action action, State state2) {
        double[] features = this.fvGen.features(state);
        int intValue = this.actionMap.get(action).intValue() * this.numStateFeatures;
        double d = 0.0d;
        for (int i = intValue; i < intValue + this.numStateFeatures; i++) {
            d += this.parameters[i] * features[i - intValue];
        }
        return d;
    }

    @Override // burlap.behavior.singleagent.learnfromdemo.mlirl.support.DifferentiableRF
    public FunctionGradient gradient(State state, Action action, State state2) {
        double[] features = this.fvGen.features(state);
        int intValue = this.actionMap.get(action).intValue() * this.numStateFeatures;
        FunctionGradient.SparseGradient sparseGradient = new FunctionGradient.SparseGradient(features.length);
        int i = this.numStateFeatures * this.numActions;
        for (int i2 = 0; i2 < features.length; i2++) {
            sparseGradient.put(i2 + i, features[i2]);
        }
        return sparseGradient;
    }

    @Override // burlap.behavior.functionapproximation.ParametricFunction
    public int numParameters() {
        return this.dim;
    }

    @Override // burlap.behavior.functionapproximation.ParametricFunction
    public double getParameter(int i) {
        return this.parameters[i];
    }

    @Override // burlap.behavior.functionapproximation.ParametricFunction
    public void setParameter(int i, double d) {
        this.parameters[i] = d;
    }

    @Override // burlap.behavior.functionapproximation.ParametricFunction
    public void resetParameters() {
        for (int i = 0; i < this.parameters.length; i++) {
            this.parameters[i] = 0.0d;
        }
    }

    @Override // burlap.behavior.functionapproximation.ParametricFunction
    public ParametricFunction copy() {
        LinearStateActionDifferentiableRF linearStateActionDifferentiableRF = new LinearStateActionDifferentiableRF(this.fvGen, this.numStateFeatures, new Action[0]);
        for (Map.Entry<Action, Integer> entry : this.actionMap.entrySet()) {
            linearStateActionDifferentiableRF.actionMap.put(entry.getKey(), entry.getValue());
        }
        linearStateActionDifferentiableRF.parameters = (double[]) this.parameters.clone();
        return linearStateActionDifferentiableRF;
    }

    public String toString() {
        return Arrays.toString(this.parameters);
    }
}
