mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-26 22:45:15 +00:00
Improve predict method
This commit is contained in:
parent
9caea35a11
commit
f1a4fa2d3f
12 changed files with 134 additions and 90 deletions
|
|
@ -3,7 +3,7 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Pre-Release 0.2.0a1 (WIP)
|
||||
Pre-Release 0.2.0a2 (WIP)
|
||||
------------------------------
|
||||
|
||||
Breaking Changes:
|
||||
|
|
@ -20,7 +20,8 @@ New Features:
|
|||
- Add methods for saving and loading replay buffer
|
||||
- Add `extend()` method to the buffers
|
||||
- Add `get_vec_normalize_env()` to `BaseRLModel` to retrieve `VecNormalize` wrapper when it exists
|
||||
- Add `¶results_plotter` from Stable Baselines
|
||||
- Add `results_plotter` from Stable Baselines
|
||||
- Improve `predict()` method to handle different type of observations (single, vectorized, ...)
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
|
|||
6
setup.py
6
setup.py
|
|
@ -12,7 +12,9 @@ setup(name='torchy_baselines',
|
|||
'torch>=1.2.0',
|
||||
'cloudpickle',
|
||||
# For reading logs
|
||||
'pandas'
|
||||
'pandas',
|
||||
# Plotting learning curves
|
||||
'matplotlib'
|
||||
],
|
||||
extras_require={
|
||||
'tests': [
|
||||
|
|
@ -45,7 +47,7 @@ setup(name='torchy_baselines',
|
|||
license="MIT",
|
||||
long_description="",
|
||||
long_description_content_type='text/markdown',
|
||||
version="0.2.0a1",
|
||||
version="0.2.0a2",
|
||||
)
|
||||
|
||||
# python setup.py sdist
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ def test_save_load(model_class):
|
|||
|
||||
env.reset()
|
||||
observations = np.array([env.step(env.action_space.sample())[0] for _ in range(10)])
|
||||
observations = np.squeeze(observations)
|
||||
observations = observations.reshape(10, -1)
|
||||
|
||||
# Get dictionary of current parameters
|
||||
params = deepcopy(model.policy.state_dict())
|
||||
|
|
@ -53,7 +53,7 @@ def test_save_load(model_class):
|
|||
params = new_params
|
||||
|
||||
# get selected actions
|
||||
selected_actions = [model.predict(observation, deterministic=True) for observation in observations]
|
||||
selected_actions = model.predict(observations, deterministic=True)
|
||||
|
||||
# Check
|
||||
model.save("test_save.zip")
|
||||
|
|
@ -68,7 +68,7 @@ def test_save_load(model_class):
|
|||
assert th.allclose(params[key], new_params[key]), "Model parameters not the same after save and load."
|
||||
|
||||
# check if model still selects the same actions
|
||||
new_selected_actions = [model.predict(observation, deterministic=True) for observation in observations]
|
||||
new_selected_actions = model.predict(observations, deterministic=True)
|
||||
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
|
||||
|
||||
# check if learn still works
|
||||
|
|
|
|||
|
|
@ -4,4 +4,4 @@ from torchy_baselines.ppo import PPO
|
|||
from torchy_baselines.sac import SAC
|
||||
from torchy_baselines.td3 import TD3
|
||||
|
||||
__version__ = "0.2.0a1"
|
||||
__version__ = "0.2.0a2"
|
||||
|
|
|
|||
|
|
@ -306,21 +306,104 @@ class BaseRLModel(ABC):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
@staticmethod
|
||||
def _is_vectorized_observation(observation: np.ndarray, observation_space: gym.spaces.Space) -> bool:
|
||||
"""
|
||||
For every observation type, detects and validates the shape,
|
||||
then returns whether or not the observation is vectorized.
|
||||
|
||||
:param observation: (np.ndarray) the input observation to validate
|
||||
:param observation_space: (gym.spaces) the observation space
|
||||
:return: (bool) whether the given observation is vectorized or not
|
||||
"""
|
||||
if isinstance(observation_space, gym.spaces.Box):
|
||||
if observation.shape == observation_space.shape:
|
||||
return False
|
||||
elif observation.shape[1:] == observation_space.shape:
|
||||
return True
|
||||
else:
|
||||
raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) +
|
||||
"Box environment, please use {} ".format(observation_space.shape) +
|
||||
"or (n_env, {}) for the observation shape."
|
||||
.format(", ".join(map(str, observation_space.shape))))
|
||||
elif isinstance(observation_space, gym.spaces.Discrete):
|
||||
if observation.shape == (): # A numpy array of a number, has shape empty tuple '()'
|
||||
return False
|
||||
elif len(observation.shape) == 1:
|
||||
return True
|
||||
else:
|
||||
raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) +
|
||||
"Discrete environment, please use (1,) or (n_env, 1) for the observation shape.")
|
||||
elif isinstance(observation_space, gym.spaces.MultiDiscrete):
|
||||
if observation.shape == (len(observation_space.nvec),):
|
||||
return False
|
||||
elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec):
|
||||
return True
|
||||
else:
|
||||
raise ValueError("Error: Unexpected observation shape {} for MultiDiscrete ".format(observation.shape) +
|
||||
"environment, please use ({},) or ".format(len(observation_space.nvec)) +
|
||||
"(n_env, {}) for the observation shape.".format(len(observation_space.nvec)))
|
||||
elif isinstance(observation_space, gym.spaces.MultiBinary):
|
||||
if observation.shape == (observation_space.n,):
|
||||
return False
|
||||
elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n:
|
||||
return True
|
||||
else:
|
||||
raise ValueError("Error: Unexpected observation shape {} for MultiBinary ".format(observation.shape) +
|
||||
"environment, please use ({},) or ".format(observation_space.n) +
|
||||
"(n_env, {}) for the observation shape.".format(observation_space.n))
|
||||
else:
|
||||
raise ValueError("Error: Cannot determine if the observation is vectorized with the space type {}."
|
||||
.format(observation_space))
|
||||
|
||||
def predict(self, observation: np.ndarray,
|
||||
state: Optional[np.ndarray] = None,
|
||||
mask: Optional[np.ndarray] = None,
|
||||
deterministic: bool = False) -> np.ndarray:
|
||||
"""
|
||||
Get the model's action from an observation
|
||||
Get the model's action(s) from an observation
|
||||
|
||||
:param observation: the input observation
|
||||
:param state: The last states (can be None, used in recurrent policies)
|
||||
:param mask: The last masks (can be None, used in recurrent policies)
|
||||
:param deterministic: Whether or not to return deterministic actions.
|
||||
:return: the model's action and the next state (used in recurrent policies)
|
||||
:param observation: (np.ndarray) the input observation
|
||||
:param state: (Optional[np.ndarray]) The last states (can be None, used in recurrent policies)
|
||||
:param mask: (Optional[np.ndarray]) The last masks (can be None, used in recurrent policies)
|
||||
:param deterministic: (bool) Whether or not to return deterministic actions.
|
||||
:return: (np.ndarray) the model's action and the next state (used in recurrent policies)
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
# if state is None:
|
||||
# state = self.initial_state
|
||||
# if mask is None:
|
||||
# mask = [False for _ in range(self.n_envs)]
|
||||
observation = np.array(observation)
|
||||
vectorized_env = self._is_vectorized_observation(observation, self.observation_space)
|
||||
|
||||
observation = observation.reshape((-1,) + self.observation_space.shape)
|
||||
# Convert to float pytorch
|
||||
# TODO: replace with preprocessing
|
||||
observation = th.as_tensor(observation).float().to(self.device)
|
||||
with th.no_grad():
|
||||
actions = self.policy.predict(observation, deterministic=deterministic)
|
||||
# Convert to numpy
|
||||
actions = actions.cpu().numpy()
|
||||
|
||||
# Rescale to proper domain when using squashing
|
||||
# TODO: should not be used for a Gaussian distribution?
|
||||
if isinstance(self.action_space, gym.spaces.Box):
|
||||
actions = self.unscale_action(actions)
|
||||
|
||||
clipped_actions = actions
|
||||
# Clip the actions to avoid out of bound error when using gaussian distribution
|
||||
if isinstance(self.action_space, gym.spaces.Box):
|
||||
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
|
||||
|
||||
if not vectorized_env:
|
||||
if state is not None:
|
||||
raise ValueError("Error: The environment must be vectorized when using recurrent policies.")
|
||||
clipped_actions = clipped_actions[0]
|
||||
|
||||
# TODO: switch to stable baselines API
|
||||
# return clipped_actions, state
|
||||
return clipped_actions
|
||||
|
||||
|
||||
@classmethod
|
||||
def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs):
|
||||
|
|
@ -806,7 +889,9 @@ class OffPolicyRLModel(BaseRLModel):
|
|||
# Warmup phase
|
||||
unscaled_action = np.array([self.action_space.sample()])
|
||||
else:
|
||||
unscaled_action = self.predict(obs, deterministic=not self.use_sde)
|
||||
# Note: we assume that the policy uses tanh to scale the action
|
||||
# We use non-deterministic action in the case of SAC, for TD3, it does not matter
|
||||
unscaled_action = self.predict(obs, deterministic=False)
|
||||
|
||||
# Rescale the action from [low, high] to [-1, 1]
|
||||
scaled_action = self.scale_action(unscaled_action)
|
||||
|
|
|
|||
|
|
@ -1,25 +1,30 @@
|
|||
from typing import Union
|
||||
|
||||
from itertools import zip_longest
|
||||
|
||||
import gym
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
class BasePolicy(nn.Module):
|
||||
"""
|
||||
The base policy object
|
||||
|
||||
:param observation_space: (Gym Space) The observation space of the environment
|
||||
:param action_space: (Gym Space) The action space of the environment
|
||||
:param observation_space: (gym.spaces.Space) The observation space of the environment
|
||||
:param action_space: (gym.spaces.Space) The action space of the environment
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space, action_space, device='cpu'):
|
||||
def __init__(self, observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space, device: Union[th.device, str] = 'cpu'):
|
||||
super(BasePolicy, self).__init__()
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
self.device = device
|
||||
|
||||
@staticmethod
|
||||
def init_weights(module, gain=1):
|
||||
def init_weights(module: nn.Module, gain: float = 1):
|
||||
if type(module) == nn.Linear:
|
||||
nn.init.orthogonal_(module.weight, gain=gain)
|
||||
module.bias.data.fill_(0.0)
|
||||
|
|
@ -27,7 +32,13 @@ class BasePolicy(nn.Module):
|
|||
def forward(self, *_args, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
def save(self, path):
|
||||
def predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
"""
|
||||
Get the action according to the policy for a given observation.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def save(self, path: str) -> None:
|
||||
"""
|
||||
Save model to a given location.
|
||||
|
||||
|
|
@ -35,7 +46,7 @@ class BasePolicy(nn.Module):
|
|||
"""
|
||||
th.save(self.state_dict(), path)
|
||||
|
||||
def load(self, path):
|
||||
def load(self, path: str) -> None:
|
||||
"""
|
||||
Load saved model from path.
|
||||
|
||||
|
|
@ -43,7 +54,7 @@ class BasePolicy(nn.Module):
|
|||
"""
|
||||
self.load_state_dict(th.load(path))
|
||||
|
||||
def load_from_vector(self, vector):
|
||||
def load_from_vector(self, vector: np.ndarray):
|
||||
"""
|
||||
Load parameters from a 1D vector.
|
||||
|
||||
|
|
@ -51,7 +62,7 @@ class BasePolicy(nn.Module):
|
|||
"""
|
||||
th.nn.utils.vector_to_parameters(th.FloatTensor(vector).to(self.device), self.parameters())
|
||||
|
||||
def parameters_to_vector(self):
|
||||
def parameters_to_vector(self) -> np.ndarray:
|
||||
"""
|
||||
Convert the parameters to a 1D vector.
|
||||
|
||||
|
|
|
|||
|
|
@ -162,10 +162,10 @@ class PPOPolicy(BasePolicy):
|
|||
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde,
|
||||
deterministic=deterministic)
|
||||
|
||||
def actor_forward(self, obs, deterministic=False):
|
||||
latent_pi, _, latent_sde = self._get_latent(obs)
|
||||
def predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
latent_pi, _, latent_sde = self._get_latent(observation)
|
||||
action, _ = self._get_action_dist_from_latent(latent_pi, latent_sde, deterministic=deterministic)
|
||||
return action.detach().cpu().numpy()
|
||||
return action
|
||||
|
||||
def evaluate_actions(self, obs, action, deterministic=False):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -129,28 +129,6 @@ class PPO(BaseRLModel):
|
|||
if self.clip_range_vf is not None:
|
||||
self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
|
||||
|
||||
def select_action(self, observation, deterministic=False):
|
||||
# Normally not needed
|
||||
observation = np.array(observation)
|
||||
with th.no_grad():
|
||||
observation = th.FloatTensor(observation.reshape(1, -1)).to(self.device)
|
||||
return self.policy.actor_forward(observation, deterministic=deterministic)
|
||||
|
||||
def predict(self, observation, state=None, mask=None, deterministic=False):
|
||||
"""
|
||||
Get the model's action from an observation
|
||||
|
||||
:param observation: (np.ndarray) the input observation
|
||||
:param state: (np.ndarray) The last states (can be None, used in recurrent policies)
|
||||
:param mask: (np.ndarray) The last masks (can be None, used in recurrent policies)
|
||||
:param deterministic: (bool) Whether or not to return deterministic actions.
|
||||
:return: (np.ndarray, np.ndarray) the model's action and the next state (used in recurrent policies)
|
||||
"""
|
||||
clipped_actions = self.select_action(observation, deterministic=deterministic)
|
||||
if isinstance(self.action_space, gym.spaces.Box):
|
||||
clipped_actions = np.clip(clipped_actions, self.action_space.low, self.action_space.high)
|
||||
return clipped_actions
|
||||
|
||||
def collect_rollouts(self,
|
||||
env: VecEnv,
|
||||
callback: BaseCallback,
|
||||
|
|
|
|||
|
|
@ -129,11 +129,11 @@ class Actor(BaseNetwork):
|
|||
def forward(self, obs, deterministic=False):
|
||||
mean_actions, log_std, latent_sde = self.get_action_dist_params(obs)
|
||||
if self.use_sde:
|
||||
# Note the action is squashed
|
||||
# Note: the action is squashed
|
||||
action, _ = self.action_dist.proba_distribution(mean_actions, log_std, latent_sde,
|
||||
deterministic=deterministic)
|
||||
else:
|
||||
# Note the action is squashed
|
||||
# Note: the action is squashed
|
||||
action, _ = self.action_dist.proba_distribution(mean_actions, log_std,
|
||||
deterministic=deterministic)
|
||||
return action
|
||||
|
|
@ -246,6 +246,8 @@ class SACPolicy(BasePolicy):
|
|||
def forward(self, obs):
|
||||
return self.actor(obs)
|
||||
|
||||
def predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
return self.actor.forward(observation, deterministic)
|
||||
|
||||
MlpPolicy = SACPolicy
|
||||
|
||||
|
|
|
|||
|
|
@ -148,25 +148,6 @@ class SAC(OffPolicyRLModel):
|
|||
self.critic = self.policy.critic
|
||||
self.critic_target = self.policy.critic_target
|
||||
|
||||
def select_action(self, observation):
|
||||
# Normally not needed
|
||||
observation = np.array(observation)
|
||||
with th.no_grad():
|
||||
observation = th.FloatTensor(observation.reshape(1, -1)).to(self.device)
|
||||
return self.actor(observation).cpu().data.numpy()
|
||||
|
||||
def predict(self, observation, state=None, mask=None, deterministic=True):
|
||||
"""
|
||||
Get the model's action from an observation
|
||||
|
||||
:param observation: (np.ndarray) the input observation
|
||||
:param state: (np.ndarray) The last states (can be None, used in recurrent policies)
|
||||
:param mask: (np.ndarray) The last masks (can be None, used in recurrent policies)
|
||||
:param deterministic: (bool) Whether or not to return deterministic actions.
|
||||
:return: (np.ndarray, np.ndarray) the model's action and the next state (used in recurrent policies)
|
||||
"""
|
||||
return self.unscale_action(self.select_action(observation))
|
||||
|
||||
def train(self, gradient_steps: int, batch_size: int = 64):
|
||||
# Update optimizers learning rate
|
||||
optimizers = [self.actor.optimizer, self.critic.optimizer]
|
||||
|
|
|
|||
|
|
@ -277,6 +277,9 @@ class TD3Policy(BasePolicy):
|
|||
def forward(self, obs, deterministic=True):
|
||||
return self.actor(obs, deterministic=deterministic)
|
||||
|
||||
def predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
return self.forward(observation, deterministic)
|
||||
|
||||
|
||||
MlpPolicy = TD3Policy
|
||||
|
||||
|
|
|
|||
|
|
@ -114,25 +114,6 @@ class TD3(OffPolicyRLModel):
|
|||
self.critic_target = self.policy.critic_target
|
||||
self.vf_net = self.policy.vf_net
|
||||
|
||||
def select_action(self, observation, deterministic=True):
|
||||
# Normally not needed
|
||||
observation = np.array(observation)
|
||||
with th.no_grad():
|
||||
observation = th.FloatTensor(observation.reshape(1, -1)).to(self.device)
|
||||
return self.actor(observation, deterministic=deterministic).cpu().numpy()
|
||||
|
||||
def predict(self, observation, state=None, mask=None, deterministic=True):
|
||||
"""
|
||||
Get the model's action from an observation
|
||||
|
||||
:param observation: (np.ndarray) the input observation
|
||||
:param state: (np.ndarray) The last states (can be None, used in recurrent policies)
|
||||
:param mask: (np.ndarray) The last masks (can be None, used in recurrent policies)
|
||||
:param deterministic: (bool) Whether or not to return deterministic actions.
|
||||
:return: (np.ndarray, np.ndarray) the model's action and the next state (used in recurrent policies)
|
||||
"""
|
||||
return self.unscale_action(self.select_action(observation, deterministic=deterministic))
|
||||
|
||||
def train_critic(self, gradient_steps: int = 1,
|
||||
batch_size: int = 100,
|
||||
replay_data: Optional[ReplayBufferSamples] = None,
|
||||
|
|
|
|||
Loading…
Reference in a new issue