Added `optimizer and optimizer_kwargs to policy_kwargs`

This commit is contained in:
Antonin RAFFIN 2020-04-17 15:13:45 +02:00
parent 0e44cdce44
commit aa1026ee87
7 changed files with 107 additions and 50 deletions

View file

@ -3,6 +3,32 @@
Changelog
==========
Pre-Release 0.5.0a0 (WIP)
------------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
New Features:
^^^^^^^^^^^^^
- Added ``optimizer`` and ``optimizer_kwargs`` to ``policy_kwargs`` in order to easily
customizer optimizers
Bug Fixes:
^^^^^^^^^^
- Fixed ``reset_num_timesteps`` behavior, so ``env.reset()`` is not called if ``reset_num_timesteps=True``
Deprecations:
^^^^^^^^^^^^^
Others:
^^^^^^^
- Cleanup rollout return
Documentation:
^^^^^^^^^^^^^^
Pre-Release 0.4.0 (2020-02-14)
------------------------------
@ -32,28 +58,6 @@ Others:
- Refactored action distributions
Documentation:
^^^^^^^^^^^^^^
Pre-Release 0.5.0a0 (WIP)
------------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
New Features:
^^^^^^^^^^^^^
Bug Fixes:
^^^^^^^^^^
- Fixed ``reset_num_timesteps`` behavior, so ``env.reset()`` is not called if ``reset_num_timesteps=True``
Others:
^^^^^^^
- Cleanup rollout return
Pre-Release 0.3.0 (2020-02-14)
------------------------------

View file

@ -1,6 +1,7 @@
import pytest
import torch as th
from torchy_baselines import PPO
from torchy_baselines import A2C, PPO, SAC, TD3
@pytest.mark.parametrize('net_arch', [
@ -11,5 +12,22 @@ from torchy_baselines import PPO
[12, dict(vf=[8], pi=[8, 4])],
[12, dict(pi=[8])],
])
def test_flexible_mlp(net_arch):
_ = PPO('MlpPolicy', 'CartPole-v1', policy_kwargs=dict(net_arch=net_arch), n_steps=100).learn(1000)
@pytest.mark.parametrize('model_class', [A2C, PPO])
def test_flexible_mlp(model_class, net_arch):
_ = model_class('MlpPolicy', 'CartPole-v1', policy_kwargs=dict(net_arch=net_arch), n_steps=100).learn(1000)
@pytest.mark.parametrize('net_arch', [
[4],
[4, 4],
])
@pytest.mark.parametrize('model_class', [SAC, TD3])
def test_custom_offpolicy(model_class, net_arch):
_ = model_class('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=net_arch)).learn(1000)
@pytest.mark.parametrize('model_class', [A2C, PPO, SAC, TD3])
@pytest.mark.parametrize('optimizer_kwargs', [None, dict(weight_decay=0.0)])
def test_custom_optimizer(model_class, optimizer_kwargs):
policy_kwargs = dict(optimizer=th.optim.AdamW, optimizer_kwargs=optimizer_kwargs, net_arch=[32])
_ = model_class('MlpPolicy', 'Pendulum-v0', policy_kwargs=policy_kwargs).learn(1000)

View file

@ -81,19 +81,15 @@ class A2C(PPO):
seed=seed, _init_setup_model=False)
self.normalize_advantage = normalize_advantage
self.rms_prop_eps = rms_prop_eps
self.use_rms_prop = use_rms_prop
# Override PPO optimizer to match original implementation
if use_rms_prop and 'optimizer' not in self.policy_kwargs:
self.policy_kwargs['optimizer'] = th.optim.RMSprop
self.policy_kwargs['optimizer_kwargs'] = dict(alpha=0.99, eps=rms_prop_eps,
weight_decay=0)
if _init_setup_model:
self._setup_model()
def _setup_model(self) -> None:
super(A2C, self)._setup_model()
if self.use_rms_prop:
self.policy.optimizer = th.optim.RMSprop(self.policy.parameters(),
lr=self.lr_schedule(1), alpha=0.99,
eps=self.rms_prop_eps, weight_decay=0)
def train(self, gradient_steps: int, batch_size: Optional[int] = None) -> None:
# Update optimizer learning rate
self._update_learning_rate(self.policy.optimizer)

View file

