diff --git a/torchy_baselines/common/vec_env/vec_frame_stack.py b/torchy_baselines/common/vec_env/vec_frame_stack.py index 562c525..6676162 100644 --- a/torchy_baselines/common/vec_env/vec_frame_stack.py +++ b/torchy_baselines/common/vec_env/vec_frame_stack.py @@ -3,18 +3,18 @@ import warnings import numpy as np from gym import spaces -from torchy_baselines.common.vec_env.base_vec_env import VecEnvWrapper +from torchy_baselines.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper class VecFrameStack(VecEnvWrapper): """ Frame stacking wrapper for vectorized environment - :param venv: (VecEnv) the vectorized environment to wrap - :param n_stack: (int) Number of frames to stack + :param venv: the vectorized environment to wrap + :param n_stack: Number of frames to stack """ - def __init__(self, venv, n_stack): + def __init__(self, venv: VecEnv, n_stack: int): self.venv = venv self.n_stack = n_stack wrapped_obs_space = venv.observation_space