Add CNN support for TD3

This commit is contained in:
Antonin RAFFIN 2020-04-22 11:05:46 +02:00
parent 8f4155180e
commit 73fb8d1c63
5 changed files with 135 additions and 27 deletions

View file

@ -1,10 +1,16 @@
import os
import numpy as np
import pytest
from torchy_baselines import A2C, PPO, SAC, TD3
from torchy_baselines.common.identity_env import FakeImageEnv
@pytest.mark.parametrize('model_class', [A2C, PPO, SAC])
SAVE_PATH = './cnn_model.zip'
@pytest.mark.parametrize('model_class', [A2C, PPO, SAC, TD3])
def test_cnn(model_class):
# Fake grayscale with frameskip
# Atari after preprocessing: 84x84x1, here we are using lower resolution
@ -16,5 +22,19 @@ def test_cnn(model_class):
else:
# Avoid memory error when using replay buffer
# Reduce the size of the features
kwargs = dict(buffer_size=500, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=40)))
_ = model_class('CnnPolicy', env, **kwargs).learn(500)
kwargs = dict(buffer_size=250, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=512)))
model = model_class('CnnPolicy', env, **kwargs).learn(250)
obs = env.reset()
action, _ = model.predict(obs, deterministic=True)
model.save(SAVE_PATH)
del model
model = model_class.load(SAVE_PATH)
# Check that the prediction is the same
assert np.allclose(action, model.predict(obs, deterministic=True)[0])
os.remove(SAVE_PATH)

View file

@ -7,7 +7,7 @@ import torch as th
import torch.nn as nn
import numpy as np
from torchy_baselines.common.preprocessing import preprocess_obs, get_obs_dim, is_image_space
from torchy_baselines.common.preprocessing import preprocess_obs, get_flattened_obs_dim, is_image_space
from torchy_baselines.common.utils import get_device, get_schedule_fn
from torchy_baselines.common.vec_env import VecTransposeImage
@ -482,7 +482,7 @@ class BaseFeaturesExtractor(nn.Module):
class FlattenExtractor(BaseFeaturesExtractor):
def __init__(self, observation_space: gym.Space):
super(FlattenExtractor, self).__init__(observation_space, get_obs_dim(observation_space))
super(FlattenExtractor, self).__init__(observation_space, get_flattened_obs_dim(observation_space))
self.flatten = nn.Flatten()
def forward(self, observations: th.Tensor) -> th.Tensor:

View file