@ -37,6 +37,7 @@ class BasePolicy(nn.Module):
self.features_extractor = features_extractor
self.normalize_images = normalize_images
self._squash_output = squash_output
self.optimizer = None # type: Optional[th.optim.Optimizer]
def extract_features(self, obs: th.Tensor) -> th.Tensor:
"""

View file

@ -1,4 +1,4 @@
from typing import Optional, List, Tuple, Callable, Union, Dict, Type
from typing import Optional, List, Tuple, Callable, Union, Dict, Type, Any
from functools import partial
import gym
@ -24,7 +24,6 @@ class PPOPolicy(BasePolicy):
:param net_arch: ([int or dict]) The specification of the policy and value networks.
:param device: (str or th.device) Device on which the code should run.
:param activation_fn: (Type[nn.Module]) Activation function
:param adam_epsilon: (float) Small values to avoid NaN in ADAM optimizer
:param ortho_init: (bool) Whether to use or not orthogonal initialization
:param use_sde: (bool) Whether to use State Dependent Exploration or not
:param log_std_init: (float) Initial value for the log standard deviation
@ -40,6 +39,10 @@ class PPOPolicy(BasePolicy):
this allows to ensure boundaries when using SDE.
:param normalize_images: (bool) Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer: (Type[th.optim.Optimizer]) The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
def __init__(self,
observation_space: gym.spaces.Space,
@ -48,7 +51,6 @@ class PPOPolicy(BasePolicy):
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
device: Union[th.device, str] = 'cpu',
activation_fn: Type[nn.Module] = nn.Tanh,
adam_epsilon: float = 1e-5,
ortho_init: bool = True,
use_sde: bool = False,
log_std_init: float = 0.0,
@ -56,7 +58,9 @@ class PPOPolicy(BasePolicy):
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
squash_output: bool = False,
normalize_images: bool = True):
normalize_images: bool = True,
optimizer: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None):
super(PPOPolicy, self).__init__(observation_space, action_space, device, squash_output=squash_output)
# Default network architecture, from stable-baselines
@ -65,7 +69,15 @@ class PPOPolicy(BasePolicy):
self.net_arch = net_arch
self.activation_fn = activation_fn
self.adam_epsilon = adam_epsilon
if optimizer_kwargs is None:
optimizer_kwargs = {}
# Small values to avoid NaN in ADAM optimizer
if optimizer == th.optim.Adam:
optimizer_kwargs['eps'] = 1e-5
self.optimizer_class = optimizer
self.optimizer_kwargs = optimizer_kwargs
self.ortho_init = ortho_init
# In the future, feature_extractor will be replaced with a CNN
self.features_extractor = nn.Flatten()
@ -142,7 +154,7 @@ class PPOPolicy(BasePolicy):
}[module]
module.apply(partial(self.init_weights, gain=gain))
# Setup optimizer with initial learning rate
self.optimizer = th.optim.Adam(self.parameters(), lr=lr_schedule(1), eps=self.adam_epsilon)
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
def forward(self, obs: th.Tensor,
deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:

View file

@ -1,4 +1,4 @@
from typing import Optional, List, Tuple, Callable, Union, Type, Dict
from typing import Optional, List, Tuple, Callable, Union, Type, Dict, Any
import gym
import torch as th
@ -214,6 +214,10 @@ class SACPolicy(BasePolicy):
:param clip_mean: (float) Clip the mean output when using SDE to avoid numerical instability.
:param normalize_images: (bool) Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer: (Type[th.optim.Optimizer]) The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
def __init__(self, observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
@ -226,12 +230,20 @@ class SACPolicy(BasePolicy):
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
clip_mean: float = 2.0,
normalize_images: bool = True):
normalize_images: bool = True,
optimizer: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None):
super(SACPolicy, self).__init__(observation_space, action_space, device, squash_output=True)
if net_arch is None:
net_arch = [256, 256]
if optimizer_kwargs is None:
optimizer_kwargs = {}
self.optimizer_class = optimizer
self.optimizer_kwargs = optimizer_kwargs
# In the future, features_extractor will be replaced with a CNN
self.features_extractor = nn.Flatten()
self.features_dim = get_obs_dim(self.observation_space)
@ -264,12 +276,14 @@ class SACPolicy(BasePolicy):
def _build(self, lr_schedule: Callable) -> None:
self.actor = self.make_actor()
self.actor.optimizer = th.optim.Adam(self.actor.parameters(), lr=lr_schedule(1))
self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1),
**self.optimizer_kwargs)
self.critic = self.make_critic()
self.critic_target = self.make_critic()
self.critic_target.load_state_dict(self.critic.state_dict())
self.critic.optimizer = th.optim.Adam(self.critic.parameters(), lr=lr_schedule(1))
self.critic.optimizer = self.optimizer_class(self.critic.parameters(), lr=lr_schedule(1),
**self.optimizer_kwargs)
def make_actor(self) -> Actor:
return Actor(**self.actor_kwargs).to(self.device)

View file

@ -1,4 +1,4 @@
from typing import Optional, List, Tuple, Callable, Union, Type
from typing import Optional, List, Tuple, Callable, Union, Type, Any, Dict
import gym
import torch as th
@ -259,6 +259,10 @@ class TD3Policy(BasePolicy):
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param normalize_images: (bool) Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer: (Type[th.optim.Optimizer]) The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
def __init__(self, observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
@ -272,13 +276,21 @@ class TD3Policy(BasePolicy):
lr_sde: float = 3e-4,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
normalize_images: bool = True):
normalize_images: bool = True,
optimizer: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None):
super(TD3Policy, self).__init__(observation_space, action_space, device, squash_output=True)
# Default network architecture, from the original paper
if net_arch is None:
net_arch = [400, 300]
if optimizer_kwargs is None:
optimizer_kwargs = {}
self.optimizer_class = optimizer
self.optimizer_kwargs = optimizer_kwargs
# In the future, features_extractor will be replaced with a CNN
self.features_extractor = nn.Flatten()
self.features_dim = get_obs_dim(self.observation_space)
@ -318,13 +330,13 @@ class TD3Policy(BasePolicy):
self.actor = self.make_actor()
self.actor_target = self.make_actor()
self.actor_target.load_state_dict(self.actor.state_dict())
self.actor.optimizer = th.optim.Adam(self.actor.parameters(), lr=lr_schedule(1))
self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1),
**self.optimizer_kwargs)
self.critic = self.make_critic()
self.critic_target = self.make_critic()
self.critic_target.load_state_dict(self.critic.state_dict())
self.critic.optimizer = th.optim.Adam(self.critic.parameters(), lr=lr_schedule(1))
self.critic.optimizer = self.optimizer_class(self.critic.parameters(), lr=lr_schedule(1),
**self.optimizer_kwargs)
if self.use_sde:
self.vf_net = ValueFunction(self.observation_space, self.action_space,
features_extractor=self.features_extractor,