From 73fb8d1c6303fa7e7bf53522e59ac91e1ea36fd7 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Wed, 22 Apr 2020 11:05:46 +0200 Subject: [PATCH] Add CNN support for TD3 --- tests/test_cnn.py | 26 +++++- torchy_baselines/common/policies.py | 4 +- torchy_baselines/common/preprocessing.py | 6 +- torchy_baselines/sac/policies.py | 22 +++-- torchy_baselines/td3/policies.py | 104 ++++++++++++++++++++--- 5 files changed, 135 insertions(+), 27 deletions(-) diff --git a/tests/test_cnn.py b/tests/test_cnn.py index d0b7ed7..9aceafb 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -1,10 +1,16 @@ +import os + +import numpy as np import pytest from torchy_baselines import A2C, PPO, SAC, TD3 from torchy_baselines.common.identity_env import FakeImageEnv -@pytest.mark.parametrize('model_class', [A2C, PPO, SAC]) +SAVE_PATH = './cnn_model.zip' + + +@pytest.mark.parametrize('model_class', [A2C, PPO, SAC, TD3]) def test_cnn(model_class): # Fake grayscale with frameskip # Atari after preprocessing: 84x84x1, here we are using lower resolution @@ -16,5 +22,19 @@ def test_cnn(model_class): else: # Avoid memory error when using replay buffer # Reduce the size of the features - kwargs = dict(buffer_size=500, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=40))) - _ = model_class('CnnPolicy', env, **kwargs).learn(500) + kwargs = dict(buffer_size=250, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=512))) + model = model_class('CnnPolicy', env, **kwargs).learn(250) + + obs = env.reset() + + action, _ = model.predict(obs, deterministic=True) + + model.save(SAVE_PATH) + del model + + model = model_class.load(SAVE_PATH) + + # Check that the prediction is the same + assert np.allclose(action, model.predict(obs, deterministic=True)[0]) + + os.remove(SAVE_PATH) diff --git a/torchy_baselines/common/policies.py b/torchy_baselines/common/policies.py index 675e9be..8250616 100644 --- a/torchy_baselines/common/policies.py +++ b/torchy_baselines/common/policies.py @@ -7,7 +7,7 @@ import torch as th import torch.nn as nn import numpy as np -from torchy_baselines.common.preprocessing import preprocess_obs, get_obs_dim, is_image_space +from torchy_baselines.common.preprocessing import preprocess_obs, get_flattened_obs_dim, is_image_space from torchy_baselines.common.utils import get_device, get_schedule_fn from torchy_baselines.common.vec_env import VecTransposeImage @@ -482,7 +482,7 @@ class BaseFeaturesExtractor(nn.Module): class FlattenExtractor(BaseFeaturesExtractor): def __init__(self, observation_space: gym.Space): - super(FlattenExtractor, self).__init__(observation_space, get_obs_dim(observation_space)) + super(FlattenExtractor, self).__init__(observation_space, get_flattened_obs_dim(observation_space)) self.flatten = nn.Flatten() def forward(self, observations: th.Tensor) -> th.Tensor: diff --git a/torchy_baselines/common/preprocessing.py b/torchy_baselines/common/preprocessing.py index a52226e..2caebcc 100644 --- a/torchy_baselines/common/preprocessing.py +++ b/torchy_baselines/common/preprocessing.py @@ -79,10 +79,10 @@ def get_obs_shape(observation_space: spaces.Space) -> Tuple[int, ...]: raise NotImplementedError() -def get_obs_dim(observation_space: spaces.Space) -> Union[int, Tuple[int, ...]]: +def get_flattened_obs_dim(observation_space: spaces.Space) -> Union[int, Tuple[int, ...]]: """ - Get the dimension of the observation space. - It should not be used when using images. + Get the dimension of the observation space when flattened. + It does not apply to image observation space. :param observation_space: (spaces.Space) :return: (Union[int, Tuple[int, ...]]) diff --git a/torchy_baselines/sac/policies.py b/torchy_baselines/sac/policies.py index f066da0..97ba375 100644 --- a/torchy_baselines/sac/policies.py +++ b/torchy_baselines/sac/policies.py @@ -61,10 +61,7 @@ class Actor(BasePolicy): device=device, squash_output=True) - action_dim = get_action_dim(self.action_space) - - latent_pi_net = create_mlp(features_dim, -1, net_arch, activation_fn) - self.latent_pi = nn.Sequential(*latent_pi_net) + # Save arguments to re-create object at loading self.use_sde = use_sde self.sde_features_extractor = None self.sde_net_arch = sde_net_arch @@ -76,6 +73,10 @@ class Actor(BasePolicy): self.use_expln = use_expln self.full_std = full_std self.clip_mean = clip_mean + + action_dim = get_action_dim(self.action_space) + latent_pi_net = create_mlp(features_dim, -1, net_arch, activation_fn) + self.latent_pi = nn.Sequential(*latent_pi_net) last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim @@ -218,7 +219,10 @@ class Critic(BasePolicy): self.q_networks = [self.q1_net, self.q2_net] def forward(self, obs: th.Tensor, action: th.Tensor) -> List[th.Tensor]: - features = self.extract_features(obs) + # Learn the features extractor using the policy loss only + # this is much faster + with th.no_grad(): + features = self.extract_features(obs) qvalue_input = th.cat([features, action], dim=1) return [q_net(qvalue_input) for q_net in self.q_networks] @@ -279,13 +283,14 @@ class SACPolicy(BasePolicy): if optimizer_kwargs is None: optimizer_kwargs = {} + if features_extractor_kwargs is None: + features_extractor_kwargs = {} + self.optimizer_class = optimizer self.optimizer_kwargs = optimizer_kwargs self.features_extractor_class = features_extractor_class self.features_extractor_kwargs = features_extractor_kwargs - if features_extractor_kwargs is None: - features_extractor_kwargs = {} self.features_extractor = features_extractor_class(self.observation_space, **features_extractor_kwargs) self.features_dim = self.features_extractor.features_dim @@ -325,8 +330,7 @@ class SACPolicy(BasePolicy): self.critic_target.load_state_dict(self.critic.state_dict()) # Do not optimize the shared feature extractor with the critic loss # otherwise, there are gradient computation issues - # another solution: having duplicated features extractor but requires more memory and computation - # Note: check gradients, they are maybe computed but not zeroed by the critic + # Another solution: having duplicated features extractor but requires more memory and computation critic_parameters = [param for name, param in self.critic.named_parameters() if 'features_extractor' not in name] self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs) diff --git a/torchy_baselines/td3/policies.py b/torchy_baselines/td3/policies.py index c1ecbc8..0c18526 100644 --- a/torchy_baselines/td3/policies.py +++ b/torchy_baselines/td3/policies.py @@ -4,9 +4,10 @@ import gym import torch as th import torch.nn as nn -from torchy_baselines.common.preprocessing import get_action_dim, get_obs_dim +from torchy_baselines.common.preprocessing import get_action_dim from torchy_baselines.common.policies import (BasePolicy, register_policy, create_mlp, - create_sde_features_extractor) + create_sde_features_extractor, NatureCNN, + BaseFeaturesExtractor, FlattenExtractor) from torchy_baselines.common.distributions import StateDependentNoiseDistribution, Distribution @@ -221,12 +222,15 @@ class Critic(BasePolicy): self.q2_net = nn.Sequential(*q2_net) def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: - features = self.extract_features(obs) + # Learn the features extractor using the policy loss only + with th.no_grad(): + features = self.extract_features(obs) qvalue_input = th.cat([features, actions], dim=1) return self.q1_net(qvalue_input), self.q2_net(qvalue_input) def q1_forward(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor: - features = self.extract_features(obs) + with th.no_grad(): + features = self.extract_features(obs) return self.q1_net(th.cat([features, actions], dim=1)) @@ -262,7 +266,8 @@ class ValueFunction(BasePolicy): self.vf_net = nn.Sequential(*vf_net) def forward(self, obs: th.Tensor) -> th.Tensor: - features = self.extract_features(obs) + with th.no_grad(): + features = self.extract_features(obs) return self.vf_net(features) @@ -284,6 +289,9 @@ class TD3Policy(BasePolicy): :param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using SDE to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use. + :param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments + to pass to the feature extractor. :param normalize_images: (bool) Whether to normalize images or not, dividing by 255.0 (True by default) :param optimizer: (Type[th.optim.Optimizer]) The optimizer to use, @@ -303,6 +311,8 @@ class TD3Policy(BasePolicy): lr_sde: float = 3e-4, sde_net_arch: Optional[List[int]] = None, use_expln: bool = False, + features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, normalize_images: bool = True, optimizer: Type[th.optim.Optimizer] = th.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None): @@ -310,17 +320,24 @@ class TD3Policy(BasePolicy): # Default network architecture, from the original paper if net_arch is None: - net_arch = [400, 300] + if features_extractor_class == FlattenExtractor: + net_arch = [400, 300] + else: + net_arch = [] if optimizer_kwargs is None: optimizer_kwargs = {} + if features_extractor_kwargs is None: + features_extractor_kwargs = {} + self.optimizer_class = optimizer self.optimizer_kwargs = optimizer_kwargs - # In the future, features_extractor will be replaced with a CNN - self.features_extractor = nn.Flatten() - self.features_dim = get_obs_dim(self.observation_space) + self.features_extractor_class = features_extractor_class + self.features_extractor_kwargs = features_extractor_kwargs + self.features_extractor = features_extractor_class(self.observation_space, **features_extractor_kwargs) + self.features_dim = self.features_extractor.features_dim self.net_arch = net_arch self.activation_fn = activation_fn @@ -384,7 +401,9 @@ class TD3Policy(BasePolicy): use_expln=self.actor_kwargs['use_expln'], lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone optimizer=self.optimizer_class, - optimizer_kwargs=self.optimizer_kwargs + optimizer_kwargs=self.optimizer_kwargs, + features_extractor_class=self.features_extractor_class, + features_extractor_kwargs=self.features_extractor_kwargs )) return data @@ -406,4 +425,69 @@ class TD3Policy(BasePolicy): MlpPolicy = TD3Policy + +class CnnPolicy(TD3Policy): + """ + Policy class (with both actor and critic) for TD3. + + :param observation_space: (gym.spaces.Space) Observation space + :param action_space: (gym.spaces.Space) Action space + :param lr_schedule: (Callable) Learning rate schedule (could be constant) + :param net_arch: (Optional[List[int]]) The specification of the policy and value networks. + :param device: (Union[th.device, str]) Device on which the code should run. + :param activation_fn: (Type[nn.Module]) Activation function + :param use_sde: (bool) Whether to use State Dependent Exploration or not + :param log_std_init: (float) Initial value for the log standard deviation + :param sde_net_arch: ([int]) Network architecture for extracting features + when using SDE. If None, the latent features from the policy will be used. + Pass an empty list to use the states as features. + :param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using SDE to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use. + :param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments + to pass to the feature extractor. + :param normalize_images: (bool) Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer: (Type[th.optim.Optimizer]) The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + def __init__(self, observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Callable, + net_arch: Optional[List[int]] = None, + device: Union[th.device, str] = 'auto', + activation_fn: Type[nn.Module] = nn.ReLU, + use_sde: bool = False, + log_std_init: float = -3, + clip_noise: Optional[float] = None, + lr_sde: float = 3e-4, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None): + super(CnnPolicy, self).__init__(observation_space, + action_space, + lr_schedule, + net_arch, + device, + activation_fn, + use_sde, + log_std_init, + clip_noise, + lr_sde, + sde_net_arch, + use_expln, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer, + optimizer_kwargs) + register_policy("MlpPolicy", MlpPolicy) +register_policy("CnnPolicy", CnnPolicy)