from collections import deque
from lz4.block import compress
from lz4.block import decompress
import numpy as np
from gym.spaces import Box
from gym import ObservationWrapper
[docs]class LazyFrames(object):
r"""Ensures common frames are only stored once to optimize memory use.
To further reduce the memory use, it is optionally to turn on lz4 to
compress the observations.
.. note::
This object should only be converted to numpy array just before forward pass.
"""
def __init__(self, frames, lz4_compress=False):
if lz4_compress:
self.shape = frames[0].shape
frames = [compress(frame) for frame in frames]
self._frames = frames
self.lz4_compress = lz4_compress
def __array__(self, dtype=None):
if self.lz4_compress:
frames = [np.frombuffer(decompress(frame), dtype=np.uint8).reshape(self.shape) for frame in self._frames]
else:
frames = self._frames
out = np.stack(frames, axis=0)
if dtype is not None:
out = out.astype(dtype)
return out
def __len__(self):
return len(self.__array__())
def __getitem__(self, i):
return self.__array__()[i]
[docs]class FrameStack(ObservationWrapper):
r"""Observation wrapper that stacks the observations in a rolling manner.
For example, if the number os stacks is 4, then returned observation constains
the most recent 4 observations. For environment 'Pendulum-v0', the original observation
is an array with shape [3], so if we stack 4 observations, the processed observation
has shape [3, 4].
.. note::
To be memory efficient, the stacked observations are wrapped by :class:`LazyFrame`.
.. note::
The observation space must be `Box` type. If one uses `Dict`
as observation space, it should apply `FlattenDictWrapper` at first.
Example::
>>> import gym
>>> env = gym.make('PongNoFrameskip-v0')
>>> env = FrameStack(env, 4)
>>> env.observation_space
Box(4, 210, 160, 3)
Args:
env (Env): environment object
num_stack (int): number of stacks
"""
def __init__(self, env, num_stack, lz4_compress=False):
super().__init__(env)
self.num_stack = num_stack
self.lz4_compress = lz4_compress
self.frames = deque(maxlen=num_stack)
low = np.repeat(self.observation_space.low[np.newaxis, ...], num_stack, axis=0)
high = np.repeat(self.observation_space.high[np.newaxis, ...], num_stack, axis=0)
self.observation_space = Box(low=low, high=high, dtype=self.observation_space.dtype)
def _get_observation(self):
assert len(self.frames) == self.num_stack
return LazyFrames(list(self.frames), self.lz4_compress)
[docs] def step(self, action):
observation, reward, done, info = self.env.step(action)
self.frames.append(observation)
return self._get_observation(), reward, done, info
[docs] def reset(self, **kwargs):
observation = self.env.reset(**kwargs)
[self.frames.append(observation) for _ in range(self.num_stack)]
return self._get_observation()