package burlap.mdp.stochasticgames.world;

import burlap.behavior.stochasticgames.GameEpisode;
import burlap.behavior.stochasticgames.JointPolicy;
import burlap.datastructures.HashedAggregator;
import burlap.debugtools.DPrint;
import burlap.mdp.auxiliary.StateGenerator;
import burlap.mdp.auxiliary.StateMapping;
import burlap.mdp.auxiliary.common.ConstantStateGenerator;
import burlap.mdp.auxiliary.common.IdentityStateMapping;
import burlap.mdp.core.TerminalFunction;
import burlap.mdp.core.state.State;
import burlap.mdp.stochasticgames.JointAction;
import burlap.mdp.stochasticgames.SGDomain;
import burlap.mdp.stochasticgames.agent.SGAgent;
import burlap.mdp.stochasticgames.agent.SGAgentType;
import burlap.mdp.stochasticgames.model.JointModel;
import burlap.mdp.stochasticgames.model.JointRewardFunction;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:burlap/mdp/stochasticgames/world/World.class */
public class World {
    protected SGDomain domain;
    protected State currentState;
    protected List<SGAgent> agents;
    protected HashedAggregator<String> agentCumulativeReward;
    protected JointModel worldModel;
    protected JointRewardFunction jointRewardFunction;
    protected TerminalFunction tf;
    protected StateGenerator initialStateGenerator;
    protected StateMapping abstractionForAgents;
    protected JointAction lastJointAction;
    protected List<WorldObserver> worldObservers;
    protected GameEpisode currentGameEpisodeRecord;
    protected boolean isRecordingGame = false;
    protected int debugId;
    protected double[] lastRewards;

    public World(SGDomain sGDomain, JointRewardFunction jointRewardFunction, TerminalFunction terminalFunction, State state) {
        init(sGDomain, sGDomain.getJointActionModel(), jointRewardFunction, terminalFunction, new ConstantStateGenerator(state), new IdentityStateMapping());
    }

    public World(SGDomain sGDomain, JointRewardFunction jointRewardFunction, TerminalFunction terminalFunction, StateGenerator stateGenerator) {
        init(sGDomain, sGDomain.getJointActionModel(), jointRewardFunction, terminalFunction, stateGenerator, new IdentityStateMapping());
    }

    public World(SGDomain sGDomain, JointRewardFunction jointRewardFunction, TerminalFunction terminalFunction, StateGenerator stateGenerator, StateMapping stateMapping) {
        init(sGDomain, sGDomain.getJointActionModel(), jointRewardFunction, terminalFunction, stateGenerator, stateMapping);
    }

    protected void init(SGDomain sGDomain, JointModel jointModel, JointRewardFunction jointRewardFunction, TerminalFunction terminalFunction, StateGenerator stateGenerator, StateMapping stateMapping) {
        this.domain = sGDomain;
        this.worldModel = jointModel;
        this.jointRewardFunction = jointRewardFunction;
        this.tf = terminalFunction;
        this.initialStateGenerator = stateGenerator;
        this.abstractionForAgents = stateMapping;
        this.agents = new ArrayList();
        this.agentCumulativeReward = new HashedAggregator<>();
        this.worldObservers = new ArrayList();
        generateNewCurrentState();
        this.debugId = 284673923;
    }

    public SGDomain getDomain() {
        return this.domain;
    }

    public void setDomain(SGDomain sGDomain) {
        this.domain = sGDomain;
    }

    public int getDebugId() {
        return this.debugId;
    }

    public void setDebugId(int i) {
        this.debugId = i;
    }

    public double getCumulativeRewardForAgent(String str) {
        return this.agentCumulativeReward.v(str);
    }

    public void join(SGAgent sGAgent) {
        if (agentWithName(sGAgent.agentName()) != null) {
            throw new RuntimeException("Agent with provided name has already joined.");
        }
        this.agents.add(sGAgent);
    }

    public SGAgent agentWithName(String str) {
        for (SGAgent sGAgent : this.agents) {
            if (sGAgent.agentName().equals(str)) {
                return sGAgent;
            }
        }
        return null;
    }

    public State getCurrentWorldState() {
        return this.currentState;
    }

    public void generateNewCurrentState() {
        if (gameIsRunning()) {
            return;
        }
        this.currentState = this.initialStateGenerator.generateState();
    }

    public boolean worldStateIsTerminal() {
        return this.tf.isTerminal(this.currentState);
    }

    public void setCurrentState(State state) {
        if (gameIsRunning()) {
            return;
        }
        this.currentState = state;
    }

    public JointAction getLastJointAction() {
        return this.lastJointAction;
    }

    public double[] getLastRewards() {
        return this.lastRewards;
    }

    public void addWorldObserver(WorldObserver worldObserver) {
        this.worldObservers.add(worldObserver);
    }

    public void removeWorldObserver(WorldObserver worldObserver) {
        this.worldObservers.remove(worldObserver);
    }

    public void clearAllWorldObserver() {
        this.worldObservers.clear();
    }

    public void executeJointAction(JointAction jointAction) {
        if (gameIsRunning() || this.tf.isTerminal(this.currentState)) {
            return;
        }
        State state = this.currentState;
        this.currentState = this.worldModel.sample(this.currentState, jointAction);
        double[] reward = this.jointRewardFunction.reward(state, jointAction, this.currentState);
        this.lastRewards = reward;
        this.lastJointAction = jointAction;
        Iterator<WorldObserver> it = this.worldObservers.iterator();
        while (it.hasNext()) {
            it.next().observe(state, jointAction, reward, this.currentState);
        }
    }

    public GameEpisode runGame() {
        return runGame(-1);
    }

