From 57b37513b64a593c7639f2ee4df759551e5996eb Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 20 Mar 2020 10:09:09 +0100 Subject: [PATCH] Refactor handling of obs and action space + remove duplicated code --- docs/misc/changelog.rst | 23 +++++++++++++ torchy_baselines/common/base_class.py | 38 +++++++++++++++++--- torchy_baselines/common/buffers.py | 39 ++++++++++++--------- torchy_baselines/common/distributions.py | 12 ++++--- torchy_baselines/common/preprocessing.py | 44 ++++++++++++++++++++++++ torchy_baselines/ppo/policies.py | 3 +- torchy_baselines/ppo/ppo.py | 18 +++------- torchy_baselines/sac/policies.py | 5 +-- torchy_baselines/sac/sac.py | 22 +++--------- torchy_baselines/td3/policies.py | 5 +-- torchy_baselines/td3/td3.py | 17 +++------ torchy_baselines/version.txt | 2 +- 12 files changed, 152 insertions(+), 76 deletions(-) create mode 100644 torchy_baselines/common/preprocessing.py diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 7ad06b9..d58ba20 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,6 +3,29 @@ Changelog ========== +Pre-Release 0.4.0a0 (WIP) +------------------------------ + + +Breaking Changes: +^^^^^^^^^^^^^^^^^ + +New Features: +^^^^^^^^^^^^^ + +Bug Fixes: +^^^^^^^^^^ + +Deprecations: +^^^^^^^^^^^^^ + +Others: +^^^^^^^ +- Refactor handling of observation and action spaces + +Documentation: +^^^^^^^^^^^^^^ + Pre-Release 0.3.0 (2020-02-14) ------------------------------ diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index cd9eb77..54024a3 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -31,6 +31,8 @@ class BaseRLModel(ABC): :param env: (Union[GymEnv, str]) The environment to learn from (if registered in Gym, can be str. Can be None for loading trained models) :param policy_base: (Type[BasePolicy]) The base policy used by this method + :param learning_rate: (float or callable) learning rate for the optimizer, + it can be a function of the current progress (from 1 to 0) :param policy_kwargs: (Dict[str, Any]) Additional arguments to be passed to the policy on creation :param verbose: (int) The verbosity level: 0 none, 1 training information, 2 debug :param device: (Union[th.device, str]) Device on which the code should run. @@ -53,6 +55,7 @@ class BaseRLModel(ABC): policy: Type[BasePolicy], env: Union[GymEnv, str], policy_base: Type[BasePolicy], + learning_rate: Union[float, Callable], policy_kwargs: Dict[str, Any] = None, verbose: int = 0, device: Union[th.device, str] = 'auto', @@ -89,7 +92,7 @@ class BaseRLModel(ABC): self.action_noise = None # type: Optional[ActionNoise] self.start_time = None self.policy = None - self.learning_rate = None # type: Optional[float] + self.learning_rate = learning_rate self.lr_schedule = None # type: Optional[Callable] # Used for SDE only self.use_sde = use_sde @@ -135,7 +138,7 @@ class BaseRLModel(ABC): @abstractmethod def _setup_model(self) -> None: """ - Create networks and optimizers + Create networks, buffer and optimizers """ raise NotImplementedError() @@ -761,6 +764,11 @@ class OffPolicyRLModel(BaseRLModel): :param env: The environment to learn from (if registered in Gym, can be str. Can be None for loading trained models) :param policy_base: The base policy used by this method + :param learning_rate: (float or callable) learning rate for the optimizer, + it can be a function of the current progress (from 1 to 0) + :param buffer_size: (int) size of the replay buffer + :param learning_starts: (int) how many steps of the model to collect transitions for before learning starts + :param batch_size: (int) Minibatch size for each gradient update :param policy_kwargs: Additional arguments to be passed to the policy on creation :param verbose: The verbosity level: 0 none, 1 training information, 2 debug :param device: Device on which the code should run. @@ -785,6 +793,10 @@ class OffPolicyRLModel(BaseRLModel): policy: Type[BasePolicy], env: Union[GymEnv, str], policy_base: Type[BasePolicy], + learning_rate: Union[float, Callable], + buffer_size: int = int(1e6), + learning_starts: int = 100, + batch_size: int = 256, policy_kwargs: Dict[str, Any] = None, verbose: int = 0, device: Union[th.device, str] = 'auto', @@ -796,16 +808,32 @@ class OffPolicyRLModel(BaseRLModel): sde_sample_freq: int = -1, use_sde_at_warmup: bool = False): - super(OffPolicyRLModel, self).__init__(policy, env, policy_base, policy_kwargs, verbose, + super(OffPolicyRLModel, self).__init__(policy, env, policy_base, learning_rate, + policy_kwargs, verbose, device, support_multi_env, create_eval_env, monitor_wrapper, seed, use_sde, sde_sample_freq) + self.buffer_size = buffer_size + self.batch_size = batch_size + self.learning_starts = learning_starts + self.actor = None + self.replay_buffer = None # type: Optional[ReplayBuffer] + # Update policy keyword arguments + self.policy_kwargs['use_sde'] = self.use_sde + self.policy_kwargs['device'] = self.device # For SDE only self.rollout_data = None self.on_policy_exploration = False - self.actor = None - self.replay_buffer = None # type: Optional[ReplayBuffer] self.use_sde_at_warmup = use_sde_at_warmup + def _setup_model(self): + self._setup_lr_schedule() + self.set_random_seed(self.seed) + self.replay_buffer = ReplayBuffer(self.buffer_size, self.observation_space, + self.action_space, self.device) + self.policy = self.policy_class(self.observation_space, self.action_space, + self.lr_schedule, **self.policy_kwargs) + self.policy = self.policy.to(self.device) + def save_replay_buffer(self, path: str): """ Save the replay buffer as a pickle file. diff --git a/torchy_baselines/common/buffers.py b/torchy_baselines/common/buffers.py index 5ddc29c..0932ce6 100644 --- a/torchy_baselines/common/buffers.py +++ b/torchy_baselines/common/buffers.py @@ -2,9 +2,11 @@ from typing import Union, Optional, Generator import numpy as np import torch as th +from gym import spaces from torchy_baselines.common.vec_env import VecNormalize from torchy_baselines.common.type_aliases import RolloutBufferSamples, ReplayBufferSamples +from torchy_baselines.common.preprocessing import get_obs_dim, get_action_dim class BaseBuffer(object): @@ -12,23 +14,24 @@ class BaseBuffer(object): Base class that represent a buffer (rollout or replay) :param buffer_size: (int) Max number of element in the buffer - :param obs_dim: (int) Dimension of the observation - :param action_dim: (int) Dimension of the action space + :param observation_space: (spaces.Space) Observation space + :param action_space: (spaces.Space) Action space :param device: (Union[th.device, str]) PyTorch device to which the values will be converted :param n_envs: (int) Number of parallel environments """ - def __init__(self, buffer_size: int, - obs_dim: int, - action_dim: int, + observation_space: spaces.Space, + action_space: spaces.Space, device: Union[th.device, str] = 'cpu', n_envs: int = 1): super(BaseBuffer, self).__init__() self.buffer_size = buffer_size - self.obs_dim = obs_dim - self.action_dim = action_dim + self.observation_space = observation_space + self.action_space = action_space + self.obs_dim = get_obs_dim(observation_space) + self.action_dim = get_action_dim(action_space) self.pos = 0 self.full = False self.device = device @@ -137,19 +140,20 @@ class ReplayBuffer(BaseBuffer): Replay buffer used in off-policy algorithms like SAC/TD3. :param buffer_size: (int) Max number of element in the buffer - :param obs_dim: (int) Dimension of the observation - :param action_dim: (int) Dimension of the action space + :param observation_space: (spaces.Space) Observation space + :param action_space: (spaces.Space) Action space :param device: (th.device) :param n_envs: (int) Number of parallel environments """ def __init__(self, buffer_size: int, - obs_dim: int, - action_dim: int, + observation_space: spaces.Space, + action_space: spaces.Space, device: Union[th.device, str] = 'cpu', n_envs: int = 1): - super(ReplayBuffer, self).__init__(buffer_size, obs_dim, action_dim, device, n_envs=n_envs) + super(ReplayBuffer, self).__init__(buffer_size, observation_space, + action_space, device, n_envs=n_envs) assert n_envs == 1, "Replay buffer only support single environment for now" @@ -194,8 +198,8 @@ class RolloutBuffer(BaseBuffer): Rollout buffer used in on-policy algorithms like A2C/PPO. :param buffer_size: (int) Max number of element in the buffer - :param obs_dim: (int) Dimension of the observation - :param action_dim: (int) Dimension of the action space + :param observation_space: (spaces.Space) Observation space + :param action_space: (spaces.Space) Action space :param device: (th.device) :param gae_lambda: (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator Equivalent to classic advantage when set to 1. @@ -205,14 +209,15 @@ class RolloutBuffer(BaseBuffer): def __init__(self, buffer_size: int, - obs_dim: int, - action_dim: int, + observation_space: spaces.Space, + action_space: spaces.Space, device: Union[th.device, str] = 'cpu', gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1): - super(RolloutBuffer, self).__init__(buffer_size, obs_dim, action_dim, device, n_envs=n_envs) + super(RolloutBuffer, self).__init__(buffer_size, observation_space, + action_space, device, n_envs=n_envs) self.gae_lambda = gae_lambda self.gamma = gamma self.observations, self.actions, self.rewards, self.advantages = None, None, None, None diff --git a/torchy_baselines/common/distributions.py b/torchy_baselines/common/distributions.py index 5d04f83..baf54aa 100644 --- a/torchy_baselines/common/distributions.py +++ b/torchy_baselines/common/distributions.py @@ -6,6 +6,8 @@ import torch.nn as nn from torch.distributions import Normal, Categorical from gym import spaces +from torchy_baselines.common.preprocessing import get_action_dim + class Distribution(object): def __init__(self): @@ -59,7 +61,7 @@ class DiagGaussianDistribution(Distribution): Gaussian distribution with diagonal covariance matrix, for continuous actions. - :param action_dim: (int) Number of continuous actions + :param action_dim: (int) Dimension of the action space. """ def __init__(self, action_dim: int): @@ -144,7 +146,7 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution): Gaussian distribution with diagonal covariance matrix, followed by a squashing function (tanh) to ensure bounds. - :param action_dim: (int) Number of continuous actions + :param action_dim: (int) Dimension of the action space. :param epsilon: (float) small value to avoid NaN due to numerical imprecision. """ @@ -252,7 +254,7 @@ class StateDependentNoiseDistribution(Distribution): It is used to create the noise exploration matrix and compute the log probabilty of an action with that noise. - :param action_dim: (int) Number of continuous actions + :param action_dim: (int) Dimension of the action space. :param full_std: (bool) Whether to use (n_features x n_actions) parameters for the std instead of only (n_features,) :param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` to ensure @@ -495,8 +497,8 @@ def make_proba_distribution(action_space: gym.spaces.Space, if isinstance(action_space, spaces.Box): assert len(action_space.shape) == 1, "Error: the action space must be a vector" if use_sde: - return StateDependentNoiseDistribution(action_space.shape[0], **dist_kwargs) - return DiagGaussianDistribution(action_space.shape[0], **dist_kwargs) + return StateDependentNoiseDistribution(get_action_dim(action_space), **dist_kwargs) + return DiagGaussianDistribution(get_action_dim(action_space), **dist_kwargs) elif isinstance(action_space, spaces.Discrete): return CategoricalDistribution(action_space.n, **dist_kwargs) # elif isinstance(action_space, spaces.MultiDiscrete): diff --git a/torchy_baselines/common/preprocessing.py b/torchy_baselines/common/preprocessing.py new file mode 100644 index 0000000..024deca --- /dev/null +++ b/torchy_baselines/common/preprocessing.py @@ -0,0 +1,44 @@ +from typing import Tuple, Union + +import numpy as np +import torch as th +from gym import spaces + + +def is_image(observation_space): + return False + + +def preprocess_obs(obs: th.Tensor, observation_space: spaces.Space) -> th.Tensor: + if isinstance(observation_space, spaces.Box): + if is_image(observation_space): + return obs / 255.0 + return obs + elif isinstance(observation_space, spaces.Discrete): + # TODO: one hot encoding + raise NotImplementedError() + else: + # TODO: Multidiscrete, Binary, MultiBinary, Tuple, Dict + raise NotImplementedError() + + +def get_obs_dim(observation_space: spaces.Space) -> Union[int, Tuple[int, ...]]: + if isinstance(observation_space, spaces.Box): + if is_image(observation_space): + return observation_space.shape + return np.prod(observation_space.shape) + elif isinstance(observation_space, spaces.Discrete): + return 1 + else: + # TODO: Multidiscrete, Binary, MultiBinary, Tuple, Dict + raise NotImplementedError() + + +def get_action_dim(action_space: spaces.Space) -> int: + if isinstance(action_space, spaces.Box): + return int(np.prod(action_space.shape)) + elif isinstance(action_space, spaces.Discrete): + # Action is an int + return 1 + else: + raise NotImplementedError() diff --git a/torchy_baselines/ppo/policies.py b/torchy_baselines/ppo/policies.py index d54e7dd..163e890 100644 --- a/torchy_baselines/ppo/policies.py +++ b/torchy_baselines/ppo/policies.py @@ -6,6 +6,7 @@ import torch as th import torch.nn as nn import numpy as np +from torchy_baselines.common.preprocessing import get_obs_dim from torchy_baselines.common.policies import (BasePolicy, register_policy, MlpExtractor, create_sde_feature_extractor) from torchy_baselines.common.distributions import (make_proba_distribution, Distribution, @@ -54,7 +55,7 @@ class PPOPolicy(BasePolicy): use_expln: bool = False, squash_output: bool = False): super(PPOPolicy, self).__init__(observation_space, action_space, device, squash_output=squash_output) - self.obs_dim = self.observation_space.shape[0] + self.obs_dim = get_obs_dim(self.observation_space) # Default network architecture, from stable-baselines if net_arch is None: diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index 098aa3c..50efaf5 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -98,11 +98,10 @@ class PPO(BaseRLModel): device: Union[th.device, str] = 'auto', _init_setup_model: bool = True): - super(PPO, self).__init__(policy, env, PPOPolicy, policy_kwargs=policy_kwargs, + super(PPO, self).__init__(policy, env, PPOPolicy, learning_rate, policy_kwargs=policy_kwargs, verbose=verbose, device=device, use_sde=use_sde, sde_sample_freq=sde_sample_freq, create_eval_env=create_eval_env, support_multi_env=True, seed=seed) - self.learning_rate = learning_rate self.batch_size = batch_size self.n_epochs = n_epochs self.n_steps = n_steps @@ -123,19 +122,12 @@ class PPO(BaseRLModel): def _setup_model(self) -> None: self._setup_lr_schedule() - # TODO: preprocessing: one hot vector for obs discrete - state_dim = self.observation_space.shape[0] - if isinstance(self.action_space, spaces.Box): - # Action is a 1D vector - action_dim = self.action_space.shape[0] - elif isinstance(self.action_space, spaces.Discrete): - # Action is a scalar - action_dim = 1 - self.set_random_seed(self.seed) - self.rollout_buffer = RolloutBuffer(self.n_steps, state_dim, action_dim, self.device, - gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs) + self.rollout_buffer = RolloutBuffer(self.n_steps, self.observation_space, + self.action_space, self.device, + gamma=self.gamma, gae_lambda=self.gae_lambda, + n_envs=self.n_envs) self.policy = self.policy_class(self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, device=self.device, **self.policy_kwargs) diff --git a/torchy_baselines/sac/policies.py b/torchy_baselines/sac/policies.py index b88f613..2b7367a 100644 --- a/torchy_baselines/sac/policies.py +++ b/torchy_baselines/sac/policies.py @@ -4,6 +4,7 @@ import gym import torch as th import torch.nn as nn +from torchy_baselines.common.preprocessing import get_action_dim, get_obs_dim from torchy_baselines.common.policies import (BasePolicy, register_policy, create_mlp, BaseNetwork, create_sde_feature_extractor) from torchy_baselines.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution @@ -191,8 +192,8 @@ class SACPolicy(BasePolicy): if net_arch is None: net_arch = [256, 256] - self.obs_dim = self.observation_space.shape[0] - self.action_dim = self.action_space.shape[0] + self.obs_dim = get_obs_dim(self.observation_space) + self.action_dim = get_action_dim(self.action_space) self.net_arch = net_arch self.activation_fn = activation_fn self.net_args = { diff --git a/torchy_baselines/sac/sac.py b/torchy_baselines/sac/sac.py index 8878c3f..496fe0d 100644 --- a/torchy_baselines/sac/sac.py +++ b/torchy_baselines/sac/sac.py @@ -89,7 +89,9 @@ class SAC(OffPolicyRLModel): device: Union[th.device, str] = 'auto', _init_setup_model: bool = True): - super(SAC, self).__init__(policy, env, SACPolicy, policy_kwargs, verbose, device, + super(SAC, self).__init__(policy, env, SACPolicy, learning_rate, + buffer_size, learning_starts, batch_size, + policy_kwargs, verbose, device, create_eval_env=create_eval_env, seed=seed, use_sde=use_sde, sde_sample_freq=sde_sample_freq, use_sde_at_warmup=use_sde_at_warmup) @@ -97,11 +99,6 @@ class SAC(OffPolicyRLModel): self.target_entropy = target_entropy self.log_ent_coef = None # type: Optional[th.Tensor] self.target_update_interval = target_update_interval - self.buffer_size = buffer_size - # In the original paper, same learning rate is used for all networks - self.learning_rate = learning_rate - self.learning_starts = learning_starts - self.batch_size = batch_size self.tau = tau # Entropy coefficient / Entropy temperature # Inverse of the reward scale @@ -118,10 +115,8 @@ class SAC(OffPolicyRLModel): self._setup_model() def _setup_model(self) -> None: - self._setup_lr_schedule() - obs_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0] - if self.seed is not None: - self.set_random_seed(self.seed) + super(SAC, self)._setup_model() + self._create_aliases() # Target entropy is used when learning the entropy coefficient if self.target_entropy == 'auto': @@ -152,13 +147,6 @@ class SAC(OffPolicyRLModel): # is passed self.ent_coef_tensor = th.tensor(float(self.ent_coef)).to(self.device) - self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device) - self.policy = self.policy_class(self.observation_space, self.action_space, - self.lr_schedule, use_sde=self.use_sde, - device=self.device, **self.policy_kwargs) - self.policy = self.policy.to(self.device) - self._create_aliases() - def _create_aliases(self) -> None: self.actor = self.policy.actor self.critic = self.policy.critic diff --git a/torchy_baselines/td3/policies.py b/torchy_baselines/td3/policies.py index 4daa02d..b48e98a 100644 --- a/torchy_baselines/td3/policies.py +++ b/torchy_baselines/td3/policies.py @@ -4,6 +4,7 @@ import gym import torch as th import torch.nn as nn +from torchy_baselines.common.preprocessing import get_action_dim, get_obs_dim from torchy_baselines.common.policies import (BasePolicy, register_policy, create_mlp, BaseNetwork, create_sde_feature_extractor) from torchy_baselines.common.distributions import StateDependentNoiseDistribution @@ -230,8 +231,8 @@ class TD3Policy(BasePolicy): if net_arch is None: net_arch = [400, 300] - self.obs_dim = self.observation_space.shape[0] - self.action_dim = self.action_space.shape[0] + self.obs_dim = get_obs_dim(self.observation_space) + self.action_dim = get_action_dim(self.action_space) self.net_arch = net_arch self.activation_fn = activation_fn self.net_args = { diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py index 70f3a5f..79d3d88 100644 --- a/torchy_baselines/td3/td3.py +++ b/torchy_baselines/td3/td3.py @@ -88,18 +88,16 @@ class TD3(OffPolicyRLModel): device: Union[th.device, str] = 'auto', _init_setup_model: bool = True): - super(TD3, self).__init__(policy, env, TD3Policy, policy_kwargs, verbose, device, + super(TD3, self).__init__(policy, env, TD3Policy, learning_rate, + buffer_size, learning_starts, batch_size, + policy_kwargs, verbose, device, create_eval_env=create_eval_env, seed=seed, use_sde=use_sde, sde_sample_freq=sde_sample_freq, use_sde_at_warmup=use_sde_at_warmup) - self.buffer_size = buffer_size - self.learning_rate = learning_rate - self.learning_starts = learning_starts self.train_freq = train_freq self.gradient_steps = gradient_steps self.n_episodes_rollout = n_episodes_rollout - self.batch_size = batch_size self.tau = tau self.gamma = gamma self.action_noise = action_noise @@ -118,14 +116,7 @@ class TD3(OffPolicyRLModel): self._setup_model() def _setup_model(self) -> None: - self._setup_lr_schedule() - obs_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0] - self.set_random_seed(self.seed) - self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device) - self.policy = self.policy_class(self.observation_space, self.action_space, - self.lr_schedule, use_sde=self.use_sde, - device=self.device, **self.policy_kwargs) - self.policy = self.policy.to(self.device) + super(TD3, self)._setup_model() self._create_aliases() def _create_aliases(self) -> None: diff --git a/torchy_baselines/version.txt b/torchy_baselines/version.txt index 0d91a54..f28aaa5 100644 --- a/torchy_baselines/version.txt +++ b/torchy_baselines/version.txt @@ -1 +1 @@ -0.3.0 +0.4.0a0