@ -79,10 +79,10 @@ def get_obs_shape(observation_space: spaces.Space) -> Tuple[int, ...]:
raise NotImplementedError()
def get_obs_dim(observation_space: spaces.Space) -> Union[int, Tuple[int, ...]]:
def get_flattened_obs_dim(observation_space: spaces.Space) -> Union[int, Tuple[int, ...]]:
"""
Get the dimension of the observation space.
It should not be used when using images.
Get the dimension of the observation space when flattened.
It does not apply to image observation space.
:param observation_space: (spaces.Space)
:return: (Union[int, Tuple[int, ...]])

View file

@ -61,10 +61,7 @@ class Actor(BasePolicy):
device=device,
squash_output=True)
action_dim = get_action_dim(self.action_space)
latent_pi_net = create_mlp(features_dim, -1, net_arch, activation_fn)
self.latent_pi = nn.Sequential(*latent_pi_net)
# Save arguments to re-create object at loading
self.use_sde = use_sde
self.sde_features_extractor = None
self.sde_net_arch = sde_net_arch
@ -76,6 +73,10 @@ class Actor(BasePolicy):
self.use_expln = use_expln
self.full_std = full_std
self.clip_mean = clip_mean
action_dim = get_action_dim(self.action_space)
latent_pi_net = create_mlp(features_dim, -1, net_arch, activation_fn)
self.latent_pi = nn.Sequential(*latent_pi_net)
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim
@ -218,7 +219,10 @@ class Critic(BasePolicy):
self.q_networks = [self.q1_net, self.q2_net]
def forward(self, obs: th.Tensor, action: th.Tensor) -> List[th.Tensor]:
features = self.extract_features(obs)
# Learn the features extractor using the policy loss only
# this is much faster
with th.no_grad():
features = self.extract_features(obs)
qvalue_input = th.cat([features, action], dim=1)
return [q_net(qvalue_input) for q_net in self.q_networks]
@ -279,13 +283,14 @@ class SACPolicy(BasePolicy):
if optimizer_kwargs is None:
optimizer_kwargs = {}
if features_extractor_kwargs is None:
features_extractor_kwargs = {}
self.optimizer_class = optimizer
self.optimizer_kwargs = optimizer_kwargs
self.features_extractor_class = features_extractor_class
self.features_extractor_kwargs = features_extractor_kwargs
if features_extractor_kwargs is None:
features_extractor_kwargs = {}
self.features_extractor = features_extractor_class(self.observation_space, **features_extractor_kwargs)
self.features_dim = self.features_extractor.features_dim
@ -325,8 +330,7 @@ class SACPolicy(BasePolicy):
self.critic_target.load_state_dict(self.critic.state_dict())
# Do not optimize the shared feature extractor with the critic loss
# otherwise, there are gradient computation issues
# another solution: having duplicated features extractor but requires more memory and computation
# Note: check gradients, they are maybe computed but not zeroed by the critic
# Another solution: having duplicated features extractor but requires more memory and computation
critic_parameters = [param for name, param in self.critic.named_parameters() if 'features_extractor' not in name]
self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1),
**self.optimizer_kwargs)

View file

@ -4,9 +4,10 @@ import gym
import torch as th
import torch.nn as nn
from torchy_baselines.common.preprocessing import get_action_dim, get_obs_dim
from torchy_baselines.common.preprocessing import get_action_dim
from torchy_baselines.common.policies import (BasePolicy, register_policy, create_mlp,
create_sde_features_extractor)
create_sde_features_extractor, NatureCNN,
BaseFeaturesExtractor, FlattenExtractor)
from torchy_baselines.common.distributions import StateDependentNoiseDistribution, Distribution
@ -221,12 +222,15 @@ class Critic(BasePolicy):
self.q2_net = nn.Sequential(*q2_net)
def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
features = self.extract_features(obs)
# Learn the features extractor using the policy loss only
with th.no_grad():
features = self.extract_features(obs)
qvalue_input = th.cat([features, actions], dim=1)
return self.q1_net(qvalue_input), self.q2_net(qvalue_input)
def q1_forward(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor:
features = self.extract_features(obs)
with th.no_grad():
features = self.extract_features(obs)
return self.q1_net(th.cat([features, actions], dim=1))
@ -262,7 +266,8 @@ class ValueFunction(BasePolicy):
self.vf_net = nn.Sequential(*vf_net)
def forward(self, obs: th.Tensor) -> th.Tensor:
features = self.extract_features(obs)
with th.no_grad():
features = self.extract_features(obs)
return self.vf_net(features)
@ -284,6 +289,9 @@ class TD3Policy(BasePolicy):
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using SDE to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
:param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments
to pass to the feature extractor.
:param normalize_images: (bool) Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer: (Type[th.optim.Optimizer]) The optimizer to use,
@ -303,6 +311,8 @@ class TD3Policy(BasePolicy):
lr_sde: float = 3e-4,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None):
@ -310,17 +320,24 @@ class TD3Policy(BasePolicy):
# Default network architecture, from the original paper
if net_arch is None:
net_arch = [400, 300]
if features_extractor_class == FlattenExtractor:
net_arch = [400, 300]
else:
net_arch = []
if optimizer_kwargs is None:
optimizer_kwargs = {}
if features_extractor_kwargs is None:
features_extractor_kwargs = {}
self.optimizer_class = optimizer
self.optimizer_kwargs = optimizer_kwargs
# In the future, features_extractor will be replaced with a CNN
self.features_extractor = nn.Flatten()
self.features_dim = get_obs_dim(self.observation_space)
self.features_extractor_class = features_extractor_class
self.features_extractor_kwargs = features_extractor_kwargs
self.features_extractor = features_extractor_class(self.observation_space, **features_extractor_kwargs)
self.features_dim = self.features_extractor.features_dim
self.net_arch = net_arch
self.activation_fn = activation_fn
@ -384,7 +401,9 @@ class TD3Policy(BasePolicy):
use_expln=self.actor_kwargs['use_expln'],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
optimizer=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs
optimizer_kwargs=self.optimizer_kwargs,
features_extractor_class=self.features_extractor_class,
features_extractor_kwargs=self.features_extractor_kwargs
))
return data
@ -406,4 +425,69 @@ class TD3Policy(BasePolicy):
MlpPolicy = TD3Policy
class CnnPolicy(TD3Policy):
"""
Policy class (with both actor and critic) for TD3.
:param observation_space: (gym.spaces.Space) Observation space
:param action_space: (gym.spaces.Space) Action space
:param lr_schedule: (Callable) Learning rate schedule (could be constant)
:param net_arch: (Optional[List[int]]) The specification of the policy and value networks.
:param device: (Union[th.device, str]) Device on which the code should run.
:param activation_fn: (Type[nn.Module]) Activation function
:param use_sde: (bool) Whether to use State Dependent Exploration or not
:param log_std_init: (float) Initial value for the log standard deviation
:param sde_net_arch: ([int]) Network architecture for extracting features
when using SDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using SDE to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
:param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments
to pass to the feature extractor.
:param normalize_images: (bool) Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer: (Type[th.optim.Optimizer]) The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
def __init__(self, observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Callable,
net_arch: Optional[List[int]] = None,
device: Union[th.device, str] = 'auto',
activation_fn: Type[nn.Module] = nn.ReLU,
use_sde: bool = False,
log_std_init: float = -3,
clip_noise: Optional[float] = None,
lr_sde: float = 3e-4,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None):
super(CnnPolicy, self).__init__(observation_space,
action_space,
lr_schedule,
net_arch,
device,
activation_fn,
use_sde,
log_std_init,
clip_noise,
lr_sde,
sde_net_arch,
use_expln,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer,
optimizer_kwargs)
register_policy("MlpPolicy", MlpPolicy)
register_policy("CnnPolicy", CnnPolicy)