mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-16 21:10:08 +00:00
55 lines
2 KiB
Python
55 lines
2 KiB
Python
import warnings
|
|
|
|
import numpy as np
|
|
from gym import spaces
|
|
|
|
from torchy_baselines.common.vec_env import VecEnvWrapper
|
|
|
|
|
|
class VecFrameStack(VecEnvWrapper):
|
|
"""
|
|
Frame stacking wrapper for vectorized environment
|
|
|
|
:param venv: (VecEnv) the vectorized environment to wrap
|
|
:param n_stack: (int) Number of frames to stack
|
|
"""
|
|
|
|
def __init__(self, venv, n_stack):
|
|
self.venv = venv
|
|
self.n_stack = n_stack
|
|
wrapped_obs_space = venv.observation_space
|
|
low = np.repeat(wrapped_obs_space.low, self.n_stack, axis=-1)
|
|
high = np.repeat(wrapped_obs_space.high, self.n_stack, axis=-1)
|
|
self.stackedobs = np.zeros((venv.num_envs,) + low.shape, low.dtype)
|
|
observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype)
|
|
VecEnvWrapper.__init__(self, venv, observation_space=observation_space)
|
|
|
|
def step_wait(self):
|
|
observations, rewards, dones, infos = self.venv.step_wait()
|
|
last_ax_size = observations.shape[-1]
|
|
self.stackedobs = np.roll(self.stackedobs, shift=-last_ax_size, axis=-1)
|
|
for i, done in enumerate(dones):
|
|
if done:
|
|
if 'terminal_observation' in infos[i]:
|
|
old_terminal = infos[i]['terminal_observation']
|
|
new_terminal = np.concatenate(
|
|
(self.stackedobs[i, ..., :-last_ax_size], old_terminal), axis=-1)
|
|
infos[i]['terminal_observation'] = new_terminal
|
|
else:
|
|
warnings.warn(
|
|
"VecFrameStack wrapping a VecEnv without terminal_observation info")
|
|
self.stackedobs[i] = 0
|
|
self.stackedobs[..., -observations.shape[-1]:] = observations
|
|
return self.stackedobs, rewards, dones, infos
|
|
|
|
def reset(self):
|
|
"""
|
|
Reset all environments
|
|
"""
|
|
obs = self.venv.reset()
|
|
self.stackedobs[...] = 0
|
|
self.stackedobs[..., -obs.shape[-1]:] = obs
|
|
return self.stackedobs
|
|
|
|
def close(self):
|
|
self.venv.close()
|