diff --git a/setup.cfg b/setup.cfg index 708de33..c4d1055 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,6 +15,7 @@ filterwarnings = # Gym warnings ignore:Parameters to load are deprecated.:DeprecationWarning ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning + ignore::UserWarning:gym [pytype] inputs = torchy_baselines diff --git a/torchy_baselines/common/preprocessing.py b/torchy_baselines/common/preprocessing.py index e117287..fa312bf 100644 --- a/torchy_baselines/common/preprocessing.py +++ b/torchy_baselines/common/preprocessing.py @@ -78,6 +78,7 @@ def get_obs_shape(observation_space: spaces.Space) -> Tuple[int, ...]: def get_obs_dim(observation_space: spaces.Space) -> Union[int, Tuple[int, ...]]: """ Get the dimension of the observation space. + It should not be used when using images. :param observation_space: (spaces.Space) :return: (Union[int, Tuple[int, ...]]) diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index 0aea585..51473af 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -162,8 +162,7 @@ class PPO(BaseRLModel): with th.no_grad(): # Convert to pytorch tensor - obs_tensor = obs.reshape((-1,) + self.observation_space.shape) - obs_tensor = th.as_tensor(obs_tensor).to(self.device) + obs_tensor = th.as_tensor(obs).to(self.device) actions, values, log_probs = self.policy.forward(obs_tensor) actions = actions.cpu().numpy()