mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-03 03:59:13 +00:00
Added `optimizer and optimizer_kwargs to policy_kwargs`
This commit is contained in:
parent
0e44cdce44
commit
aa1026ee87
7 changed files with 107 additions and 50 deletions
|
|
@ -3,6 +3,32 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Pre-Release 0.5.0a0 (WIP)
|
||||
------------------------------
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Added ``optimizer`` and ``optimizer_kwargs`` to ``policy_kwargs`` in order to easily
|
||||
customizer optimizers
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fixed ``reset_num_timesteps`` behavior, so ``env.reset()`` is not called if ``reset_num_timesteps=True``
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
- Cleanup rollout return
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
|
||||
Pre-Release 0.4.0 (2020-02-14)
|
||||
------------------------------
|
||||
|
||||
|
|
@ -32,28 +58,6 @@ Others:
|
|||
- Refactored action distributions
|
||||
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
Pre-Release 0.5.0a0 (WIP)
|
||||
------------------------------
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fixed ``reset_num_timesteps`` behavior, so ``env.reset()`` is not called if ``reset_num_timesteps=True``
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
- Cleanup rollout return
|
||||
|
||||
|
||||
|
||||
Pre-Release 0.3.0 (2020-02-14)
|
||||
------------------------------
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import pytest
|
||||
import torch as th
|
||||
|
||||
from torchy_baselines import PPO
|
||||
from torchy_baselines import A2C, PPO, SAC, TD3
|
||||
|
||||
|
||||
@pytest.mark.parametrize('net_arch', [
|
||||
|
|
@ -11,5 +12,22 @@ from torchy_baselines import PPO
|
|||
[12, dict(vf=[8], pi=[8, 4])],
|
||||
[12, dict(pi=[8])],
|
||||
])
|
||||
def test_flexible_mlp(net_arch):
|
||||
_ = PPO('MlpPolicy', 'CartPole-v1', policy_kwargs=dict(net_arch=net_arch), n_steps=100).learn(1000)
|
||||
@pytest.mark.parametrize('model_class', [A2C, PPO])
|
||||
def test_flexible_mlp(model_class, net_arch):
|
||||
_ = model_class('MlpPolicy', 'CartPole-v1', policy_kwargs=dict(net_arch=net_arch), n_steps=100).learn(1000)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('net_arch', [
|
||||
[4],
|
||||
[4, 4],
|
||||
])
|
||||
@pytest.mark.parametrize('model_class', [SAC, TD3])
|
||||
def test_custom_offpolicy(model_class, net_arch):
|
||||
_ = model_class('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=net_arch)).learn(1000)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_class', [A2C, PPO, SAC, TD3])
|
||||
@pytest.mark.parametrize('optimizer_kwargs', [None, dict(weight_decay=0.0)])
|
||||
def test_custom_optimizer(model_class, optimizer_kwargs):
|
||||
policy_kwargs = dict(optimizer=th.optim.AdamW, optimizer_kwargs=optimizer_kwargs, net_arch=[32])
|
||||
_ = model_class('MlpPolicy', 'Pendulum-v0', policy_kwargs=policy_kwargs).learn(1000)
|
||||
|
|
|
|||
|
|
@ -81,19 +81,15 @@ class A2C(PPO):
|
|||
seed=seed, _init_setup_model=False)
|
||||
|
||||
self.normalize_advantage = normalize_advantage
|
||||
self.rms_prop_eps = rms_prop_eps
|
||||
self.use_rms_prop = use_rms_prop
|
||||
# Override PPO optimizer to match original implementation
|
||||
if use_rms_prop and 'optimizer' not in self.policy_kwargs:
|
||||
self.policy_kwargs['optimizer'] = th.optim.RMSprop
|
||||
self.policy_kwargs['optimizer_kwargs'] = dict(alpha=0.99, eps=rms_prop_eps,
|
||||
weight_decay=0)
|
||||
|
||||
if _init_setup_model:
|
||||
self._setup_model()
|
||||
|
||||
def _setup_model(self) -> None:
|
||||
super(A2C, self)._setup_model()
|
||||
if self.use_rms_prop:
|
||||
self.policy.optimizer = th.optim.RMSprop(self.policy.parameters(),
|
||||
lr=self.lr_schedule(1), alpha=0.99,
|
||||
eps=self.rms_prop_eps, weight_decay=0)
|
||||
|
||||
def train(self, gradient_steps: int, batch_size: Optional[int] = None) -> None:
|
||||
# Update optimizer learning rate
|
||||
self._update_learning_rate(self.policy.optimizer)
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ class BasePolicy(nn.Module):
|
|||
self.features_extractor = features_extractor
|
||||
self.normalize_images = normalize_images
|
||||
self._squash_output = squash_output
|
||||
self.optimizer = None # type: Optional[th.optim.Optimizer]
|
||||
|
||||
def extract_features(self, obs: th.Tensor) -> th.Tensor:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional, List, Tuple, Callable, Union, Dict, Type
|
||||
from typing import Optional, List, Tuple, Callable, Union, Dict, Type, Any
|
||||
from functools import partial
|
||||
|
||||
import gym
|
||||
|
|
@ -24,7 +24,6 @@ class PPOPolicy(BasePolicy):
|
|||
:param net_arch: ([int or dict]) The specification of the policy and value networks.
|
||||
:param device: (str or th.device) Device on which the code should run.
|
||||
:param activation_fn: (Type[nn.Module]) Activation function
|
||||
:param adam_epsilon: (float) Small values to avoid NaN in ADAM optimizer
|
||||
:param ortho_init: (bool) Whether to use or not orthogonal initialization
|
||||
:param use_sde: (bool) Whether to use State Dependent Exploration or not
|
||||
:param log_std_init: (float) Initial value for the log standard deviation
|
||||
|
|
@ -40,6 +39,10 @@ class PPOPolicy(BasePolicy):
|
|||
this allows to ensure boundaries when using SDE.
|
||||
:param normalize_images: (bool) Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer: (Type[th.optim.Optimizer]) The optimizer to use,
|
||||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
"""
|
||||
def __init__(self,
|
||||
observation_space: gym.spaces.Space,
|
||||
|
|
@ -48,7 +51,6 @@ class PPOPolicy(BasePolicy):
|
|||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||
device: Union[th.device, str] = 'cpu',
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
adam_epsilon: float = 1e-5,
|
||||
ortho_init: bool = True,
|
||||
use_sde: bool = False,
|
||||
log_std_init: float = 0.0,
|
||||
|
|
@ -56,7 +58,9 @@ class PPOPolicy(BasePolicy):
|
|||
sde_net_arch: Optional[List[int]] = None,
|
||||
use_expln: bool = False,
|
||||
squash_output: bool = False,
|
||||
normalize_images: bool = True):
|
||||
normalize_images: bool = True,
|
||||
optimizer: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None):
|
||||
super(PPOPolicy, self).__init__(observation_space, action_space, device, squash_output=squash_output)
|
||||
|
||||
# Default network architecture, from stable-baselines
|
||||
|
|
@ -65,7 +69,15 @@ class PPOPolicy(BasePolicy):
|
|||
|
||||
self.net_arch = net_arch
|
||||
self.activation_fn = activation_fn
|
||||
self.adam_epsilon = adam_epsilon
|
||||
|
||||
if optimizer_kwargs is None:
|
||||
optimizer_kwargs = {}
|
||||
# Small values to avoid NaN in ADAM optimizer
|
||||
if optimizer == th.optim.Adam:
|
||||
optimizer_kwargs['eps'] = 1e-5
|
||||
|
||||
self.optimizer_class = optimizer
|
||||
self.optimizer_kwargs = optimizer_kwargs
|
||||
self.ortho_init = ortho_init
|
||||
# In the future, feature_extractor will be replaced with a CNN
|
||||
self.features_extractor = nn.Flatten()
|
||||
|
|
@ -142,7 +154,7 @@ class PPOPolicy(BasePolicy):
|
|||
}[module]
|
||||
module.apply(partial(self.init_weights, gain=gain))
|
||||
# Setup optimizer with initial learning rate
|
||||
self.optimizer = th.optim.Adam(self.parameters(), lr=lr_schedule(1), eps=self.adam_epsilon)
|
||||
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||
|
||||
def forward(self, obs: th.Tensor,
|
||||
deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional, List, Tuple, Callable, Union, Type, Dict
|
||||
from typing import Optional, List, Tuple, Callable, Union, Type, Dict, Any
|
||||
|
||||
import gym
|
||||
import torch as th
|
||||
|
|
@ -214,6 +214,10 @@ class SACPolicy(BasePolicy):
|
|||
:param clip_mean: (float) Clip the mean output when using SDE to avoid numerical instability.
|
||||
:param normalize_images: (bool) Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer: (Type[th.optim.Optimizer]) The optimizer to use,
|
||||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
"""
|
||||
def __init__(self, observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
|
|
@ -226,12 +230,20 @@ class SACPolicy(BasePolicy):
|
|||
sde_net_arch: Optional[List[int]] = None,
|
||||
use_expln: bool = False,
|
||||
clip_mean: float = 2.0,
|
||||
normalize_images: bool = True):
|
||||
normalize_images: bool = True,
|
||||
optimizer: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None):
|
||||
super(SACPolicy, self).__init__(observation_space, action_space, device, squash_output=True)
|
||||
|
||||
if net_arch is None:
|
||||
net_arch = [256, 256]
|
||||
|
||||
if optimizer_kwargs is None:
|
||||
optimizer_kwargs = {}
|
||||
|
||||
self.optimizer_class = optimizer
|
||||
self.optimizer_kwargs = optimizer_kwargs
|
||||
|
||||
# In the future, features_extractor will be replaced with a CNN
|
||||
self.features_extractor = nn.Flatten()
|
||||
self.features_dim = get_obs_dim(self.observation_space)
|
||||
|
|
@ -264,12 +276,14 @@ class SACPolicy(BasePolicy):
|
|||
|
||||
def _build(self, lr_schedule: Callable) -> None:
|
||||
self.actor = self.make_actor()
|
||||
self.actor.optimizer = th.optim.Adam(self.actor.parameters(), lr=lr_schedule(1))
|
||||
self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1),
|
||||
**self.optimizer_kwargs)
|
||||
|
||||
self.critic = self.make_critic()
|
||||
self.critic_target = self.make_critic()
|
||||
self.critic_target.load_state_dict(self.critic.state_dict())
|
||||
self.critic.optimizer = th.optim.Adam(self.critic.parameters(), lr=lr_schedule(1))
|
||||
self.critic.optimizer = self.optimizer_class(self.critic.parameters(), lr=lr_schedule(1),
|
||||
**self.optimizer_kwargs)
|
||||
|
||||
def make_actor(self) -> Actor:
|
||||
return Actor(**self.actor_kwargs).to(self.device)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional, List, Tuple, Callable, Union, Type
|
||||
from typing import Optional, List, Tuple, Callable, Union, Type, Any, Dict
|
||||
|
||||
import gym
|
||||
import torch as th
|
||||
|
|
@ -259,6 +259,10 @@ class TD3Policy(BasePolicy):
|
|||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||
:param normalize_images: (bool) Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer: (Type[th.optim.Optimizer]) The optimizer to use,
|
||||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
"""
|
||||
def __init__(self, observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
|
|
@ -272,13 +276,21 @@ class TD3Policy(BasePolicy):
|
|||
lr_sde: float = 3e-4,
|
||||
sde_net_arch: Optional[List[int]] = None,
|
||||
use_expln: bool = False,
|
||||
normalize_images: bool = True):
|
||||
normalize_images: bool = True,
|
||||
optimizer: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None):
|
||||
super(TD3Policy, self).__init__(observation_space, action_space, device, squash_output=True)
|
||||
|
||||
# Default network architecture, from the original paper
|
||||
if net_arch is None:
|
||||
net_arch = [400, 300]
|
||||
|
||||
if optimizer_kwargs is None:
|
||||
optimizer_kwargs = {}
|
||||
|
||||
self.optimizer_class = optimizer
|
||||
self.optimizer_kwargs = optimizer_kwargs
|
||||
|
||||
# In the future, features_extractor will be replaced with a CNN
|
||||
self.features_extractor = nn.Flatten()
|
||||
self.features_dim = get_obs_dim(self.observation_space)
|
||||
|
|
@ -318,13 +330,13 @@ class TD3Policy(BasePolicy):
|
|||
self.actor = self.make_actor()
|
||||
self.actor_target = self.make_actor()
|
||||
self.actor_target.load_state_dict(self.actor.state_dict())
|
||||
self.actor.optimizer = th.optim.Adam(self.actor.parameters(), lr=lr_schedule(1))
|
||||
|
||||
self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1),
|
||||
**self.optimizer_kwargs)
|
||||
self.critic = self.make_critic()
|
||||
self.critic_target = self.make_critic()
|
||||
self.critic_target.load_state_dict(self.critic.state_dict())
|
||||
self.critic.optimizer = th.optim.Adam(self.critic.parameters(), lr=lr_schedule(1))
|
||||
|
||||
self.critic.optimizer = self.optimizer_class(self.critic.parameters(), lr=lr_schedule(1),
|
||||
**self.optimizer_kwargs)
|
||||
if self.use_sde:
|
||||
self.vf_net = ValueFunction(self.observation_space, self.action_space,
|
||||
features_extractor=self.features_extractor,
|
||||
|
|
|
|||
Loading…
Reference in a new issue