package burlap.behavior.functionapproximation.sparse;

import burlap.behavior.functionapproximation.DifferentiableStateActionValue;
import burlap.behavior.functionapproximation.DifferentiableStateValue;
import burlap.behavior.functionapproximation.FunctionGradient;
import burlap.mdp.core.action.Action;
import burlap.mdp.core.state.State;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:burlap/behavior/functionapproximation/sparse/LinearVFA.class */
public class LinearVFA implements DifferentiableStateValue, DifferentiableStateActionValue {
    protected SparseStateFeatures sparseStateFeatures;
    protected SparseCrossProductFeatures stateActionFeatures;
    protected Map<Integer, Double> weights;
    protected double defaultWeight;
    protected List<StateFeature> currentFeatures;
    protected double currentValue;
    protected FunctionGradient currentGradient;
    protected State lastState;
    protected Action lastAction;

    public LinearVFA(SparseStateFeatures sparseStateFeatures) {
        this.defaultWeight = 0.0d;
        this.currentGradient = null;
        this.lastState = null;
        this.lastAction = null;
        this.sparseStateFeatures = sparseStateFeatures;
        this.stateActionFeatures = new SparseCrossProductFeatures(sparseStateFeatures);
        this.weights = new HashMap();
    }

    public LinearVFA(SparseStateFeatures sparseStateFeatures, double d) {
        this.defaultWeight = 0.0d;
        this.currentGradient = null;
        this.lastState = null;
        this.lastAction = null;
        this.sparseStateFeatures = sparseStateFeatures;
        this.stateActionFeatures = new SparseCrossProductFeatures(sparseStateFeatures);
        this.defaultWeight = d;
        this.weights = new HashMap();
    }

    @Override // burlap.behavior.functionapproximation.ParametricFunction.ParametricStateActionFunction
    public double evaluate(State state, Action action) {
        List<StateFeature> features = this.stateActionFeatures.features(state, action);
        double d = 0.0d;
        for (StateFeature stateFeature : features) {
            d += stateFeature.value * getWeight(stateFeature.id);
        }
        this.currentValue = d;
        this.currentGradient = null;
        this.currentFeatures = features;
        this.lastState = state;
        this.lastAction = action;
        return d;
    }

    @Override // burlap.behavior.functionapproximation.ParametricFunction.ParametricStateFunction
    public double evaluate(State state) {
        List<StateFeature> features = this.sparseStateFeatures.features(state);
        double d = 0.0d;
        for (StateFeature stateFeature : features) {
            d += stateFeature.value * getWeight(stateFeature.id);
        }
        this.currentValue = d;
        this.currentGradient = null;
        this.currentFeatures = features;
        this.lastState = state;
        this.lastAction = null;
        return this.currentValue;
    }

    @Override // burlap.behavior.functionapproximation.DifferentiableStateValue
    public FunctionGradient gradient(State state) {
        List<StateFeature> features;
        if (this.lastState == state && this.lastAction == null) {
            if (this.currentGradient != null) {
                return this.currentGradient;
            }
            features = this.currentFeatures;
        } else {
            features = this.sparseStateFeatures.features(state);
        }
        FunctionGradient.SparseGradient sparseGradient = new FunctionGradient.SparseGradient(features.size());
        for (StateFeature stateFeature : features) {
            sparseGradient.put(stateFeature.id, stateFeature.value);
        }
        this.currentGradient = sparseGradient;
        this.lastState = state;
        this.lastAction = null;
        this.currentFeatures = features;
        return sparseGradient;
    }

    @Override // burlap.behavior.functionapproximation.DifferentiableStateActionValue
    public FunctionGradient gradient(State state, Action action) {
        List<StateFeature> features;
        if (this.lastState == state && this.lastAction == action) {
            if (this.currentGradient != null) {
                return this.currentGradient;
            }
            features = this.currentFeatures;
        } else {
            features = this.stateActionFeatures.features(state, action);
        }
        FunctionGradient.SparseGradient sparseGradient = new FunctionGradient.SparseGradient(features.size());
        for (StateFeature stateFeature : features) {
            sparseGradient.put(stateFeature.id, stateFeature.value);
        }
        this.currentGradient = sparseGradient;
        this.lastState = state;
        this.lastAction = action;
        this.currentFeatures = features;
        return sparseGradient;
    }

    @Override // burlap.behavior.functionapproximation.ParametricFunction
    public int numParameters() {
        return this.weights.size();
    }

    @Override // burlap.behavior.functionapproximation.ParametricFunction
    public double getParameter(int i) {
        return getWeight(i);
    }

    @Override // burlap.behavior.functionapproximation.ParametricFunction
    public void setParameter(int i, double d) {
        this.weights.put(Integer.valueOf(i), Double.valueOf(d));
    }

    protected double getWeight(int i) {
        Double d = this.weights.get(Integer.valueOf(i));
        if (d != null) {
            return d.doubleValue();
        }
        this.weights.put(Integer.valueOf(i), Double.valueOf(this.defaultWeight));
        return this.defaultWeight;
    }

    @Override // burlap.behavior.functionapproximation.ParametricFunction
    public void resetParameters() {
        this.weights.clear();
    }

    @Override // burlap.behavior.functionapproximation.ParametricFunction
    public LinearVFA copy() {
        LinearVFA linearVFA = new LinearVFA(this.sparseStateFeatures.copy(), this.defaultWeight);
        linearVFA.weights = new HashMap(this.weights.size());
        linearVFA.stateActionFeatures = this.stateActionFeatures.copy();
        for (Map.Entry<Integer, Double> entry : this.weights.entrySet()) {
            linearVFA.weights.put(entry.getKey(), entry.getValue());
        }
        return linearVFA;
    }
}
