From aa1026ee879455ef6952140c742fb9f16fbdb994 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 17 Apr 2020 15:13:45 +0200 Subject: [PATCH] Added ``optimizer`` and ``optimizer_kwargs`` to ``policy_kwargs`` --- docs/misc/changelog.rst | 48 ++++++++++++++++------------- tests/test_custom_policy.py | 24 +++++++++++++-- torchy_baselines/a2c/a2c.py | 14 +++------ torchy_baselines/common/policies.py | 1 + torchy_baselines/ppo/policies.py | 24 +++++++++++---- torchy_baselines/sac/policies.py | 22 ++++++++++--- torchy_baselines/td3/policies.py | 24 +++++++++++---- 7 files changed, 107 insertions(+), 50 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 56cf18f..0d1620d 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,6 +3,32 @@ Changelog ========== +Pre-Release 0.5.0a0 (WIP) +------------------------------ + +Breaking Changes: +^^^^^^^^^^^^^^^^^ + +New Features: +^^^^^^^^^^^^^ +- Added ``optimizer`` and ``optimizer_kwargs`` to ``policy_kwargs`` in order to easily + customizer optimizers + +Bug Fixes: +^^^^^^^^^^ +- Fixed ``reset_num_timesteps`` behavior, so ``env.reset()`` is not called if ``reset_num_timesteps=True`` + +Deprecations: +^^^^^^^^^^^^^ + +Others: +^^^^^^^ +- Cleanup rollout return + +Documentation: +^^^^^^^^^^^^^^ + + Pre-Release 0.4.0 (2020-02-14) ------------------------------ @@ -32,28 +58,6 @@ Others: - Refactored action distributions -Documentation: -^^^^^^^^^^^^^^ - -Pre-Release 0.5.0a0 (WIP) ------------------------------- - -Breaking Changes: -^^^^^^^^^^^^^^^^^ - -New Features: -^^^^^^^^^^^^^ - -Bug Fixes: -^^^^^^^^^^ -- Fixed ``reset_num_timesteps`` behavior, so ``env.reset()`` is not called if ``reset_num_timesteps=True`` - -Others: -^^^^^^^ -- Cleanup rollout return - - - Pre-Release 0.3.0 (2020-02-14) ------------------------------ diff --git a/tests/test_custom_policy.py b/tests/test_custom_policy.py index 2555016..d2ffab0 100644 --- a/tests/test_custom_policy.py +++ b/tests/test_custom_policy.py @@ -1,6 +1,7 @@ import pytest +import torch as th -from torchy_baselines import PPO +from torchy_baselines import A2C, PPO, SAC, TD3 @pytest.mark.parametrize('net_arch', [ @@ -11,5 +12,22 @@ from torchy_baselines import PPO [12, dict(vf=[8], pi=[8, 4])], [12, dict(pi=[8])], ]) -def test_flexible_mlp(net_arch): - _ = PPO('MlpPolicy', 'CartPole-v1', policy_kwargs=dict(net_arch=net_arch), n_steps=100).learn(1000) +@pytest.mark.parametrize('model_class', [A2C, PPO]) +def test_flexible_mlp(model_class, net_arch): + _ = model_class('MlpPolicy', 'CartPole-v1', policy_kwargs=dict(net_arch=net_arch), n_steps=100).learn(1000) + + +@pytest.mark.parametrize('net_arch', [ + [4], + [4, 4], +]) +@pytest.mark.parametrize('model_class', [SAC, TD3]) +def test_custom_offpolicy(model_class, net_arch): + _ = model_class('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=net_arch)).learn(1000) + + +@pytest.mark.parametrize('model_class', [A2C, PPO, SAC, TD3]) +@pytest.mark.parametrize('optimizer_kwargs', [None, dict(weight_decay=0.0)]) +def test_custom_optimizer(model_class, optimizer_kwargs): + policy_kwargs = dict(optimizer=th.optim.AdamW, optimizer_kwargs=optimizer_kwargs, net_arch=[32]) + _ = model_class('MlpPolicy', 'Pendulum-v0', policy_kwargs=policy_kwargs).learn(1000) diff --git a/torchy_baselines/a2c/a2c.py b/torchy_baselines/a2c/a2c.py index 634a941..78e9038 100644 --- a/torchy_baselines/a2c/a2c.py +++ b/torchy_baselines/a2c/a2c.py @@ -81,19 +81,15 @@ class A2C(PPO): seed=seed, _init_setup_model=False) self.normalize_advantage = normalize_advantage - self.rms_prop_eps = rms_prop_eps - self.use_rms_prop = use_rms_prop + # Override PPO optimizer to match original implementation + if use_rms_prop and 'optimizer' not in self.policy_kwargs: + self.policy_kwargs['optimizer'] = th.optim.RMSprop + self.policy_kwargs['optimizer_kwargs'] = dict(alpha=0.99, eps=rms_prop_eps, + weight_decay=0) if _init_setup_model: self._setup_model() - def _setup_model(self) -> None: - super(A2C, self)._setup_model() - if self.use_rms_prop: - self.policy.optimizer = th.optim.RMSprop(self.policy.parameters(), - lr=self.lr_schedule(1), alpha=0.99, - eps=self.rms_prop_eps, weight_decay=0) - def train(self, gradient_steps: int, batch_size: Optional[int] = None) -> None: # Update optimizer learning rate self._update_learning_rate(self.policy.optimizer) diff --git a/torchy_baselines/common/policies.py b/torchy_baselines/common/policies.py index 58661f1..e881139 100644 --- a/torchy_baselines/common/policies.py +++ b/torchy_baselines/common/policies.py @@ -37,6 +37,7 @@ class BasePolicy(nn.Module): self.features_extractor = features_extractor self.normalize_images = normalize_images self._squash_output = squash_output + self.optimizer = None # type: Optional[th.optim.Optimizer] def extract_features(self, obs: th.Tensor) -> th.Tensor: """ diff --git a/torchy_baselines/ppo/policies.py b/torchy_baselines/ppo/policies.py index 5ea14a7..0dddf51 100644 --- a/torchy_baselines/ppo/policies.py +++ b/torchy_baselines/ppo/policies.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Tuple, Callable, Union, Dict, Type +from typing import Optional, List, Tuple, Callable, Union, Dict, Type, Any from functools import partial import gym @@ -24,7 +24,6 @@ class PPOPolicy(BasePolicy): :param net_arch: ([int or dict]) The specification of the policy and value networks. :param device: (str or th.device) Device on which the code should run. :param activation_fn: (Type[nn.Module]) Activation function - :param adam_epsilon: (float) Small values to avoid NaN in ADAM optimizer :param ortho_init: (bool) Whether to use or not orthogonal initialization :param use_sde: (bool) Whether to use State Dependent Exploration or not :param log_std_init: (float) Initial value for the log standard deviation @@ -40,6 +39,10 @@ class PPOPolicy(BasePolicy): this allows to ensure boundaries when using SDE. :param normalize_images: (bool) Whether to normalize images or not, dividing by 255.0 (True by default) + :param optimizer: (Type[th.optim.Optimizer]) The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, + excluding the learning rate, to pass to the optimizer """ def __init__(self, observation_space: gym.spaces.Space, @@ -48,7 +51,6 @@ class PPOPolicy(BasePolicy): net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, device: Union[th.device, str] = 'cpu', activation_fn: Type[nn.Module] = nn.Tanh, - adam_epsilon: float = 1e-5, ortho_init: bool = True, use_sde: bool = False, log_std_init: float = 0.0, @@ -56,7 +58,9 @@ class PPOPolicy(BasePolicy): sde_net_arch: Optional[List[int]] = None, use_expln: bool = False, squash_output: bool = False, - normalize_images: bool = True): + normalize_images: bool = True, + optimizer: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None): super(PPOPolicy, self).__init__(observation_space, action_space, device, squash_output=squash_output) # Default network architecture, from stable-baselines @@ -65,7 +69,15 @@ class PPOPolicy(BasePolicy): self.net_arch = net_arch self.activation_fn = activation_fn - self.adam_epsilon = adam_epsilon + + if optimizer_kwargs is None: + optimizer_kwargs = {} + # Small values to avoid NaN in ADAM optimizer + if optimizer == th.optim.Adam: + optimizer_kwargs['eps'] = 1e-5 + + self.optimizer_class = optimizer + self.optimizer_kwargs = optimizer_kwargs self.ortho_init = ortho_init # In the future, feature_extractor will be replaced with a CNN self.features_extractor = nn.Flatten() @@ -142,7 +154,7 @@ class PPOPolicy(BasePolicy): }[module] module.apply(partial(self.init_weights, gain=gain)) # Setup optimizer with initial learning rate - self.optimizer = th.optim.Adam(self.parameters(), lr=lr_schedule(1), eps=self.adam_epsilon) + self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: diff --git a/torchy_baselines/sac/policies.py b/torchy_baselines/sac/policies.py index e83292f..8f4ab51 100644 --- a/torchy_baselines/sac/policies.py +++ b/torchy_baselines/sac/policies.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Tuple, Callable, Union, Type, Dict +from typing import Optional, List, Tuple, Callable, Union, Type, Dict, Any import gym import torch as th @@ -214,6 +214,10 @@ class SACPolicy(BasePolicy): :param clip_mean: (float) Clip the mean output when using SDE to avoid numerical instability. :param normalize_images: (bool) Whether to normalize images or not, dividing by 255.0 (True by default) + :param optimizer: (Type[th.optim.Optimizer]) The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, + excluding the learning rate, to pass to the optimizer """ def __init__(self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, @@ -226,12 +230,20 @@ class SACPolicy(BasePolicy): sde_net_arch: Optional[List[int]] = None, use_expln: bool = False, clip_mean: float = 2.0, - normalize_images: bool = True): + normalize_images: bool = True, + optimizer: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None): super(SACPolicy, self).__init__(observation_space, action_space, device, squash_output=True) if net_arch is None: net_arch = [256, 256] + if optimizer_kwargs is None: + optimizer_kwargs = {} + + self.optimizer_class = optimizer + self.optimizer_kwargs = optimizer_kwargs + # In the future, features_extractor will be replaced with a CNN self.features_extractor = nn.Flatten() self.features_dim = get_obs_dim(self.observation_space) @@ -264,12 +276,14 @@ class SACPolicy(BasePolicy): def _build(self, lr_schedule: Callable) -> None: self.actor = self.make_actor() - self.actor.optimizer = th.optim.Adam(self.actor.parameters(), lr=lr_schedule(1)) + self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), + **self.optimizer_kwargs) self.critic = self.make_critic() self.critic_target = self.make_critic() self.critic_target.load_state_dict(self.critic.state_dict()) - self.critic.optimizer = th.optim.Adam(self.critic.parameters(), lr=lr_schedule(1)) + self.critic.optimizer = self.optimizer_class(self.critic.parameters(), lr=lr_schedule(1), + **self.optimizer_kwargs) def make_actor(self) -> Actor: return Actor(**self.actor_kwargs).to(self.device) diff --git a/torchy_baselines/td3/policies.py b/torchy_baselines/td3/policies.py index 994650c..7cc1dca 100644 --- a/torchy_baselines/td3/policies.py +++ b/torchy_baselines/td3/policies.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Tuple, Callable, Union, Type +from typing import Optional, List, Tuple, Callable, Union, Type, Any, Dict import gym import torch as th @@ -259,6 +259,10 @@ class TD3Policy(BasePolicy): above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. :param normalize_images: (bool) Whether to normalize images or not, dividing by 255.0 (True by default) + :param optimizer: (Type[th.optim.Optimizer]) The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, + excluding the learning rate, to pass to the optimizer """ def __init__(self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, @@ -272,13 +276,21 @@ class TD3Policy(BasePolicy): lr_sde: float = 3e-4, sde_net_arch: Optional[List[int]] = None, use_expln: bool = False, - normalize_images: bool = True): + normalize_images: bool = True, + optimizer: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None): super(TD3Policy, self).__init__(observation_space, action_space, device, squash_output=True) # Default network architecture, from the original paper if net_arch is None: net_arch = [400, 300] + if optimizer_kwargs is None: + optimizer_kwargs = {} + + self.optimizer_class = optimizer + self.optimizer_kwargs = optimizer_kwargs + # In the future, features_extractor will be replaced with a CNN self.features_extractor = nn.Flatten() self.features_dim = get_obs_dim(self.observation_space) @@ -318,13 +330,13 @@ class TD3Policy(BasePolicy): self.actor = self.make_actor() self.actor_target = self.make_actor() self.actor_target.load_state_dict(self.actor.state_dict()) - self.actor.optimizer = th.optim.Adam(self.actor.parameters(), lr=lr_schedule(1)) - + self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), + **self.optimizer_kwargs) self.critic = self.make_critic() self.critic_target = self.make_critic() self.critic_target.load_state_dict(self.critic.state_dict()) - self.critic.optimizer = th.optim.Adam(self.critic.parameters(), lr=lr_schedule(1)) - + self.critic.optimizer = self.optimizer_class(self.critic.parameters(), lr=lr_schedule(1), + **self.optimizer_kwargs) if self.use_sde: self.vf_net = ValueFunction(self.observation_space, self.action_space, features_extractor=self.features_extractor,