package burlap.behavior.stochasticgames.agents.naiveq;

import burlap.behavior.learningrate.ConstantLR;
import burlap.behavior.learningrate.LearningRate;
import burlap.behavior.policy.EpsilonGreedy;
import burlap.behavior.policy.Policy;
import burlap.behavior.valuefunction.ConstantValueFunction;
import burlap.behavior.valuefunction.QFunction;
import burlap.behavior.valuefunction.QProvider;
import burlap.behavior.valuefunction.QValue;
import burlap.mdp.auxiliary.StateMapping;
import burlap.mdp.auxiliary.common.ShallowIdentityStateMapping;
import burlap.mdp.core.action.Action;
import burlap.mdp.core.action.ActionUtils;
import burlap.mdp.core.state.State;
import burlap.mdp.stochasticgames.JointAction;
import burlap.mdp.stochasticgames.SGDomain;
import burlap.mdp.stochasticgames.agent.SGAgentBase;
import burlap.mdp.stochasticgames.agent.SGAgentType;
import burlap.mdp.stochasticgames.world.World;
import burlap.statehashing.HashableState;
import burlap.statehashing.HashableStateFactory;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:burlap/behavior/stochasticgames/agents/naiveq/SGNaiveQLAgent.class */
public class SGNaiveQLAgent extends SGAgentBase implements QProvider {
    protected Map<HashableState, List<QValue>> qMap;
    protected Map<HashableState, State> stateRepresentations;
    protected StateMapping storedMapAbstraction;
    protected double discount;
    protected LearningRate learningRate;
    protected QFunction qInit;
    protected Policy policy;
    protected HashableStateFactory hashFactory;
    protected int agentNum;
    protected int totalNumberOfSteps = 0;

    public SGNaiveQLAgent(SGDomain sGDomain, double d, double d2, HashableStateFactory hashableStateFactory) {
        init(sGDomain);
        this.discount = d;
        this.learningRate = new ConstantLR(Double.valueOf(d2));
        this.hashFactory = hashableStateFactory;
        this.qInit = new ConstantValueFunction(0.0d);
        this.qMap = new HashMap();
        this.stateRepresentations = new HashMap();
        this.policy = new EpsilonGreedy(this, 0.1d);
        this.storedMapAbstraction = new ShallowIdentityStateMapping();
    }

    public SGNaiveQLAgent(SGDomain sGDomain, double d, double d2, double d3, HashableStateFactory hashableStateFactory) {
        init(sGDomain);
        this.discount = d;
        this.learningRate = new ConstantLR(Double.valueOf(d2));
        this.hashFactory = hashableStateFactory;
        this.qInit = new ConstantValueFunction(d3);
        this.qMap = new HashMap();
        this.stateRepresentations = new HashMap();
        this.policy = new EpsilonGreedy(this, 0.1d);
        this.storedMapAbstraction = new ShallowIdentityStateMapping();
    }

    public SGNaiveQLAgent(SGDomain sGDomain, double d, double d2, QFunction qFunction, HashableStateFactory hashableStateFactory) {
        init(sGDomain);
        this.discount = d;
        this.learningRate = new ConstantLR(Double.valueOf(d2));
        this.hashFactory = hashableStateFactory;
        this.qInit = qFunction;
        this.qMap = new HashMap();
        this.stateRepresentations = new HashMap();
        this.policy = new EpsilonGreedy(this, 0.1d);
        this.storedMapAbstraction = new ShallowIdentityStateMapping();
    }

    @Override // burlap.mdp.stochasticgames.agent.SGAgentBase
    public SGNaiveQLAgent setAgentDetails(String str, SGAgentType sGAgentType) {
        this.worldAgentName = str;
        this.agentType = sGAgentType;
        return this;
    }

    public void setStoredMapAbstraction(StateMapping stateMapping) {
        this.storedMapAbstraction = stateMapping;
    }

    public void setStrategy(Policy policy) {
        this.policy = policy;
    }

    public void setQValueInitializer(QFunction qFunction) {
        this.qInit = qFunction;
    }

    public void setLearningRate(LearningRate learningRate) {
        this.learningRate = learningRate;
    }