    public GameEpisode runGame(int i) {
        return runGame(i, this.initialStateGenerator.generateState());
    }

    public GameEpisode runGame(int i, State state) {
        int i2 = 0;
        Iterator<SGAgent> it = this.agents.iterator();
        while (it.hasNext()) {
            it.next().gameStarting(this, i2);
            i2++;
        }
        this.currentState = state;
        this.currentGameEpisodeRecord = new GameEpisode(this.currentState);
        this.isRecordingGame = true;
        Iterator<WorldObserver> it2 = this.worldObservers.iterator();
        while (it2.hasNext()) {
            it2.next().gameStarting(this.currentState);
        }
        for (int i3 = 0; !this.tf.isTerminal(this.currentState) && (i3 < i || i == -1); i3++) {
            runStage();
        }
        Iterator<SGAgent> it3 = this.agents.iterator();
        while (it3.hasNext()) {
            it3.next().gameTerminated();
        }
        Iterator<WorldObserver> it4 = this.worldObservers.iterator();
        while (it4.hasNext()) {
            it4.next().gameEnding(this.currentState);
        }
        DPrint.cl(this.debugId, this.currentState.toString());
        this.isRecordingGame = false;
        return this.currentGameEpisodeRecord;
    }

    public GameEpisode rolloutJointPolicy(JointPolicy jointPolicy, int i) {
        this.currentState = this.initialStateGenerator.generateState();
        this.currentGameEpisodeRecord = new GameEpisode(this.currentState);
        this.isRecordingGame = true;
        for (int i2 = 0; !this.tf.isTerminal(this.currentState) && i2 < i; i2++) {
            rolloutOneStageOfJointPolicy(jointPolicy);
        }
        this.isRecordingGame = false;
        return this.currentGameEpisodeRecord;
    }

    public GameEpisode rolloutJointPolicyFromState(JointPolicy jointPolicy, State state, int i) {
        this.currentState = state;
        this.currentGameEpisodeRecord = new GameEpisode(this.currentState);
        this.isRecordingGame = true;
        for (int i2 = 0; !this.tf.isTerminal(this.currentState) && i2 < i; i2++) {
            rolloutOneStageOfJointPolicy(jointPolicy);
        }
        this.isRecordingGame = false;
        return this.currentGameEpisodeRecord;
    }

    public void runStage() {
        if (this.tf.isTerminal(this.currentState)) {
            return;
        }
        JointAction jointAction = new JointAction();
        State mapState = this.abstractionForAgents.mapState(this.currentState);
        Iterator<SGAgent> it = this.agents.iterator();
        while (it.hasNext()) {
            jointAction.addAction(it.next().action(mapState));
        }
        this.lastJointAction = jointAction;
        DPrint.cl(this.debugId, jointAction.toString());
        State sample = this.worldModel.sample(this.currentState, jointAction);
        State mapState2 = this.abstractionForAgents.mapState(sample);
        double[] reward = this.jointRewardFunction.reward(this.currentState, jointAction, sample);
        DPrint.cl(this.debugId, reward.toString());
        for (int i = 0; i < reward.length; i++) {
            this.agentCumulativeReward.add(this.agents.get(i).agentName(), reward[i]);
        }
        Iterator<SGAgent> it2 = this.agents.iterator();
        while (it2.hasNext()) {
            it2.next().observeOutcome(mapState, jointAction, reward, mapState2, this.tf.isTerminal(sample));
        }
        Iterator<WorldObserver> it3 = this.worldObservers.iterator();
        while (it3.hasNext()) {
            it3.next().observe(this.currentState, jointAction, reward, sample);
        }
        this.currentState = sample;
        this.lastRewards = reward;
        if (this.isRecordingGame) {
            this.currentGameEpisodeRecord.transition(this.lastJointAction, this.currentState, reward);
        }
    }

    protected void rolloutOneStageOfJointPolicy(JointPolicy jointPolicy) {
        if (this.tf.isTerminal(this.currentState)) {
            return;
        }
        this.lastJointAction = (JointAction) jointPolicy.action(this.currentState);
        DPrint.cl(this.debugId, this.lastJointAction.toString());
        State sample = this.worldModel.sample(this.currentState, this.lastJointAction);
        double[] reward = this.jointRewardFunction.reward(this.currentState, this.lastJointAction, sample);
        DPrint.cl(this.debugId, reward.toString());
        for (int i = 0; i < reward.length; i++) {
            this.agentCumulativeReward.add(this.agents.get(i).agentName(), reward[i]);
        }
        Iterator<WorldObserver> it = this.worldObservers.iterator();
        while (it.hasNext()) {
            it.next().observe(this.currentState, this.lastJointAction, reward, sample);
        }
        this.currentState = sample;
        this.lastRewards = reward;
        if (this.isRecordingGame) {
            this.currentGameEpisodeRecord.transition(this.lastJointAction, this.currentState, reward);
        }
    }

    public JointModel getActionModel() {
        return this.worldModel;
    }

    public JointRewardFunction getRewardFunction() {
        return this.jointRewardFunction;
    }

    public TerminalFunction getTF() {
        return this.tf;
    }

    public List<SGAgent> getRegisteredAgents() {
        return new ArrayList(this.agents);
    }

    public List<SGAgentType> getAgentDefinitions() {
        ArrayList arrayList = new ArrayList(this.agents.size());
        Iterator<SGAgent> it = this.agents.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().agentType());
        }
        return arrayList;
    }

    public int getPlayerNumberForAgent(String str) {
        for (int i = 0; i < this.agents.size(); i++) {
            if (this.agents.get(i).agentName().equals(str)) {
                return i;
            }
        }
        return -1;
    }

    public boolean gameIsRunning() {
        return this.isRecordingGame;
    }
}
