Source code for lagom.agent

from abc import ABC
from abc import abstractmethod

from lagom.networks import Module
from lagom.envs import VecEnv


[docs]class BaseAgent(Module, ABC): r"""Base class for all agents. The agent could select an action from a given observation and update itself by defining a certain learning mechanism. Any agent should subclass this class, e.g. policy-based or value-based. .. note:: All agents should by default handle batched data e.g. batched observation returned from :class:`VecEnv` and batched action for each sub-environment of a :class:`VecEnv`. Args: config (dict): a dictionary of configurations env (VecEnv): environment object. device (Device): a PyTorch device **kwargs: keyword aguments used to specify the agent """ def __init__(self, config, env, device, **kwargs): super(Module, self).__init__(**kwargs) self.config = config self.env = env self.device = device self.info = {} self.is_recurrent = None
[docs] @abstractmethod def choose_action(self, obs, **kwargs): r"""Returns an (batched) action selected by the agent from received (batched) observation/ .. note:: Tensor conversion should be handled here instead of in policy or network forward pass. The output is a dictionary containing useful items, e.g. action, action_logprob, state_value Args: obs (object): batched observation returned from the environment. First dimension is treated as batch dimension. **kwargs: keyword arguments to specify action selection. Returns ------- out : dict a dictionary of action selection output. It should also contain all useful information to be stored during interaction with :class:`BaseRunner`. This allows a generic API of the runner classes for all kinds of agents. Note that everything should be batched even if for scalar loss, i.e. ``scalar_loss -> [scalar_loss]`` """ pass
[docs] @abstractmethod def learn(self, D, **kwargs): r"""Defines learning mechanism to update the agent from a batched data. Args: D (list): a list of batched data to train the agent e.g. in policy gradient, this can be a list of :class:`Trajectory` or :class:`Segment` **kwargs: keyword arguments to specify learning mechanism Returns ------- out : dict a dictionary of learning output. This could contain the loss. """ pass
[docs]class RandomAgent(BaseAgent): r"""A random agent samples action uniformly from action space. """
[docs] def choose_action(self, obs, **kwargs): if isinstance(self.env, VecEnv): action = [self.env.action_space.sample() for _ in range(len(self.env))] else: action = self.env.action_space.sample() out = {'raw_action': action} return out
[docs] def learn(self, D, **kwargs): pass