package burlap.behavior.functionapproximation.dense;

import burlap.behavior.functionapproximation.DifferentiableStateActionValue;
import burlap.behavior.functionapproximation.DifferentiableStateValue;
import burlap.behavior.functionapproximation.FunctionGradient;
import burlap.mdp.core.action.Action;
import burlap.mdp.core.state.State;
import java.util.HashMap;
import java.util.Map;

/* loaded from: input_file:burlap/behavior/functionapproximation/dense/DenseLinearVFA.class */
public class DenseLinearVFA implements DifferentiableStateValue, DifferentiableStateActionValue {
    protected DenseStateFeatures stateFeatures;
    protected double[] stateWeights;
    protected double[] stateActionWeights;
    protected double defaultWeight;
    protected double[] currentStateFeatures;
    protected double currentValue;
    protected State lastState;
    protected Map<Action, Integer> actionOffset = new HashMap();
    protected int currentActionOffset = -1;
    protected FunctionGradient currentGradient = null;

    public DenseLinearVFA(DenseStateFeatures denseStateFeatures, double d) {
        this.defaultWeight = 0.0d;
        this.stateFeatures = denseStateFeatures;
        this.defaultWeight = d;
    }

    @Override // burlap.behavior.functionapproximation.ParametricFunction.ParametricStateActionFunction
    public double evaluate(State state, Action action) {
        this.currentStateFeatures = this.stateFeatures.features(state);
        this.currentActionOffset = getActionOffset(action);
        int length = this.currentActionOffset * this.currentStateFeatures.length;
        double d = 0.0d;
        for (int i = 0; i < this.currentStateFeatures.length; i++) {
            d += this.currentStateFeatures[i] * this.stateActionWeights[i + length];
        }
        this.currentValue = d;
        this.currentGradient = null;
        this.lastState = state;
        return this.currentValue;
    }

    @Override // burlap.behavior.functionapproximation.ParametricFunction.ParametricStateFunction
    public double evaluate(State state) {
        this.currentStateFeatures = this.stateFeatures.features(state);
        this.currentActionOffset = 0;
        if (this.stateWeights == null) {
            this.stateWeights = new double[this.currentStateFeatures.length];
            for (int i = 0; i < this.stateWeights.length; i++) {
                this.stateWeights[i] = this.defaultWeight;
            }
        }
        double d = 0.0d;
        for (int i2 = 0; i2 < this.currentStateFeatures.length; i2++) {
            d += this.currentStateFeatures[i2] * this.stateWeights[i2];
        }
        this.currentValue = d;
        this.currentGradient = null;
        this.lastState = state;
        return this.currentValue;
    }

    @Override // burlap.behavior.functionapproximation.DifferentiableStateValue
    public FunctionGradient gradient(State state) {
        double[] features;
        if (this.lastState != state) {
            features = this.stateFeatures.features(state);
        } else {
            if (this.currentGradient != null) {
                return this.currentGradient;
            }
            features = this.currentStateFeatures;
        }
        FunctionGradient.SparseGradient sparseGradient = new FunctionGradient.SparseGradient(features.length);
        for (int i = 0; i < features.length; i++) {
            sparseGradient.put(i, features[i]);
        }
        this.currentGradient = sparseGradient;
        this.currentStateFeatures = features;
        this.lastState = state;
        return sparseGradient;
    }

    @Override // burlap.behavior.functionapproximation.DifferentiableStateActionValue
    public FunctionGradient gradient(State state, Action action) {
        double[] features;
        if (this.lastState != state) {
            features = this.stateFeatures.features(state);
        } else {
            if (this.currentGradient != null) {
                return this.currentGradient;
            }
            features = this.currentStateFeatures;
        }
        FunctionGradient.SparseGradient sparseGradient = new FunctionGradient.SparseGradient(features.length);
        int actionOffset = getActionOffset(action) * features.length;
        for (int i = 0; i < features.length; i++) {
            sparseGradient.put(i + actionOffset, features[i]);
        }
        this.currentGradient = sparseGradient;
        this.currentStateFeatures = features;
        this.lastState = state;
        return sparseGradient;
    }

    @Override // burlap.behavior.functionapproximation.ParametricFunction
    public int numParameters() {
        if (this.stateWeights != null) {
            return this.stateWeights.length;
        }
        if (this.stateActionWeights != null) {
            return this.stateActionWeights.length;
        }
        return 0;
    }

