package burlap.behavior.singleagent.learning.lspi;

import burlap.debugtools.RandomFactory;
import burlap.mdp.auxiliary.StateGenerator;
import burlap.mdp.core.action.Action;
import burlap.mdp.core.action.ActionType;
import burlap.mdp.core.action.ActionUtils;
import burlap.mdp.core.state.State;
import burlap.mdp.singleagent.SADomain;
import burlap.mdp.singleagent.environment.Environment;
import burlap.mdp.singleagent.environment.EnvironmentOutcome;
import burlap.mdp.singleagent.model.SampleModel;
import java.util.List;

/* loaded from: input_file:burlap/behavior/singleagent/learning/lspi/SARSCollector.class */
public abstract class SARSCollector {
    protected List<ActionType> actionTypes;

    /* loaded from: input_file:burlap/behavior/singleagent/learning/lspi/SARSCollector$UniformRandomSARSCollector.class */
    public static class UniformRandomSARSCollector extends SARSCollector {
        public UniformRandomSARSCollector(SADomain sADomain) {
            super(sADomain);
        }

        public UniformRandomSARSCollector(List<ActionType> list) {
            super(list);
        }

        @Override // burlap.behavior.singleagent.learning.lspi.SARSCollector
        public SARSData collectDataFrom(State state, SampleModel sampleModel, int i, SARSData sARSData) {
            if (sARSData == null) {
                sARSData = new SARSData();
            }
            State state2 = state;
            for (int i2 = 0; 0 == 0 && i2 < i; i2++) {
                List<Action> allApplicableActionsForTypes = ActionUtils.allApplicableActionsForTypes(this.actionTypes, state2);
                Action action = allApplicableActionsForTypes.get(RandomFactory.getMapped(0).nextInt(allApplicableActionsForTypes.size()));
                EnvironmentOutcome sample = sampleModel.sample(state2, action);
                sARSData.add(state2, action, sample.r, sample.op);
                state2 = sample.op;
            }
            return sARSData;
        }

        @Override // burlap.behavior.singleagent.learning.lspi.SARSCollector
        public SARSData collectDataFrom(Environment environment, int i, SARSData sARSData) {
            if (sARSData == null) {
                sARSData = new SARSData();
            }
            for (int i2 = 0; !environment.isInTerminalState() && i2 < i; i2++) {
                List<Action> allApplicableActionsForTypes = ActionUtils.allApplicableActionsForTypes(this.actionTypes, environment.currentObservation());
                EnvironmentOutcome executeAction = environment.executeAction(allApplicableActionsForTypes.get(RandomFactory.getMapped(0).nextInt(allApplicableActionsForTypes.size())));
                sARSData.add(executeAction.o, executeAction.a, executeAction.r, executeAction.op);
            }
            return sARSData;
        }
    }

    public SARSCollector(SADomain sADomain) {
        this.actionTypes = sADomain.getActionTypes();
    }

    public SARSCollector(List<ActionType> list) {
        this.actionTypes = list;
    }

    public abstract SARSData collectDataFrom(State state, SampleModel sampleModel, int i, SARSData sARSData);

    public abstract SARSData collectDataFrom(Environment environment, int i, SARSData sARSData);

    public SARSData collectNInstances(StateGenerator stateGenerator, SampleModel sampleModel, int i, int i2, SARSData sARSData) {
        if (sARSData == null) {
            sARSData = new SARSData(i);
        }
        while (i > 0) {
            int min = Math.min(i, i2);
            int size = sARSData.size();
            collectDataFrom(stateGenerator.generateState(), sampleModel, min, sARSData);
            i -= sARSData.size() - size;
        }
        return sARSData;
    }

    public SARSData collectNInstances(Environment environment, int i, int i2, SARSData sARSData) {
        if (sARSData == null) {
            sARSData = new SARSData(i);
        }
        while (i > 0) {
            int min = Math.min(i, i2);
            int size = sARSData.size();
            collectDataFrom(environment, min, sARSData);
            i -= sARSData.size() - size;
            environment.resetEnvironment();
        }
        return sARSData;
    }
}
