Source code for lagom.envs.wrappers.resize_observation

import numpy as np
import cv2

from gym.spaces import Box
from gym import ObservationWrapper


[docs]class ResizeObservation(ObservationWrapper): r"""Downsample the image observation to a square image. """ def __init__(self, env, size): super().__init__(env) assert size > 0 self.size = size shape = (self.size, self.size) + self.observation_space.shape[2:] self.observation_space = Box(low=0, high=255, shape=shape, dtype=np.uint8) def observation(self, observation): observation = cv2.resize(observation, (self.size, self.size), interpolation=cv2.INTER_AREA) if observation.ndim == 2: observation = np.expand_dims(observation, -1) return observation