    @Override // burlap.mdp.stochasticgames.agent.SGAgent
    public void gameStarting(World world, int i) {
        this.world = world;
        this.agentNum = i;
    }

    @Override // burlap.mdp.stochasticgames.agent.SGAgent
    public Action action(State state) {
        return this.policy.action(state);
    }

    @Override // burlap.mdp.stochasticgames.agent.SGAgent
    public void observeOutcome(State state, JointAction jointAction, double[] dArr, State state2, boolean z) {
        if (this.internalRewardFunction != null) {
            dArr = this.internalRewardFunction.reward(state, jointAction, state2);
        }
        Action action = jointAction.action(this.agentNum);
        double d = dArr[this.agentNum];
        QValue storedQ = storedQ(state, action);
        double d2 = 0.0d;
        if (!z) {
            d2 = getMaxQValue(state2);
        }
        storedQ.q += this.learningRate.pollLearningRate(this.totalNumberOfSteps, state, action) * ((d + (this.discount * d2)) - storedQ.q);
        this.totalNumberOfSteps++;
    }

    @Override // burlap.mdp.stochasticgames.agent.SGAgent
    public void gameTerminated() {
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double getMaxQValue(State state) {
        double d = Double.NEGATIVE_INFINITY;
        Iterator<QValue> it = qValues(state).iterator();
        while (it.hasNext()) {
            d = Math.max(d, it.next().q);
        }
        return d;
    }

    protected HashableState stateHash(State state) {
        return this.hashFactory.hashState(this.storedMapAbstraction.mapState(state));
    }

    @Override // burlap.behavior.valuefunction.QProvider
    public List<QValue> qValues(State state) {
        List<Action> allApplicableActionsForTypes = ActionUtils.allApplicableActionsForTypes(this.agentType.actions, state);
        HashableState stateHash = stateHash(state);
        if (this.stateRepresentations.get(stateHash) == null) {
            this.stateRepresentations.put(stateHash, stateHash.s());
            ArrayList arrayList = new ArrayList();
            for (Action action : allApplicableActionsForTypes) {
                arrayList.add(new QValue(stateHash.s(), action, this.qInit.qValue(stateHash.s(), action)));
            }
            this.qMap.put(stateHash, arrayList);
            return arrayList;
        }
        List<QValue> list = this.qMap.get(stateHash);
        ArrayList arrayList2 = new ArrayList(allApplicableActionsForTypes.size());
        for (Action action2 : allApplicableActionsForTypes) {
            boolean z = false;
            Iterator<QValue> it = list.iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                QValue next = it.next();
                if (next.a.equals(action2)) {
                    arrayList2.add(next);
                    z = true;
                    break;
                }
            }
            if (!z) {
                QValue qValue = new QValue(stateHash.s(), action2, this.qInit.qValue(stateHash.s(), action2));
                list.add(qValue);
                arrayList2.add(qValue);
            }
        }
        if (arrayList2.isEmpty()) {
            throw new RuntimeException();
        }
        return arrayList2;
    }

    @Override // burlap.behavior.valuefunction.ValueFunction
    public double value(State state) {
        return QProvider.Helper.maxQ(this, state);
    }

    @Override // burlap.behavior.valuefunction.QFunction
    public double qValue(State state, Action action) {
        return storedQ(state, action).q;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public QValue storedQ(State state, Action action) {
        HashableState stateHash = stateHash(state);
        State state2 = this.stateRepresentations.get(stateHash);
        if (state2 == null) {
            this.stateRepresentations.put(stateHash, stateHash.s());
            QValue qValue = new QValue(state2, action, this.qInit.qValue(stateHash.s(), action));
            ArrayList arrayList = new ArrayList();
            arrayList.add(qValue);
            this.qMap.put(stateHash, arrayList);
            return qValue;
        }
        List<QValue> list = this.qMap.get(stateHash);
        for (QValue qValue2 : list) {
            if (qValue2.a.equals(action)) {
                return qValue2;
            }
        }
        QValue qValue3 = new QValue(stateHash.s(), action, this.qInit.qValue(stateHash.s(), action));
        list.add(qValue3);
        return qValue3;
    }
}
