Source code for lagom.runner

from abc import ABC
from abc import abstractmethod

from lagom.data import StepType
from lagom.data import TimeStep
from lagom.data import Trajectory
from lagom.envs.timestep_env import TimeStepEnv


[docs]class BaseRunner(ABC): r"""Base class for all runners. A runner is a data collection interface between the agent and the environment. """
[docs] @abstractmethod def __call__(self, agent, env, **kwargs): r"""Defines data collection via interactions between the agent and the environment. Args: agent (BaseAgent): agent env (Env): environment **kwargs: keyword arguments for more specifications. """ pass
[docs]class EpisodeRunner(BaseRunner): def __call__(self, agent, env, N, **kwargs): assert isinstance(env, TimeStepEnv) D = [] for _ in range(N): traj = Trajectory() timestep = env.reset() traj.add(timestep, None) while not timestep.last(): out_agent = agent.choose_action(timestep, **kwargs) action = out_agent.pop('raw_action') timestep = env.step(action) timestep.info = {**timestep.info, **out_agent} traj.add(timestep, action) traj.extra_info['last_info'] = agent.choose_action(timestep, last_info=True, **kwargs) D.append(traj) return D
[docs]class StepRunner(BaseRunner): def __init__(self, reset_on_call=True): self.reset_on_call = reset_on_call self.observation = None def __call__(self, agent, env, T, **kwargs): assert isinstance(env, TimeStepEnv) D = [] traj = Trajectory() if self.reset_on_call or self.observation is None: timestep = env.reset() else: timestep = TimeStep(StepType.FIRST, observation=self.observation, reward=None, done=None, info=None) traj.add(timestep, None) for t in range(T): out_agent = agent.choose_action(timestep, **kwargs) action = out_agent.pop('raw_action') timestep = env.step(action) timestep.info = {**timestep.info, **out_agent} traj.add(timestep, action) if timestep.last(): traj.extra_info['last_info'] = agent.choose_action(timestep, last_info=True, **kwargs) D.append(traj) traj = Trajectory() timestep = env.reset() traj.add(timestep, None) if traj.T > 0: traj.extra_info['last_info'] = agent.choose_action(timestep, last_info=True, **kwargs) D.append(traj) if not self.reset_on_call: self.observation = timestep.observation return D