package burlap.domain.singleagent.rlglue;

import burlap.behavior.functionapproximation.dense.DenseStateFeatures;
import burlap.mdp.auxiliary.StateGenerator;
import burlap.mdp.core.action.Action;
import burlap.mdp.core.action.ActionType;
import burlap.mdp.core.state.State;
import burlap.mdp.singleagent.SADomain;
import burlap.mdp.singleagent.environment.EnvironmentOutcome;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import org.rlcommunity.rlglue.codec.EnvironmentInterface;
import org.rlcommunity.rlglue.codec.taskspec.TaskSpecVRLGLUE3;
import org.rlcommunity.rlglue.codec.taskspec.ranges.DoubleRange;
import org.rlcommunity.rlglue.codec.taskspec.ranges.IntRange;
import org.rlcommunity.rlglue.codec.types.Observation;
import org.rlcommunity.rlglue.codec.types.Reward_observation_terminal;
import org.rlcommunity.rlglue.codec.util.EnvironmentLoader;

/* loaded from: input_file:burlap/domain/singleagent/rlglue/RLGlueEnvironment.class */
public class RLGlueEnvironment implements EnvironmentInterface {
    protected SADomain domain;
    protected StateGenerator stateGenerator;
    protected DenseStateFeatures stateFlattener;
    protected DoubleRange[] valueRanges;
    protected DoubleRange rewardRange;
    protected boolean isEpisodic;
    protected double discount;
    protected State curState;
    protected int terminalVisits = 0;
    protected Map<Integer, Action> actionMap = new HashMap();
    protected boolean usedConstructorState = false;

    public RLGlueEnvironment(SADomain sADomain, StateGenerator stateGenerator, DenseStateFeatures denseStateFeatures, DoubleRange[] doubleRangeArr, DoubleRange doubleRange, boolean z, double d) {
        if (sADomain.getModel() == null) {
            throw new RuntimeException("RLGlueEnvironment requires a BURLAP domain with a SampleModel, but the domain does not provide one.");
        }
        this.domain = sADomain;
        this.stateGenerator = stateGenerator;
        this.stateFlattener = denseStateFeatures;
        this.valueRanges = doubleRangeArr;
        this.rewardRange = doubleRange;
        this.isEpisodic = z;
        this.discount = d;
        State generateState = this.stateGenerator.generateState();
        int i = 0;
        Iterator<ActionType> it = this.domain.getActionTypes().iterator();
        while (it.hasNext()) {
            Iterator<Action> it2 = it.next().allApplicableActions(generateState).iterator();
            while (it2.hasNext()) {
                this.actionMap.put(Integer.valueOf(i), it2.next());
                i++;
            }
        }
        this.curState = generateState;
    }

    public void load() {
        new EnvironmentLoader(this).run();
    }

    public void load(String str, String str2) {
        new EnvironmentLoader(str, str2, this).run();
    }

    @Override // org.rlcommunity.rlglue.codec.EnvironmentInterface
    public void env_cleanup() {
    }

    @Override // org.rlcommunity.rlglue.codec.EnvironmentInterface
    public String env_init() {
        TaskSpecVRLGLUE3 taskSpecVRLGLUE3 = new TaskSpecVRLGLUE3();
        if (this.isEpisodic) {
            taskSpecVRLGLUE3.setEpisodic();
        } else {
            taskSpecVRLGLUE3.setContinuing();
        }
        taskSpecVRLGLUE3.setDiscountFactor(this.discount);
        taskSpecVRLGLUE3.setRewardRange(this.rewardRange);
        taskSpecVRLGLUE3.addDiscreteAction(new IntRange(0, this.actionMap.size() - 1));
        for (int i = 0; i < this.valueRanges.length; i++) {
            taskSpecVRLGLUE3.addContinuousObservation(this.valueRanges[i]);
        }
        return taskSpecVRLGLUE3.toTaskSpec();
    }

    @Override // org.rlcommunity.rlglue.codec.EnvironmentInterface
    public String env_message(String str) {
        return "Messages not supported by default BURLAP RLGlueEnvironment";
    }

    @Override // org.rlcommunity.rlglue.codec.EnvironmentInterface
    public Observation env_start() {
        this.terminalVisits = 0;
        if (this.usedConstructorState) {
            this.curState = this.stateGenerator.generateState();
        } else {
            this.usedConstructorState = true;
        }
        return convertIntoObservation(this.curState);
    }

    @Override // org.rlcommunity.rlglue.codec.EnvironmentInterface
    public Reward_observation_terminal env_step(org.rlcommunity.rlglue.codec.types.Action action) {
        EnvironmentOutcome environmentOutcome;
        Action action2 = this.actionMap.get(Integer.valueOf(action.getInt(0)));
        if (this.terminalVisits == 0) {
            environmentOutcome = this.domain.getModel().sample(this.curState, action2);
            if (environmentOutcome.terminated) {
                this.terminalVisits++;
            }
        } else {
            environmentOutcome = new EnvironmentOutcome(this.curState, action2, this.curState, 0.0d, true);
            this.terminalVisits++;
        }
        Observation convertIntoObservation = convertIntoObservation(environmentOutcome.op);
        double d = environmentOutcome.r;
        boolean z = this.terminalVisits > 2;
        this.curState = environmentOutcome.op;
        return new Reward_observation_terminal(d, convertIntoObservation, z);
    }

    protected Observation convertIntoObservation(State state) {
        Observation observation = new Observation(0, this.valueRanges.length);
        double[] features = this.stateFlattener.features(state);
        for (int i = 0; i < features.length; i++) {
            observation.setDouble(i, features[i]);
        }
        return observation;
    }
}
