mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Remove gym warnings
This commit is contained in:
parent
72a88a8d92
commit
52d2cd6a1b
3 changed files with 3 additions and 2 deletions
|
|
@ -15,6 +15,7 @@ filterwarnings =
|
|||
# Gym warnings
|
||||
ignore:Parameters to load are deprecated.:DeprecationWarning
|
||||
ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning
|
||||
ignore::UserWarning:gym
|
||||
|
||||
[pytype]
|
||||
inputs = torchy_baselines
|
||||
|
|
|
|||
|
|
@ -78,6 +78,7 @@ def get_obs_shape(observation_space: spaces.Space) -> Tuple[int, ...]:
|
|||
def get_obs_dim(observation_space: spaces.Space) -> Union[int, Tuple[int, ...]]:
|
||||
"""
|
||||
Get the dimension of the observation space.
|
||||
It should not be used when using images.
|
||||
|
||||
:param observation_space: (spaces.Space)
|
||||
:return: (Union[int, Tuple[int, ...]])
|
||||
|
|
|
|||
|
|
@ -162,8 +162,7 @@ class PPO(BaseRLModel):
|
|||
|
||||
with th.no_grad():
|
||||
# Convert to pytorch tensor
|
||||
obs_tensor = obs.reshape((-1,) + self.observation_space.shape)
|
||||
obs_tensor = th.as_tensor(obs_tensor).to(self.device)
|
||||
obs_tensor = th.as_tensor(obs).to(self.device)
|
||||
actions, values, log_probs = self.policy.forward(obs_tensor)
|
||||
actions = actions.cpu().numpy()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue