From 041f2bc59a67269f7e0397476a6379cd2f5b7ee6 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Wed, 22 Apr 2020 13:14:22 +0200 Subject: [PATCH] Cleanup, bug fixes + more tests --- docs/misc/changelog.rst | 4 +- tests/test_cnn.py | 3 +- tests/test_save_load.py | 34 ++++-- torchy_baselines/a2c/a2c.py | 4 +- torchy_baselines/common/policies.py | 169 ++++++++++++++++++---------- torchy_baselines/ppo/policies.py | 57 ++++++---- torchy_baselines/sac/policies.py | 35 +++--- torchy_baselines/td3/policies.py | 34 +++--- 8 files changed, 208 insertions(+), 132 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 57a0a33..74b5c5b 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -12,10 +12,10 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ -- Added ``optimizer`` and ``optimizer_kwargs`` to ``policy_kwargs`` in order to easily +- Added ``optimizer_class`` and ``optimizer_kwargs`` to ``policy_kwargs`` in order to easily customizer optimizers - Complete independent save/load for policies -- Add ``CnnPolicies`` to support images as input +- Add ``CnnPolicy`` and ``VecTransposeImage`` to support images as input Bug Fixes: diff --git a/tests/test_cnn.py b/tests/test_cnn.py index 9aceafb..92292ec 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -22,7 +22,8 @@ def test_cnn(model_class): else: # Avoid memory error when using replay buffer # Reduce the size of the features - kwargs = dict(buffer_size=250, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=512))) + kwargs = dict(buffer_size=250, + policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32))) model = model_class('CnnPolicy', env, **kwargs).learn(250) obs = env.reset() diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 1e547c0..b1005cc 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -8,6 +8,8 @@ import torch as th from torchy_baselines import A2C, PPO, SAC, TD3 from torchy_baselines.common.identity_env import IdentityEnvBox from torchy_baselines.common.vec_env import DummyVecEnv +from torchy_baselines.common.identity_env import FakeImageEnv + MODEL_LIST = [ PPO, @@ -30,12 +32,11 @@ def test_save_load(model_class): env = DummyVecEnv([lambda: IdentityEnvBox(10)]) # create model - model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True) + model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=1) model.learn(total_timesteps=500, eval_freq=250) env.reset() - observations = np.array([env.step(env.action_space.sample())[0] for _ in range(10)]) - observations = observations.reshape(10, -1) + observations = np.concatenate([env.step(env.action_space.sample())[0] for _ in range(10)], axis=0) # Get dictionary of current parameters params = deepcopy(model.policy.state_dict()) @@ -90,7 +91,7 @@ def test_set_env(model_class): env3 = IdentityEnvBox(10) # create model - model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), create_eval_env=True) + model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16])) # learn model.learn(total_timesteps=1000, eval_freq=500) @@ -115,7 +116,7 @@ def test_exclude_include_saved_params(model_class): env = DummyVecEnv([lambda: IdentityEnvBox(10)]) # create model, set verbose as 2, which is not standard - model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=2, create_eval_env=True) + model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=2) # Check if exclude works model.save("test_save.zip", exclude=["verbose"]) @@ -163,21 +164,34 @@ def test_save_load_replay_buffer(model_class): @pytest.mark.parametrize("model_class", MODEL_LIST) -def test_save_load_policy(model_class): +@pytest.mark.parametrize("policy_str", ['MlpPolicy', 'CnnPolicy']) +def test_save_load_policy(model_class, policy_str): """ Test saving and loading policy only. :param model_class: (BaseRLModel) A RL model + :param policy_str: (str) Name of the policy. """ - env = DummyVecEnv([lambda: IdentityEnvBox(10)]) + kwargs = {} + if policy_str == 'MlpPolicy': + env = IdentityEnvBox(10) + else: + if model_class in [SAC, TD3]: + # Avoid memory error when using replay buffer + # Reduce the size of the features + kwargs = dict(buffer_size=250) + env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=3, + discrete=False) + + env = DummyVecEnv([lambda: env]) # create model - model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True) + model = model_class(policy_str, env, policy_kwargs=dict(net_arch=[16]), + verbose=1, **kwargs) model.learn(total_timesteps=500, eval_freq=250) env.reset() - observations = np.array([env.step(env.action_space.sample())[0] for _ in range(10)]) - observations = observations.reshape(10, -1) + observations = np.concatenate([env.step(env.action_space.sample())[0] for _ in range(10)], axis=0) policy = model.policy policy_class = policy.__class__ diff --git a/torchy_baselines/a2c/a2c.py b/torchy_baselines/a2c/a2c.py index 78e9038..ff10bf4 100644 --- a/torchy_baselines/a2c/a2c.py +++ b/torchy_baselines/a2c/a2c.py @@ -82,8 +82,8 @@ class A2C(PPO): self.normalize_advantage = normalize_advantage # Override PPO optimizer to match original implementation - if use_rms_prop and 'optimizer' not in self.policy_kwargs: - self.policy_kwargs['optimizer'] = th.optim.RMSprop + if use_rms_prop and 'optimizer_class' not in self.policy_kwargs: + self.policy_kwargs['optimizer_class'] = th.optim.RMSprop self.policy_kwargs['optimizer_kwargs'] = dict(alpha=0.99, eps=rms_prop_eps, weight_decay=0) diff --git a/torchy_baselines/common/policies.py b/torchy_baselines/common/policies.py index 8250616..701cc9d 100644 --- a/torchy_baselines/common/policies.py +++ b/torchy_baselines/common/policies.py @@ -12,6 +12,79 @@ from torchy_baselines.common.utils import get_device, get_schedule_fn from torchy_baselines.common.vec_env import VecTransposeImage +class BaseFeaturesExtractor(nn.Module): + """ + Base class that represents a features extractor. + + :param observation_space: (gym.Space) + :param feature_dim: (int) Number of features extracted. + """ + def __init__(self, observation_space: gym.Space, features_dim: int = 0): + super(BaseFeaturesExtractor, self).__init__() + assert features_dim > 0 + self._observation_space = observation_space + self._features_dim = features_dim + + @property + def features_dim(self) -> int: + return self._features_dim + + def forward(self, observations: th.Tensor) -> th.Tensor: + raise NotImplementedError() + + +class FlattenExtractor(BaseFeaturesExtractor): + """ + Feature extract that flatten the input. + Used as a placeholder when feature extraction is not needed. + + :param observation_space: (gym.Space) + """ + def __init__(self, observation_space: gym.Space): + super(FlattenExtractor, self).__init__(observation_space, get_flattened_obs_dim(observation_space)) + self.flatten = nn.Flatten() + + def forward(self, observations: th.Tensor) -> th.Tensor: + return self.flatten(observations) + + +class NatureCNN(BaseFeaturesExtractor): + """ + CNN from DQN nature paper: https://arxiv.org/abs/1312.5602 + + :param observation_space: (gym.Space) + :param feature_dim: (int) Number of features extracted. + This corresponds to the number of unit for the last layer. + """ + def __init__(self, observation_space: gym.spaces.Box, + features_dim: int = 512): + super(NatureCNN, self).__init__(observation_space, features_dim) + # TODO: custom init? + # We assume CxWxH images (channels first) + # Re-ordering will be done by pre-preprocessing or wrapper + is_image_input = is_image_space(observation_space) or is_image_space(observation_space, channels_last=False) + assert is_image_input, ('You should use NatureCNN ' + f'only with images not with {observation_space} ' + '(you are probably using `CnnPolicy` instead of `MlpPolicy`)') + n_input_channels = observation_space.shape[0] + self.cnn = nn.Sequential(nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0), + nn.ReLU(), + nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), + nn.ReLU(), + nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=0), + nn.ReLU(), + nn.Flatten()) + + # Compute shape by doing one forward pass + with th.no_grad(): + n_flatten = self.cnn(th.as_tensor(observation_space.sample()[None]).float()).shape[1] + + self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU()) + + def forward(self, observations: th.Tensor) -> th.Tensor: + return self.linear(self.cnn(observations)) + + class BasePolicy(nn.Module): """ The base policy object @@ -21,26 +94,50 @@ class BasePolicy(nn.Module): :param device: (Union[th.device, str]) Device on which the code should run. :param squash_output: (bool) For continuous actions, whether the output is squashed or not using a ``tanh()`` function. + :param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use. + :param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments + to pass to the feature extractor. :param features_extractor: (nn.Module) Network to extract features (a CNN when using images, a nn.Flatten() layer otherwise) :param normalize_images: (bool) Whether to normalize images or not, dividing by 255.0 (True by default) + :param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, + excluding the learning rate, to pass to the optimizer """ def __init__(self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, device: Union[th.device, str] = 'auto', - squash_output: bool = False, + features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, features_extractor: Optional[nn.Module] = None, - normalize_images: bool = True): + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + squash_output: bool = False): super(BasePolicy, self).__init__() + + if optimizer_kwargs is None: + optimizer_kwargs = {} + + if features_extractor_kwargs is None: + features_extractor_kwargs = {} + self.observation_space = observation_space self.action_space = action_space self.device = get_device(device) self.features_extractor = features_extractor self.normalize_images = normalize_images self._squash_output = squash_output + + self.optimizer_class = optimizer_class + self.optimizer_kwargs = optimizer_kwargs self.optimizer = None # type: Optional[th.optim.Optimizer] + self.features_extractor_class = features_extractor_class + self.features_extractor_kwargs = features_extractor_kwargs + def extract_features(self, obs: th.Tensor) -> th.Tensor: """ Preprocess the observation if needed and extract features. @@ -62,7 +159,7 @@ class BasePolicy(nn.Module): """ Orthogonal initialization (used in PPO and A2C) """ - if isinstance(module, nn.Linear): + if isinstance(module, (nn.Linear, nn.Conv2d)): nn.init.orthogonal_(module.weight, gain=gain) module.bias.data.fill_(0.0) @@ -104,11 +201,18 @@ class BasePolicy(nn.Module): # mask = [False for _ in range(self.n_envs)] observation = np.array(observation) + # Handle the different cases for images + # as PyTorch use channel first format if is_image_space(self.observation_space, channels_last=False): - # TODO: handle the different cases - if (observation.shape != self.observation_space.shape - and observation.shape[1:] != self.observation_space.shape): - observation = VecTransposeImage.transpose_image(observation) + if (observation.shape == self.observation_space.shape or + observation.shape[1:] == self.observation_space.shape): + pass + else: + # Try to re-order the channels + transpose_obs = VecTransposeImage.transpose_image(observation) + if (transpose_obs.shape == self.observation_space.shape + or transpose_obs.shape[1:] == self.observation_space.shape): + observation = transpose_obs vectorized_env = self._is_vectorized_observation(observation, self.observation_space) @@ -463,54 +567,3 @@ class MlpExtractor(nn.Module): """ shared_latent = self.shared_net(features) return self.policy_net(shared_latent), self.value_net(shared_latent) - - -class BaseFeaturesExtractor(nn.Module): - def __init__(self, observation_space: gym.Space, features_dim: int = 0): - super(BaseFeaturesExtractor, self).__init__() - assert features_dim > 0 - self._observation_space = observation_space - self._features_dim = features_dim - - @property - def features_dim(self) -> int: - return self._features_dim - - def forward(self, observations: th.Tensor) -> th.Tensor: - raise NotImplementedError() - - -class FlattenExtractor(BaseFeaturesExtractor): - def __init__(self, observation_space: gym.Space): - super(FlattenExtractor, self).__init__(observation_space, get_flattened_obs_dim(observation_space)) - self.flatten = nn.Flatten() - - def forward(self, observations: th.Tensor) -> th.Tensor: - return self.flatten(observations) - - -class NatureCNN(BaseFeaturesExtractor): - def __init__(self, observation_space: gym.Space, - features_dim: int = 512): - super(NatureCNN, self).__init__(observation_space, features_dim) - # TODO: custom init? - # TODO: check that the observation space is an image - # we assume CxWxH images - # assert is_image_space(observation_space), observation_space - n_input_channels = observation_space.shape[0] - self.cnn = nn.Sequential(nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0), - nn.ReLU(), - nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), - nn.ReLU(), - nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=0), - nn.ReLU(), - nn.Flatten()) - - # Compute shape by doing one forward pass - with th.no_grad(): - n_flatten = self.cnn(th.as_tensor(observation_space.sample()[None]).float()).shape[1] - - self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU()) - - def forward(self, observations: th.Tensor) -> th.Tensor: - return self.linear(self.cnn(observations)) diff --git a/torchy_baselines/ppo/policies.py b/torchy_baselines/ppo/policies.py index 7aa8b88..12fe737 100644 --- a/torchy_baselines/ppo/policies.py +++ b/torchy_baselines/ppo/policies.py @@ -38,9 +38,11 @@ class PPOPolicy(BasePolicy): :param squash_output: (bool) Whether to squash the output using a tanh function, this allows to ensure boundaries when using SDE. :param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use. + :param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments + to pass to the feature extractor. :param normalize_images: (bool) Whether to normalize images or not, dividing by 255.0 (True by default) - :param optimizer: (Type[th.optim.Optimizer]) The optimizer to use, + :param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use, ``th.optim.Adam`` by default :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, excluding the learning rate, to pass to the optimizer @@ -60,10 +62,24 @@ class PPOPolicy(BasePolicy): use_expln: bool = False, squash_output: bool = False, features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, normalize_images: bool = True, - optimizer: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None): - super(PPOPolicy, self).__init__(observation_space, action_space, device, squash_output=squash_output) + + if optimizer_kwargs is None: + optimizer_kwargs = {} + # Small values to avoid NaN in ADAM optimizer + if optimizer_class == th.optim.Adam: + optimizer_kwargs['eps'] = 1e-5 + + super(PPOPolicy, self).__init__(observation_space, action_space, + device, + features_extractor_class, + features_extractor_kwargs, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + squash_output=squash_output) # Default network architecture, from stable-baselines if net_arch is None: @@ -74,19 +90,10 @@ class PPOPolicy(BasePolicy): self.net_arch = net_arch self.activation_fn = activation_fn - - if optimizer_kwargs is None: - optimizer_kwargs = {} - # Small values to avoid NaN in ADAM optimizer - if optimizer == th.optim.Adam: - optimizer_kwargs['eps'] = 1e-5 - - self.optimizer_class = optimizer - self.optimizer_kwargs = optimizer_kwargs self.ortho_init = ortho_init - self.features_extractor_class = features_extractor_class - self.features_extractor = features_extractor_class(self.observation_space) + self.features_extractor = features_extractor_class(self.observation_space, + **self.features_extractor_kwargs) self.features_dim = self.features_extractor.features_dim self.normalize_images = normalize_images @@ -124,10 +131,11 @@ class PPOPolicy(BasePolicy): sde_net_arch=self.dist_kwargs['sde_net_arch'] if self.dist_kwargs else None, use_expln=self.dist_kwargs['use_expln'] if self.dist_kwargs else None, lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone - optimizer=self.optimizer_class, - optimizer_kwargs=self.optimizer_kwargs, ortho_init=self.ortho_init, - features_extractor_class=self.features_extractor_class + optimizer_class=self.optimizer_class, + optimizer_kwargs=self.optimizer_kwargs, + features_extractor_class=self.features_extractor_class, + features_extractor_kwargs=self.features_extractor_kwargs )) return data @@ -173,9 +181,12 @@ class PPOPolicy(BasePolicy): # Init weights: use orthogonal initialization # with small initial weight for the output if self.ortho_init: - for module in [self.mlp_extractor, self.action_net, self.value_net]: + # TODO: check for features_extractor + for module in [self.features_extractor, self.mlp_extractor, + self.action_net, self.value_net]: # Values from stable-baselines, TODO: check why gain = { + self.features_extractor: np.sqrt(2), self.mlp_extractor: np.sqrt(2), self.action_net: 0.01, self.value_net: 1 @@ -300,9 +311,11 @@ class CnnPolicy(PPOPolicy): :param squash_output: (bool) Whether to squash the output using a tanh function, this allows to ensure boundaries when using SDE. :param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use. + :param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments + to pass to the feature extractor. :param normalize_images: (bool) Whether to normalize images or not, dividing by 255.0 (True by default) - :param optimizer: (Type[th.optim.Optimizer]) The optimizer to use, + :param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use, ``th.optim.Adam`` by default :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, excluding the learning rate, to pass to the optimizer @@ -322,8 +335,9 @@ class CnnPolicy(PPOPolicy): use_expln: bool = False, squash_output: bool = False, features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, normalize_images: bool = True, - optimizer: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None): super(CnnPolicy, self).__init__(observation_space, action_space, @@ -339,8 +353,9 @@ class CnnPolicy(PPOPolicy): use_expln, squash_output, features_extractor_class, + features_extractor_kwargs, normalize_images, - optimizer, + optimizer_class, optimizer_kwargs) diff --git a/torchy_baselines/sac/policies.py b/torchy_baselines/sac/policies.py index 97ba375..4205e81 100644 --- a/torchy_baselines/sac/policies.py +++ b/torchy_baselines/sac/policies.py @@ -251,7 +251,7 @@ class SACPolicy(BasePolicy): to pass to the feature extractor. :param normalize_images: (bool) Whether to normalize images or not, dividing by 255.0 (True by default) - :param optimizer: (Type[th.optim.Optimizer]) The optimizer to use, + :param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use, ``th.optim.Adam`` by default :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, excluding the learning rate, to pass to the optimizer @@ -270,9 +270,15 @@ class SACPolicy(BasePolicy): features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, features_extractor_kwargs: Optional[Dict[str, Any]] = None, normalize_images: bool = True, - optimizer: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None): - super(SACPolicy, self).__init__(observation_space, action_space, device, squash_output=True) + super(SACPolicy, self).__init__(observation_space, action_space, + device, + features_extractor_class, + features_extractor_kwargs, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + squash_output=True) if net_arch is None: if features_extractor_class == FlattenExtractor: @@ -280,18 +286,9 @@ class SACPolicy(BasePolicy): else: net_arch = [] - if optimizer_kwargs is None: - optimizer_kwargs = {} - - if features_extractor_kwargs is None: - features_extractor_kwargs = {} - - self.optimizer_class = optimizer - self.optimizer_kwargs = optimizer_kwargs - - self.features_extractor_class = features_extractor_class - self.features_extractor_kwargs = features_extractor_kwargs - self.features_extractor = features_extractor_class(self.observation_space, **features_extractor_kwargs) + # Create shared features extractor + self.features_extractor = features_extractor_class(self.observation_space, + **self.features_extractor_kwargs) self.features_dim = self.features_extractor.features_dim self.net_arch = net_arch @@ -347,7 +344,7 @@ class SACPolicy(BasePolicy): use_expln=self.actor_kwargs['use_expln'], clip_mean=self.actor_kwargs['clip_mean'], lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone - optimizer=self.optimizer_class, + optimizer_class=self.optimizer_class, optimizer_kwargs=self.optimizer_kwargs, features_extractor_class=self.features_extractor_class, features_extractor_kwargs=self.features_extractor_kwargs @@ -392,7 +389,7 @@ class CnnPolicy(SACPolicy): :param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use. :param normalize_images: (bool) Whether to normalize images or not, dividing by 255.0 (True by default) - :param optimizer: (Type[th.optim.Optimizer]) The optimizer to use, + :param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use, ``th.optim.Adam`` by default :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, excluding the learning rate, to pass to the optimizer @@ -411,7 +408,7 @@ class CnnPolicy(SACPolicy): features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, features_extractor_kwargs: Optional[Dict[str, Any]] = None, normalize_images: bool = True, - optimizer: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None): super(CnnPolicy, self).__init__(observation_space, action_space, @@ -427,7 +424,7 @@ class CnnPolicy(SACPolicy): features_extractor_class, features_extractor_kwargs, normalize_images, - optimizer, + optimizer_class, optimizer_kwargs) diff --git a/torchy_baselines/td3/policies.py b/torchy_baselines/td3/policies.py index 0c18526..23aa670 100644 --- a/torchy_baselines/td3/policies.py +++ b/torchy_baselines/td3/policies.py @@ -294,7 +294,7 @@ class TD3Policy(BasePolicy): to pass to the feature extractor. :param normalize_images: (bool) Whether to normalize images or not, dividing by 255.0 (True by default) - :param optimizer: (Type[th.optim.Optimizer]) The optimizer to use, + :param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use, ``th.optim.Adam`` by default :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, excluding the learning rate, to pass to the optimizer @@ -314,9 +314,15 @@ class TD3Policy(BasePolicy): features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, features_extractor_kwargs: Optional[Dict[str, Any]] = None, normalize_images: bool = True, - optimizer: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None): - super(TD3Policy, self).__init__(observation_space, action_space, device, squash_output=True) + super(TD3Policy, self).__init__(observation_space, action_space, + device, + features_extractor_class, + features_extractor_kwargs, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + squash_output=True) # Default network architecture, from the original paper if net_arch is None: @@ -325,18 +331,8 @@ class TD3Policy(BasePolicy): else: net_arch = [] - if optimizer_kwargs is None: - optimizer_kwargs = {} - - if features_extractor_kwargs is None: - features_extractor_kwargs = {} - - self.optimizer_class = optimizer - self.optimizer_kwargs = optimizer_kwargs - - self.features_extractor_class = features_extractor_class - self.features_extractor_kwargs = features_extractor_kwargs - self.features_extractor = features_extractor_class(self.observation_space, **features_extractor_kwargs) + self.features_extractor = features_extractor_class(self.observation_space, + **self.features_extractor_kwargs) self.features_dim = self.features_extractor.features_dim self.net_arch = net_arch @@ -400,7 +396,7 @@ class TD3Policy(BasePolicy): sde_net_arch=self.actor_kwargs['sde_net_arch'], use_expln=self.actor_kwargs['use_expln'], lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone - optimizer=self.optimizer_class, + optimizer_class=self.optimizer_class, optimizer_kwargs=self.optimizer_kwargs, features_extractor_class=self.features_extractor_class, features_extractor_kwargs=self.features_extractor_kwargs @@ -449,7 +445,7 @@ class CnnPolicy(TD3Policy): to pass to the feature extractor. :param normalize_images: (bool) Whether to normalize images or not, dividing by 255.0 (True by default) - :param optimizer: (Type[th.optim.Optimizer]) The optimizer to use, + :param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use, ``th.optim.Adam`` by default :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, excluding the learning rate, to pass to the optimizer @@ -469,7 +465,7 @@ class CnnPolicy(TD3Policy): features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, features_extractor_kwargs: Optional[Dict[str, Any]] = None, normalize_images: bool = True, - optimizer: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None): super(CnnPolicy, self).__init__(observation_space, action_space, @@ -486,7 +482,7 @@ class CnnPolicy(TD3Policy): features_extractor_class, features_extractor_kwargs, normalize_images, - optimizer, + optimizer_class, optimizer_kwargs) register_policy("MlpPolicy", MlpPolicy)