Source code for lagom.agent

from abc import ABC
from abc import abstractmethod

from lagom.networks import Module


[docs]class BaseAgent(Module, ABC): r"""Base class for all agents. The agent could select an action from some data (e.g. observation) and update itself by defining a certain learning mechanism. Any agent should subclass this class, e.g. policy-based or value-based. Args: config (dict): a dictionary of configurations env (Env): environment object. device (Device): a PyTorch device **kwargs: keyword aguments used to specify the agent """ def __init__(self, config, env, device, **kwargs): super(Module, self).__init__(**kwargs) self.config = config self.env = env self.device = device self.info = {} self.is_recurrent = None
[docs] @abstractmethod def choose_action(self, x, **kwargs): r"""Returns the selected action given the data. .. note:: It's recommended to handle all dtype/device conversions between CPU/GPU or Tensor/Numpy here. The output is a dictionary containing useful items, Args: obs (object): batched observation returned from the environment. First dimension is treated as batch dimension. **kwargs: keyword arguments to specify action selection. Returns: dict: a dictionary of action selection output. It contains all useful information (e.g. action, action_logprob, state_value). This allows the API to be generic and compatible with different kinds of runner and agents. """ pass
[docs] @abstractmethod def learn(self, D, **kwargs): r"""Defines learning mechanism to update the agent from a batched data. Args: D (list): a list of batched data to train the agent e.g. in policy gradient, this can be a list of :class:`Trajectory`. **kwargs: keyword arguments to specify learning mechanism Returns: dict: a dictionary of learning output. This could contain the loss and other useful metrics. """ pass
[docs]class RandomAgent(BaseAgent): r"""A random agent samples action uniformly from action space. """
[docs] def choose_action(self, x, **kwargs): if hasattr(self.env, 'num_envs'): action = [self.env.action_space.sample() for _ in range(self.env.num_envs)] else: action = self.env.action_space.sample() out = {'raw_action': action} return out
[docs] def learn(self, D, **kwargs): pass