stable-baselines3/stable_baselines3/common/vec_env/vec_frame_stack.py
Antonin RAFFIN 23afedb254
Auto-formatting with black and isort (#97)
* Add auto formatting with black and isort

* Reformat code

* Ignore typing errors

* Add note about line length

* Add minimum version for isort

* Add commit-checks

* Update docker image

* Fixed lost import (during last merge)

* Fix opencv dependency
2020-07-16 16:12:16 +02:00

53 lines
2 KiB
Python

import warnings
import numpy as np
from gym import spaces
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper
class VecFrameStack(VecEnvWrapper):
"""
Frame stacking wrapper for vectorized environment
:param venv: the vectorized environment to wrap
:param n_stack: Number of frames to stack
"""
def __init__(self, venv: VecEnv, n_stack: int):
self.venv = venv
self.n_stack = n_stack
wrapped_obs_space = venv.observation_space
low = np.repeat(wrapped_obs_space.low, self.n_stack, axis=-1)
high = np.repeat(wrapped_obs_space.high, self.n_stack, axis=-1)
self.stackedobs = np.zeros((venv.num_envs,) + low.shape, low.dtype)
observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype)
VecEnvWrapper.__init__(self, venv, observation_space=observation_space)
def step_wait(self):
observations, rewards, dones, infos = self.venv.step_wait()
last_ax_size = observations.shape[-1]
self.stackedobs = np.roll(self.stackedobs, shift=-last_ax_size, axis=-1)
for i, done in enumerate(dones):
if done:
if "terminal_observation" in infos[i]:
old_terminal = infos[i]["terminal_observation"]
new_terminal = np.concatenate((self.stackedobs[i, ..., :-last_ax_size], old_terminal), axis=-1)
infos[i]["terminal_observation"] = new_terminal
else:
warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info")
self.stackedobs[i] = 0
self.stackedobs[..., -observations.shape[-1] :] = observations
return self.stackedobs, rewards, dones, infos
def reset(self):
"""
Reset all environments
"""
obs = self.venv.reset()
self.stackedobs[...] = 0
self.stackedobs[..., -obs.shape[-1] :] = obs
return self.stackedobs
def close(self):
self.venv.close()