From f1a4fa2d3fae520e1308929d04f78e6d7b6223cb Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 12 Feb 2020 15:25:05 +0100 Subject: [PATCH] Improve predict method --- docs/misc/changelog.rst | 5 +- setup.py | 6 +- tests/test_save_load.py | 6 +- torchy_baselines/__init__.py | 2 +- torchy_baselines/common/base_class.py | 103 +++++++++++++++++++++++--- torchy_baselines/common/policies.py | 27 +++++-- torchy_baselines/ppo/policies.py | 6 +- torchy_baselines/ppo/ppo.py | 22 ------ torchy_baselines/sac/policies.py | 6 +- torchy_baselines/sac/sac.py | 19 ----- torchy_baselines/td3/policies.py | 3 + torchy_baselines/td3/td3.py | 19 ----- 12 files changed, 134 insertions(+), 90 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index b083e39..02bf897 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Pre-Release 0.2.0a1 (WIP) +Pre-Release 0.2.0a2 (WIP) ------------------------------ Breaking Changes: @@ -20,7 +20,8 @@ New Features: - Add methods for saving and loading replay buffer - Add `extend()` method to the buffers - Add `get_vec_normalize_env()` to `BaseRLModel` to retrieve `VecNormalize` wrapper when it exists -- Add `¶results_plotter` from Stable Baselines +- Add `results_plotter` from Stable Baselines +- Improve `predict()` method to handle different type of observations (single, vectorized, ...) Bug Fixes: ^^^^^^^^^^ diff --git a/setup.py b/setup.py index b9598fc..92389ea 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,9 @@ setup(name='torchy_baselines', 'torch>=1.2.0', 'cloudpickle', # For reading logs - 'pandas' + 'pandas', + # Plotting learning curves + 'matplotlib' ], extras_require={ 'tests': [ @@ -45,7 +47,7 @@ setup(name='torchy_baselines', license="MIT", long_description="", long_description_content_type='text/markdown', - version="0.2.0a1", + version="0.2.0a2", ) # python setup.py sdist diff --git a/tests/test_save_load.py b/tests/test_save_load.py index edd326a..45c0ac5 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -34,7 +34,7 @@ def test_save_load(model_class): env.reset() observations = np.array([env.step(env.action_space.sample())[0] for _ in range(10)]) - observations = np.squeeze(observations) + observations = observations.reshape(10, -1) # Get dictionary of current parameters params = deepcopy(model.policy.state_dict()) @@ -53,7 +53,7 @@ def test_save_load(model_class): params = new_params # get selected actions - selected_actions = [model.predict(observation, deterministic=True) for observation in observations] + selected_actions = model.predict(observations, deterministic=True) # Check model.save("test_save.zip") @@ -68,7 +68,7 @@ def test_save_load(model_class): assert th.allclose(params[key], new_params[key]), "Model parameters not the same after save and load." # check if model still selects the same actions - new_selected_actions = [model.predict(observation, deterministic=True) for observation in observations] + new_selected_actions = model.predict(observations, deterministic=True) assert np.allclose(selected_actions, new_selected_actions, 1e-4) # check if learn still works diff --git a/torchy_baselines/__init__.py b/torchy_baselines/__init__.py index e250cc2..5e22a8d 100644 --- a/torchy_baselines/__init__.py +++ b/torchy_baselines/__init__.py @@ -4,4 +4,4 @@ from torchy_baselines.ppo import PPO from torchy_baselines.sac import SAC from torchy_baselines.td3 import TD3 -__version__ = "0.2.0a1" +__version__ = "0.2.0a2" diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index d4e6810..2348d0e 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -306,21 +306,104 @@ class BaseRLModel(ABC): """ raise NotImplementedError() - @abstractmethod + @staticmethod + def _is_vectorized_observation(observation: np.ndarray, observation_space: gym.spaces.Space) -> bool: + """ + For every observation type, detects and validates the shape, + then returns whether or not the observation is vectorized. + + :param observation: (np.ndarray) the input observation to validate + :param observation_space: (gym.spaces) the observation space + :return: (bool) whether the given observation is vectorized or not + """ + if isinstance(observation_space, gym.spaces.Box): + if observation.shape == observation_space.shape: + return False + elif observation.shape[1:] == observation_space.shape: + return True + else: + raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) + + "Box environment, please use {} ".format(observation_space.shape) + + "or (n_env, {}) for the observation shape." + .format(", ".join(map(str, observation_space.shape)))) + elif isinstance(observation_space, gym.spaces.Discrete): + if observation.shape == (): # A numpy array of a number, has shape empty tuple '()' + return False + elif len(observation.shape) == 1: + return True + else: + raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) + + "Discrete environment, please use (1,) or (n_env, 1) for the observation shape.") + elif isinstance(observation_space, gym.spaces.MultiDiscrete): + if observation.shape == (len(observation_space.nvec),): + return False + elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec): + return True + else: + raise ValueError("Error: Unexpected observation shape {} for MultiDiscrete ".format(observation.shape) + + "environment, please use ({},) or ".format(len(observation_space.nvec)) + + "(n_env, {}) for the observation shape.".format(len(observation_space.nvec))) + elif isinstance(observation_space, gym.spaces.MultiBinary): + if observation.shape == (observation_space.n,): + return False + elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n: + return True + else: + raise ValueError("Error: Unexpected observation shape {} for MultiBinary ".format(observation.shape) + + "environment, please use ({},) or ".format(observation_space.n) + + "(n_env, {}) for the observation shape.".format(observation_space.n)) + else: + raise ValueError("Error: Cannot determine if the observation is vectorized with the space type {}." + .format(observation_space)) + def predict(self, observation: np.ndarray, state: Optional[np.ndarray] = None, mask: Optional[np.ndarray] = None, deterministic: bool = False) -> np.ndarray: """ - Get the model's action from an observation + Get the model's action(s) from an observation - :param observation: the input observation - :param state: The last states (can be None, used in recurrent policies) - :param mask: The last masks (can be None, used in recurrent policies) - :param deterministic: Whether or not to return deterministic actions. - :return: the model's action and the next state (used in recurrent policies) + :param observation: (np.ndarray) the input observation + :param state: (Optional[np.ndarray]) The last states (can be None, used in recurrent policies) + :param mask: (Optional[np.ndarray]) The last masks (can be None, used in recurrent policies) + :param deterministic: (bool) Whether or not to return deterministic actions. + :return: (np.ndarray) the model's action and the next state (used in recurrent policies) """ - raise NotImplementedError() + # if state is None: + # state = self.initial_state + # if mask is None: + # mask = [False for _ in range(self.n_envs)] + observation = np.array(observation) + vectorized_env = self._is_vectorized_observation(observation, self.observation_space) + + observation = observation.reshape((-1,) + self.observation_space.shape) + # Convert to float pytorch + # TODO: replace with preprocessing + observation = th.as_tensor(observation).float().to(self.device) + with th.no_grad(): + actions = self.policy.predict(observation, deterministic=deterministic) + # Convert to numpy + actions = actions.cpu().numpy() + + # Rescale to proper domain when using squashing + # TODO: should not be used for a Gaussian distribution? + if isinstance(self.action_space, gym.spaces.Box): + actions = self.unscale_action(actions) + + clipped_actions = actions + # Clip the actions to avoid out of bound error when using gaussian distribution + if isinstance(self.action_space, gym.spaces.Box): + clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high) + + if not vectorized_env: + if state is not None: + raise ValueError("Error: The environment must be vectorized when using recurrent policies.") + clipped_actions = clipped_actions[0] + + # TODO: switch to stable baselines API + # return clipped_actions, state + return clipped_actions + @classmethod def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs): @@ -806,7 +889,9 @@ class OffPolicyRLModel(BaseRLModel): # Warmup phase unscaled_action = np.array([self.action_space.sample()]) else: - unscaled_action = self.predict(obs, deterministic=not self.use_sde) + # Note: we assume that the policy uses tanh to scale the action + # We use non-deterministic action in the case of SAC, for TD3, it does not matter + unscaled_action = self.predict(obs, deterministic=False) # Rescale the action from [low, high] to [-1, 1] scaled_action = self.scale_action(unscaled_action) diff --git a/torchy_baselines/common/policies.py b/torchy_baselines/common/policies.py index ea7f755..1e0064d 100644 --- a/torchy_baselines/common/policies.py +++ b/torchy_baselines/common/policies.py @@ -1,25 +1,30 @@ +from typing import Union + from itertools import zip_longest +import gym import torch as th import torch.nn as nn +import numpy as np class BasePolicy(nn.Module): """ The base policy object - :param observation_space: (Gym Space) The observation space of the environment - :param action_space: (Gym Space) The action space of the environment + :param observation_space: (gym.spaces.Space) The observation space of the environment + :param action_space: (gym.spaces.Space) The action space of the environment """ - def __init__(self, observation_space, action_space, device='cpu'): + def __init__(self, observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, device: Union[th.device, str] = 'cpu'): super(BasePolicy, self).__init__() self.observation_space = observation_space self.action_space = action_space self.device = device @staticmethod - def init_weights(module, gain=1): + def init_weights(module: nn.Module, gain: float = 1): if type(module) == nn.Linear: nn.init.orthogonal_(module.weight, gain=gain) module.bias.data.fill_(0.0) @@ -27,7 +32,13 @@ class BasePolicy(nn.Module): def forward(self, *_args, **kwargs): raise NotImplementedError() - def save(self, path): + def predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: + """ + Get the action according to the policy for a given observation. + """ + raise NotImplementedError() + + def save(self, path: str) -> None: """ Save model to a given location. @@ -35,7 +46,7 @@ class BasePolicy(nn.Module): """ th.save(self.state_dict(), path) - def load(self, path): + def load(self, path: str) -> None: """ Load saved model from path. @@ -43,7 +54,7 @@ class BasePolicy(nn.Module): """ self.load_state_dict(th.load(path)) - def load_from_vector(self, vector): + def load_from_vector(self, vector: np.ndarray): """ Load parameters from a 1D vector. @@ -51,7 +62,7 @@ class BasePolicy(nn.Module): """ th.nn.utils.vector_to_parameters(th.FloatTensor(vector).to(self.device), self.parameters()) - def parameters_to_vector(self): + def parameters_to_vector(self) -> np.ndarray: """ Convert the parameters to a 1D vector. diff --git a/torchy_baselines/ppo/policies.py b/torchy_baselines/ppo/policies.py index 4421d50..1e34926 100644 --- a/torchy_baselines/ppo/policies.py +++ b/torchy_baselines/ppo/policies.py @@ -162,10 +162,10 @@ class PPOPolicy(BasePolicy): return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde, deterministic=deterministic) - def actor_forward(self, obs, deterministic=False): - latent_pi, _, latent_sde = self._get_latent(obs) + def predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: + latent_pi, _, latent_sde = self._get_latent(observation) action, _ = self._get_action_dist_from_latent(latent_pi, latent_sde, deterministic=deterministic) - return action.detach().cpu().numpy() + return action def evaluate_actions(self, obs, action, deterministic=False): """ diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index b17f4f8..2ce168b 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -129,28 +129,6 @@ class PPO(BaseRLModel): if self.clip_range_vf is not None: self.clip_range_vf = get_schedule_fn(self.clip_range_vf) - def select_action(self, observation, deterministic=False): - # Normally not needed - observation = np.array(observation) - with th.no_grad(): - observation = th.FloatTensor(observation.reshape(1, -1)).to(self.device) - return self.policy.actor_forward(observation, deterministic=deterministic) - - def predict(self, observation, state=None, mask=None, deterministic=False): - """ - Get the model's action from an observation - - :param observation: (np.ndarray) the input observation - :param state: (np.ndarray) The last states (can be None, used in recurrent policies) - :param mask: (np.ndarray) The last masks (can be None, used in recurrent policies) - :param deterministic: (bool) Whether or not to return deterministic actions. - :return: (np.ndarray, np.ndarray) the model's action and the next state (used in recurrent policies) - """ - clipped_actions = self.select_action(observation, deterministic=deterministic) - if isinstance(self.action_space, gym.spaces.Box): - clipped_actions = np.clip(clipped_actions, self.action_space.low, self.action_space.high) - return clipped_actions - def collect_rollouts(self, env: VecEnv, callback: BaseCallback, diff --git a/torchy_baselines/sac/policies.py b/torchy_baselines/sac/policies.py index 3fe11c5..3fbea93 100644 --- a/torchy_baselines/sac/policies.py +++ b/torchy_baselines/sac/policies.py @@ -129,11 +129,11 @@ class Actor(BaseNetwork): def forward(self, obs, deterministic=False): mean_actions, log_std, latent_sde = self.get_action_dist_params(obs) if self.use_sde: - # Note the action is squashed + # Note: the action is squashed action, _ = self.action_dist.proba_distribution(mean_actions, log_std, latent_sde, deterministic=deterministic) else: - # Note the action is squashed + # Note: the action is squashed action, _ = self.action_dist.proba_distribution(mean_actions, log_std, deterministic=deterministic) return action @@ -246,6 +246,8 @@ class SACPolicy(BasePolicy): def forward(self, obs): return self.actor(obs) + def predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: + return self.actor.forward(observation, deterministic) MlpPolicy = SACPolicy diff --git a/torchy_baselines/sac/sac.py b/torchy_baselines/sac/sac.py index 9617ec7..f0930ee 100644 --- a/torchy_baselines/sac/sac.py +++ b/torchy_baselines/sac/sac.py @@ -148,25 +148,6 @@ class SAC(OffPolicyRLModel): self.critic = self.policy.critic self.critic_target = self.policy.critic_target - def select_action(self, observation): - # Normally not needed - observation = np.array(observation) - with th.no_grad(): - observation = th.FloatTensor(observation.reshape(1, -1)).to(self.device) - return self.actor(observation).cpu().data.numpy() - - def predict(self, observation, state=None, mask=None, deterministic=True): - """ - Get the model's action from an observation - - :param observation: (np.ndarray) the input observation - :param state: (np.ndarray) The last states (can be None, used in recurrent policies) - :param mask: (np.ndarray) The last masks (can be None, used in recurrent policies) - :param deterministic: (bool) Whether or not to return deterministic actions. - :return: (np.ndarray, np.ndarray) the model's action and the next state (used in recurrent policies) - """ - return self.unscale_action(self.select_action(observation)) - def train(self, gradient_steps: int, batch_size: int = 64): # Update optimizers learning rate optimizers = [self.actor.optimizer, self.critic.optimizer] diff --git a/torchy_baselines/td3/policies.py b/torchy_baselines/td3/policies.py index 8bc5f60..fa19952 100644 --- a/torchy_baselines/td3/policies.py +++ b/torchy_baselines/td3/policies.py @@ -277,6 +277,9 @@ class TD3Policy(BasePolicy): def forward(self, obs, deterministic=True): return self.actor(obs, deterministic=deterministic) + def predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: + return self.forward(observation, deterministic) + MlpPolicy = TD3Policy diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py index 05f1b4f..1ba5947 100644 --- a/torchy_baselines/td3/td3.py +++ b/torchy_baselines/td3/td3.py @@ -114,25 +114,6 @@ class TD3(OffPolicyRLModel): self.critic_target = self.policy.critic_target self.vf_net = self.policy.vf_net - def select_action(self, observation, deterministic=True): - # Normally not needed - observation = np.array(observation) - with th.no_grad(): - observation = th.FloatTensor(observation.reshape(1, -1)).to(self.device) - return self.actor(observation, deterministic=deterministic).cpu().numpy() - - def predict(self, observation, state=None, mask=None, deterministic=True): - """ - Get the model's action from an observation - - :param observation: (np.ndarray) the input observation - :param state: (np.ndarray) The last states (can be None, used in recurrent policies) - :param mask: (np.ndarray) The last masks (can be None, used in recurrent policies) - :param deterministic: (bool) Whether or not to return deterministic actions. - :return: (np.ndarray, np.ndarray) the model's action and the next state (used in recurrent policies) - """ - return self.unscale_action(self.select_action(observation, deterministic=deterministic)) - def train_critic(self, gradient_steps: int = 1, batch_size: int = 100, replay_data: Optional[ReplayBufferSamples] = None,