    @Override // burlap.behavior.functionapproximation.ParametricFunction
    public double getParameter(int i) {
        if (this.stateWeights != null) {
            if (i < this.stateWeights.length) {
                return this.stateWeights[i];
            }
        } else if (this.stateActionWeights != null && i < this.stateActionWeights.length) {
            return this.stateActionWeights[i];
        }
        throw new RuntimeException("Parameter index out of bounds; parameter cannot be returned.");
    }

    @Override // burlap.behavior.functionapproximation.ParametricFunction
    public void setParameter(int i, double d) {
        if (this.stateWeights != null) {
            if (i < this.stateWeights.length) {
                this.stateWeights[i] = d;
                return;
            }
        } else if (this.stateActionWeights != null && i < this.stateActionWeights.length) {
            this.stateActionWeights[i] = d;
            return;
        }
        throw new RuntimeException("Parameter index out of bounds; parameter cannot be set.");
    }

    @Override // burlap.behavior.functionapproximation.ParametricFunction
    public void resetParameters() {
        if (this.stateWeights != null) {
            for (int i = 0; i < this.stateWeights.length; i++) {
                this.stateWeights[i] = this.defaultWeight;
            }
            return;
        }
        if (this.stateActionWeights != null) {
            for (int i2 = 0; i2 < this.stateActionWeights.length; i2++) {
                this.stateActionWeights[i2] = this.defaultWeight;
            }
        }
    }

    public int getActionOffset(Action action) {
        Integer num = this.actionOffset.get(action);
        if (num == null) {
            num = Integer.valueOf(this.actionOffset.size());
            this.actionOffset.put(action, num);
            expandStateActionWeights(this.currentStateFeatures.length);
        }
        return num.intValue();
    }

    protected void expandStateActionWeights(int i) {
        if (this.stateActionWeights == null) {
            this.stateActionWeights = new double[i];
            for (int i2 = 0; i2 < this.stateActionWeights.length; i2++) {
                this.stateActionWeights[i2] = this.defaultWeight;
            }
            return;
        }
        double[] dArr = new double[this.stateActionWeights.length + i];
        for (int i3 = 0; i3 < this.stateActionWeights.length; i3++) {
            dArr[i3] = this.stateActionWeights[i3];
        }
        for (int length = this.stateActionWeights.length; length < dArr.length; length++) {
            dArr[length] = this.defaultWeight;
        }
        this.stateActionWeights = dArr;
    }

    public DenseStateFeatures getStateFeatures() {
        return this.stateFeatures;
    }

    public double getDefaultWeight() {
        return this.defaultWeight;
    }

    public void initializeStateWeightVector(int i, double d) {
        this.stateWeights = new double[i];
        for (int i2 = 0; i2 < this.stateWeights.length; i2++) {
            this.stateWeights[i2] = d;
        }
    }

    public void initializeStateActionWeightVector(int i, double d) {
        this.stateActionWeights = new double[i];
        for (int i2 = 0; i2 < this.stateActionWeights.length; i2++) {
            this.stateActionWeights[i2] = d;
        }
    }

    public Map<Action, Integer> getActionOffset() {
        return this.actionOffset;
    }

    public void setActionOffset(Map<Action, Integer> map) {
        this.actionOffset = map;
    }

    public void setActionOffset(Action action, int i) {
        this.actionOffset.put(action, Integer.valueOf(i));
    }

    @Override // burlap.behavior.functionapproximation.ParametricFunction
    public DenseLinearVFA copy() {
        DenseLinearVFA denseLinearVFA = new DenseLinearVFA(this.stateFeatures, this.defaultWeight);
        denseLinearVFA.actionOffset = new HashMap(this.actionOffset);
        if (this.stateWeights != null) {
            denseLinearVFA.stateWeights = new double[this.stateWeights.length];
            for (int i = 0; i < this.stateWeights.length; i++) {
                denseLinearVFA.stateWeights[i] = this.stateWeights[i];
            }
        }
        if (this.stateActionWeights != null) {
            denseLinearVFA.stateActionWeights = new double[this.stateActionWeights.length];
            for (int i2 = 0; i2 < this.stateActionWeights.length; i2++) {
                denseLinearVFA.stateActionWeights[i2] = this.stateActionWeights[i2];
            }
        }
        return denseLinearVFA;
    }
}
