package burlap.behavior.learningrate;

import burlap.mdp.core.action.Action;
import burlap.mdp.core.state.State;
import burlap.statehashing.HashableState;
import burlap.statehashing.HashableStateFactory;
import java.util.HashMap;
import java.util.Map;

/* loaded from: input_file:burlap/behavior/learningrate/ExponentialDecayLR.class */
public class ExponentialDecayLR implements LearningRate {
    protected double initialLearningRate;
    protected double decayRate;
    protected double minimumLR;
    protected double universalLR;
    protected Map<HashableState, StateWiseLearningRate> stateWiseMap;
    protected Map<Integer, StateWiseLearningRate> featureWiseMap;
    protected boolean useStateWise;
    protected boolean useStateActionWise;
    protected HashableStateFactory hashingFactory;
    protected int lastPollTime;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:burlap/behavior/learningrate/ExponentialDecayLR$MutableDouble.class */
    public class MutableDouble {
        double md;
        int lastPollTime = -1;

        public MutableDouble(double d) {
            this.md = d;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:burlap/behavior/learningrate/ExponentialDecayLR$StateWiseLearningRate.class */
    public class StateWiseLearningRate {
        double learningRate;
        Map<String, MutableDouble> actionLearningRates;
        int lastPollTime = -1;

        public StateWiseLearningRate() {
            this.actionLearningRates = null;
            this.learningRate = ExponentialDecayLR.this.initialLearningRate;
            if (ExponentialDecayLR.this.useStateActionWise) {
                this.actionLearningRates = new HashMap();
            }
        }

        public MutableDouble getActionLearningRateEntry(Action action) {
            MutableDouble mutableDouble = this.actionLearningRates.get(action);
            if (mutableDouble == null) {
                mutableDouble = new MutableDouble(ExponentialDecayLR.this.initialLearningRate);
                this.actionLearningRates.put(action.actionName(), mutableDouble);
            }
            return mutableDouble;
        }
    }

    public ExponentialDecayLR(double d, double d2) {
        this.minimumLR = Double.MIN_NORMAL;
        this.useStateWise = false;
        this.useStateActionWise = false;
        this.lastPollTime = -1;
        if (d2 > 1.0d || d2 < 0.0d) {
            throw new RuntimeException("Decay rate must be <= 1 and >= 0");
        }
        this.initialLearningRate = d;
        this.decayRate = d2;
        this.universalLR = this.initialLearningRate;
    }

    public ExponentialDecayLR(double d, double d2, double d3) {
        this.minimumLR = Double.MIN_NORMAL;
        this.useStateWise = false;
        this.useStateActionWise = false;
        this.lastPollTime = -1;
        if (d2 > 1.0d || d2 < 0.0d) {
            throw new RuntimeException("Decay rate must be <= 1 and >= 0");
        }
        this.initialLearningRate = d;
        this.decayRate = d2;
        this.universalLR = this.initialLearningRate;
        this.minimumLR = d3;
    }

    public ExponentialDecayLR(double d, double d2, HashableStateFactory hashableStateFactory, boolean z) {
        this.minimumLR = Double.MIN_NORMAL;
        this.useStateWise = false;
        this.useStateActionWise = false;
        this.lastPollTime = -1;
        if (d2 > 1.0d || d2 < 0.0d) {
            throw new RuntimeException("Decay rate must be <= 1 and >= 0");
        }
        this.initialLearningRate = d;
        this.decayRate = d2;
        this.useStateWise = true;
        this.useStateActionWise = z;
        this.hashingFactory = hashableStateFactory;
        this.stateWiseMap = new HashMap();
        this.featureWiseMap = new HashMap();
    }

    public ExponentialDecayLR(double d, double d2, double d3, HashableStateFactory hashableStateFactory, boolean z) {
        this.minimumLR = Double.MIN_NORMAL;
        this.useStateWise = false;
        this.useStateActionWise = false;
        this.lastPollTime = -1;
        if (d2 > 1.0d || d2 < 0.0d) {
            throw new RuntimeException("Decay rate must be <= 1 and >= 0");
        }
        this.initialLearningRate = d;
        this.decayRate = d2;
        this.minimumLR = d3;
        this.useStateWise = true;
        this.useStateActionWise = z;
        this.hashingFactory = hashableStateFactory;
        this.stateWiseMap = new HashMap();
        this.featureWiseMap = new HashMap();
    }

    @Override // burlap.behavior.learningrate.LearningRate
    public double peekAtLearningRate(State state, Action action) {
        if (!this.useStateWise) {
            return this.universalLR;
        }
        StateWiseLearningRate stateWiseLearningRate = getStateWiseLearningRate(state);
        return !this.useStateActionWise ? stateWiseLearningRate.learningRate : stateWiseLearningRate.getActionLearningRateEntry(action).md;
    }

    @Override // burlap.behavior.learningrate.LearningRate
    public double pollLearningRate(int i, State state, Action action) {
        if (!this.useStateWise) {
            double d = this.universalLR;
            if (i > this.lastPollTime) {
                this.universalLR = nextLRVal(d);
                this.lastPollTime = i;
            }
            return d;
        }
        StateWiseLearningRate stateWiseLearningRate = getStateWiseLearningRate(state);
        if (!this.useStateActionWise) {
            double d2 = stateWiseLearningRate.learningRate;
            if (i > stateWiseLearningRate.lastPollTime) {
                stateWiseLearningRate.learningRate = nextLRVal(d2);
                stateWiseLearningRate.lastPollTime = i;
            }
            return d2;
        }
        MutableDouble actionLearningRateEntry = stateWiseLearningRate.getActionLearningRateEntry(action);
        double d3 = actionLearningRateEntry.md;
        if (i > actionLearningRateEntry.lastPollTime) {
            actionLearningRateEntry.md = nextLRVal(d3);
            actionLearningRateEntry.lastPollTime = i;
        }
        return d3;
    }

    @Override // burlap.behavior.learningrate.LearningRate
    public double peekAtLearningRate(int i) {
        return !this.useStateWise ? this.universalLR : getFeatureWiseLearningRate(i).learningRate;
    }

    @Override // burlap.behavior.learningrate.LearningRate
    public double pollLearningRate(int i, int i2) {
        if (!this.useStateWise) {
            double d = this.universalLR;
            if (i > this.lastPollTime) {
                this.universalLR = nextLRVal(d);
                this.lastPollTime = i;
            }
            return d;
        }
        StateWiseLearningRate featureWiseLearningRate = getFeatureWiseLearningRate(i2);
        double d2 = featureWiseLearningRate.learningRate;
        if (i > featureWiseLearningRate.lastPollTime) {
            featureWiseLearningRate.learningRate = nextLRVal(d2);
            featureWiseLearningRate.lastPollTime = i;
        }
        return d2;
    }

    @Override // burlap.behavior.learningrate.LearningRate
    public void resetDecay() {
        this.universalLR = this.initialLearningRate;
        this.stateWiseMap.clear();
        this.featureWiseMap.clear();
    }

    protected StateWiseLearningRate getStateWiseLearningRate(State state) {
        HashableState hashState = this.hashingFactory.hashState(state);
        StateWiseLearningRate stateWiseLearningRate = this.stateWiseMap.get(hashState);
        if (stateWiseLearningRate == null) {
            stateWiseLearningRate = new StateWiseLearningRate();
            this.stateWiseMap.put(hashState, stateWiseLearningRate);
        }
        return stateWiseLearningRate;
    }

    protected StateWiseLearningRate getFeatureWiseLearningRate(int i) {
        StateWiseLearningRate stateWiseLearningRate = this.featureWiseMap.get(Integer.valueOf(i));
        if (stateWiseLearningRate == null) {
            stateWiseLearningRate = new StateWiseLearningRate();
            this.featureWiseMap.put(Integer.valueOf(i), stateWiseLearningRate);
        }
        return stateWiseLearningRate;
    }

    protected double nextLRVal(double d) {
        return Math.max(d * this.decayRate, this.minimumLR);
    }
}
