Remove gym warnings

This commit is contained in:
Antonin RAFFIN 2020-03-25 15:54:58 +01:00
parent 72a88a8d92
commit 52d2cd6a1b
3 changed files with 3 additions and 2 deletions

View file

@ -15,6 +15,7 @@ filterwarnings =
# Gym warnings
ignore:Parameters to load are deprecated.:DeprecationWarning
ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning
ignore::UserWarning:gym
[pytype]
inputs = torchy_baselines

View file

@ -78,6 +78,7 @@ def get_obs_shape(observation_space: spaces.Space) -> Tuple[int, ...]:
def get_obs_dim(observation_space: spaces.Space) -> Union[int, Tuple[int, ...]]:
"""
Get the dimension of the observation space.
It should not be used when using images.
:param observation_space: (spaces.Space)
:return: (Union[int, Tuple[int, ...]])

View file

@ -162,8 +162,7 @@ class PPO(BaseRLModel):
with th.no_grad():
# Convert to pytorch tensor
obs_tensor = obs.reshape((-1,) + self.observation_space.shape)
obs_tensor = th.as_tensor(obs_tensor).to(self.device)
obs_tensor = th.as_tensor(obs).to(self.device)
actions, values, log_probs = self.policy.forward(obs_tensor)
actions = actions.cpu().numpy()