Source code for lagom.envs.wrappers.frame_stack

from collections import deque
from lz4.block import compress
from lz4.block import decompress

import numpy as np

from gym.spaces import Box
from gym import ObservationWrapper


[docs]class LazyFrames(object): r"""Ensures common frames are only stored once to optimize memory use. To further reduce the memory use, it is optionally to turn on lz4 to compress the observations. .. note:: This object should only be converted to numpy array just before forward pass. """ def __init__(self, frames, lz4_compress=False): if lz4_compress: self.shape = frames[0].shape frames = [compress(frame) for frame in frames] self._frames = frames self.lz4_compress = lz4_compress def __array__(self, dtype=None): if self.lz4_compress: frames = [np.frombuffer(decompress(frame), dtype=np.uint8).reshape(self.shape) for frame in self._frames] else: frames = self._frames out = np.stack(frames, axis=0) if dtype is not None: out = out.astype(dtype) return out def __len__(self): return len(self.__array__()) def __getitem__(self, i): return self.__array__()[i]
[docs]class FrameStack(ObservationWrapper): r"""Observation wrapper that stacks the observations in a rolling manner. For example, if the number os stacks is 4, then returned observation constains the most recent 4 observations. For environment 'Pendulum-v0', the original observation is an array with shape [3], so if we stack 4 observations, the processed observation has shape [3, 4]. .. note:: To be memory efficient, the stacked observations are wrapped by :class:`LazyFrame`. .. note:: The observation space must be `Box` type. If one uses `Dict` as observation space, it should apply `FlattenDictWrapper` at first. Example:: >>> import gym >>> env = gym.make('PongNoFrameskip-v0') >>> env = FrameStack(env, 4) >>> env.observation_space Box(4, 210, 160, 3) Args: env (Env): environment object num_stack (int): number of stacks """ def __init__(self, env, num_stack, lz4_compress=False): super().__init__(env) self.num_stack = num_stack self.lz4_compress = lz4_compress self.frames = deque(maxlen=num_stack) low = np.repeat(self.observation_space.low[np.newaxis, ...], num_stack, axis=0) high = np.repeat(self.observation_space.high[np.newaxis, ...], num_stack, axis=0) self.observation_space = Box(low=low, high=high, dtype=self.observation_space.dtype) def _get_observation(self): assert len(self.frames) == self.num_stack return LazyFrames(list(self.frames), self.lz4_compress)
[docs] def step(self, action): observation, reward, done, info = self.env.step(action) self.frames.append(observation) return self._get_observation(), reward, done, info
[docs] def reset(self, **kwargs): observation = self.env.reset(**kwargs) [self.frames.append(observation) for _ in range(self.num_stack)] return self._get_observation()