mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-23 22:20:18 +00:00
Add CNN support for TD3
This commit is contained in:
parent
8f4155180e
commit
73fb8d1c63
5 changed files with 135 additions and 27 deletions
|
|
@ -1,10 +1,16 @@
|
|||
import os
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from torchy_baselines import A2C, PPO, SAC, TD3
|
||||
from torchy_baselines.common.identity_env import FakeImageEnv
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_class', [A2C, PPO, SAC])
|
||||
SAVE_PATH = './cnn_model.zip'
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_class', [A2C, PPO, SAC, TD3])
|
||||
def test_cnn(model_class):
|
||||
# Fake grayscale with frameskip
|
||||
# Atari after preprocessing: 84x84x1, here we are using lower resolution
|
||||
|
|
@ -16,5 +22,19 @@ def test_cnn(model_class):
|
|||
else:
|
||||
# Avoid memory error when using replay buffer
|
||||
# Reduce the size of the features
|
||||
kwargs = dict(buffer_size=500, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=40)))
|
||||
_ = model_class('CnnPolicy', env, **kwargs).learn(500)
|
||||
kwargs = dict(buffer_size=250, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=512)))
|
||||
model = model_class('CnnPolicy', env, **kwargs).learn(250)
|
||||
|
||||
obs = env.reset()
|
||||
|
||||
action, _ = model.predict(obs, deterministic=True)
|
||||
|
||||
model.save(SAVE_PATH)
|
||||
del model
|
||||
|
||||
model = model_class.load(SAVE_PATH)
|
||||
|
||||
# Check that the prediction is the same
|
||||
assert np.allclose(action, model.predict(obs, deterministic=True)[0])
|
||||
|
||||
os.remove(SAVE_PATH)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import torch as th
|
|||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from torchy_baselines.common.preprocessing import preprocess_obs, get_obs_dim, is_image_space
|
||||
from torchy_baselines.common.preprocessing import preprocess_obs, get_flattened_obs_dim, is_image_space
|
||||
from torchy_baselines.common.utils import get_device, get_schedule_fn
|
||||
from torchy_baselines.common.vec_env import VecTransposeImage
|
||||
|
||||
|
|
@ -482,7 +482,7 @@ class BaseFeaturesExtractor(nn.Module):
|
|||
|
||||
class FlattenExtractor(BaseFeaturesExtractor):
|
||||
def __init__(self, observation_space: gym.Space):
|
||||
super(FlattenExtractor, self).__init__(observation_space, get_obs_dim(observation_space))
|
||||
super(FlattenExtractor, self).__init__(observation_space, get_flattened_obs_dim(observation_space))
|
||||
self.flatten = nn.Flatten()
|
||||
|
||||
def forward(self, observations: th.Tensor) -> th.Tensor:
|
||||
|
|
|
|||
|
|
@ -79,10 +79,10 @@ def get_obs_shape(observation_space: spaces.Space) -> Tuple[int, ...]:
|
|||
raise NotImplementedError()
|
||||
|
||||
|
||||
def get_obs_dim(observation_space: spaces.Space) -> Union[int, Tuple[int, ...]]:
|
||||
def get_flattened_obs_dim(observation_space: spaces.Space) -> Union[int, Tuple[int, ...]]:
|
||||
"""
|
||||
Get the dimension of the observation space.
|
||||
It should not be used when using images.
|
||||
Get the dimension of the observation space when flattened.
|
||||
It does not apply to image observation space.
|
||||
|
||||
:param observation_space: (spaces.Space)
|
||||
:return: (Union[int, Tuple[int, ...]])
|
||||
|
|
|
|||
|
|
@ -61,10 +61,7 @@ class Actor(BasePolicy):
|
|||
device=device,
|
||||
squash_output=True)
|
||||
|
||||
action_dim = get_action_dim(self.action_space)
|
||||
|
||||
latent_pi_net = create_mlp(features_dim, -1, net_arch, activation_fn)
|
||||
self.latent_pi = nn.Sequential(*latent_pi_net)
|
||||
# Save arguments to re-create object at loading
|
||||
self.use_sde = use_sde
|
||||
self.sde_features_extractor = None
|
||||
self.sde_net_arch = sde_net_arch
|
||||
|
|
@ -76,6 +73,10 @@ class Actor(BasePolicy):
|
|||
self.use_expln = use_expln
|
||||
self.full_std = full_std
|
||||
self.clip_mean = clip_mean
|
||||
|
||||
action_dim = get_action_dim(self.action_space)
|
||||
latent_pi_net = create_mlp(features_dim, -1, net_arch, activation_fn)
|
||||
self.latent_pi = nn.Sequential(*latent_pi_net)
|
||||
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim
|
||||
|
||||
|
||||
|
|
@ -218,7 +219,10 @@ class Critic(BasePolicy):
|
|||
self.q_networks = [self.q1_net, self.q2_net]
|
||||
|
||||
def forward(self, obs: th.Tensor, action: th.Tensor) -> List[th.Tensor]:
|
||||
features = self.extract_features(obs)
|
||||
# Learn the features extractor using the policy loss only
|
||||
# this is much faster
|
||||
with th.no_grad():
|
||||
features = self.extract_features(obs)
|
||||
qvalue_input = th.cat([features, action], dim=1)
|
||||
return [q_net(qvalue_input) for q_net in self.q_networks]
|
||||
|
||||
|
|
@ -279,13 +283,14 @@ class SACPolicy(BasePolicy):
|
|||
if optimizer_kwargs is None:
|
||||
optimizer_kwargs = {}
|
||||
|
||||
if features_extractor_kwargs is None:
|
||||
features_extractor_kwargs = {}
|
||||
|
||||
self.optimizer_class = optimizer
|
||||
self.optimizer_kwargs = optimizer_kwargs
|
||||
|
||||
self.features_extractor_class = features_extractor_class
|
||||
self.features_extractor_kwargs = features_extractor_kwargs
|
||||
if features_extractor_kwargs is None:
|
||||
features_extractor_kwargs = {}
|
||||
self.features_extractor = features_extractor_class(self.observation_space, **features_extractor_kwargs)
|
||||
self.features_dim = self.features_extractor.features_dim
|
||||
|
||||
|
|
@ -325,8 +330,7 @@ class SACPolicy(BasePolicy):
|
|||
self.critic_target.load_state_dict(self.critic.state_dict())
|
||||
# Do not optimize the shared feature extractor with the critic loss
|
||||
# otherwise, there are gradient computation issues
|
||||
# another solution: having duplicated features extractor but requires more memory and computation
|
||||
# Note: check gradients, they are maybe computed but not zeroed by the critic
|
||||
# Another solution: having duplicated features extractor but requires more memory and computation
|
||||
critic_parameters = [param for name, param in self.critic.named_parameters() if 'features_extractor' not in name]
|
||||
self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1),
|
||||
**self.optimizer_kwargs)
|
||||
|
|
|
|||
|
|
@ -4,9 +4,10 @@ import gym
|
|||
import torch as th
|
||||
import torch.nn as nn
|
||||
|
||||
from torchy_baselines.common.preprocessing import get_action_dim, get_obs_dim
|
||||
from torchy_baselines.common.preprocessing import get_action_dim
|
||||
from torchy_baselines.common.policies import (BasePolicy, register_policy, create_mlp,
|
||||
create_sde_features_extractor)
|
||||
create_sde_features_extractor, NatureCNN,
|
||||
BaseFeaturesExtractor, FlattenExtractor)
|
||||
from torchy_baselines.common.distributions import StateDependentNoiseDistribution, Distribution
|
||||
|
||||
|
||||
|
|
@ -221,12 +222,15 @@ class Critic(BasePolicy):
|
|||
self.q2_net = nn.Sequential(*q2_net)
|
||||
|
||||
def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
features = self.extract_features(obs)
|
||||
# Learn the features extractor using the policy loss only
|
||||
with th.no_grad():
|
||||
features = self.extract_features(obs)
|
||||
qvalue_input = th.cat([features, actions], dim=1)
|
||||
return self.q1_net(qvalue_input), self.q2_net(qvalue_input)
|
||||
|
||||
def q1_forward(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor:
|
||||
features = self.extract_features(obs)
|
||||
with th.no_grad():
|
||||
features = self.extract_features(obs)
|
||||
return self.q1_net(th.cat([features, actions], dim=1))
|
||||
|
||||
|
||||
|
|
@ -262,7 +266,8 @@ class ValueFunction(BasePolicy):
|
|||
self.vf_net = nn.Sequential(*vf_net)
|
||||
|
||||
def forward(self, obs: th.Tensor) -> th.Tensor:
|
||||
features = self.extract_features(obs)
|
||||
with th.no_grad():
|
||||
features = self.extract_features(obs)
|
||||
return self.vf_net(features)
|
||||
|
||||
|
||||
|
|
@ -284,6 +289,9 @@ class TD3Policy(BasePolicy):
|
|||
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using SDE to ensure
|
||||
a positive standard deviation (cf paper). It allows to keep variance
|
||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
|
||||
:param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments
|
||||
to pass to the feature extractor.
|
||||
:param normalize_images: (bool) Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer: (Type[th.optim.Optimizer]) The optimizer to use,
|
||||
|
|
@ -303,6 +311,8 @@ class TD3Policy(BasePolicy):
|
|||
lr_sde: float = 3e-4,
|
||||
sde_net_arch: Optional[List[int]] = None,
|
||||
use_expln: bool = False,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None):
|
||||
|
|
@ -310,17 +320,24 @@ class TD3Policy(BasePolicy):
|
|||
|
||||
# Default network architecture, from the original paper
|
||||
if net_arch is None:
|
||||
net_arch = [400, 300]
|
||||
if features_extractor_class == FlattenExtractor:
|
||||
net_arch = [400, 300]
|
||||
else:
|
||||
net_arch = []
|
||||
|
||||
if optimizer_kwargs is None:
|
||||
optimizer_kwargs = {}
|
||||
|
||||
if features_extractor_kwargs is None:
|
||||
features_extractor_kwargs = {}
|
||||
|
||||
self.optimizer_class = optimizer
|
||||
self.optimizer_kwargs = optimizer_kwargs
|
||||
|
||||
# In the future, features_extractor will be replaced with a CNN
|
||||
self.features_extractor = nn.Flatten()
|
||||
self.features_dim = get_obs_dim(self.observation_space)
|
||||
self.features_extractor_class = features_extractor_class
|
||||
self.features_extractor_kwargs = features_extractor_kwargs
|
||||
self.features_extractor = features_extractor_class(self.observation_space, **features_extractor_kwargs)
|
||||
self.features_dim = self.features_extractor.features_dim
|
||||
|
||||
self.net_arch = net_arch
|
||||
self.activation_fn = activation_fn
|
||||
|
|
@ -384,7 +401,9 @@ class TD3Policy(BasePolicy):
|
|||
use_expln=self.actor_kwargs['use_expln'],
|
||||
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
|
||||
optimizer=self.optimizer_class,
|
||||
optimizer_kwargs=self.optimizer_kwargs
|
||||
optimizer_kwargs=self.optimizer_kwargs,
|
||||
features_extractor_class=self.features_extractor_class,
|
||||
features_extractor_kwargs=self.features_extractor_kwargs
|
||||
))
|
||||
return data
|
||||
|
||||
|
|
@ -406,4 +425,69 @@ class TD3Policy(BasePolicy):
|
|||
|
||||
MlpPolicy = TD3Policy
|
||||
|
||||
|
||||
class CnnPolicy(TD3Policy):
|
||||
"""
|
||||
Policy class (with both actor and critic) for TD3.
|
||||
|
||||
:param observation_space: (gym.spaces.Space) Observation space
|
||||
:param action_space: (gym.spaces.Space) Action space
|
||||
:param lr_schedule: (Callable) Learning rate schedule (could be constant)
|
||||
:param net_arch: (Optional[List[int]]) The specification of the policy and value networks.
|
||||
:param device: (Union[th.device, str]) Device on which the code should run.
|
||||
:param activation_fn: (Type[nn.Module]) Activation function
|
||||
:param use_sde: (bool) Whether to use State Dependent Exploration or not
|
||||
:param log_std_init: (float) Initial value for the log standard deviation
|
||||
:param sde_net_arch: ([int]) Network architecture for extracting features
|
||||
when using SDE. If None, the latent features from the policy will be used.
|
||||
Pass an empty list to use the states as features.
|
||||
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using SDE to ensure
|
||||
a positive standard deviation (cf paper). It allows to keep variance
|
||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
|
||||
:param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments
|
||||
to pass to the feature extractor.
|
||||
:param normalize_images: (bool) Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer: (Type[th.optim.Optimizer]) The optimizer to use,
|
||||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
"""
|
||||
def __init__(self, observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: Callable,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
device: Union[th.device, str] = 'auto',
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
use_sde: bool = False,
|
||||
log_std_init: float = -3,
|
||||
clip_noise: Optional[float] = None,
|
||||
lr_sde: float = 3e-4,
|
||||
sde_net_arch: Optional[List[int]] = None,
|
||||
use_expln: bool = False,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None):
|
||||
super(CnnPolicy, self).__init__(observation_space,
|
||||
action_space,
|
||||
lr_schedule,
|
||||
net_arch,
|
||||
device,
|
||||
activation_fn,
|
||||
use_sde,
|
||||
log_std_init,
|
||||
clip_noise,
|
||||
lr_sde,
|
||||
sde_net_arch,
|
||||
use_expln,
|
||||
features_extractor_class,
|
||||
features_extractor_kwargs,
|
||||
normalize_images,
|
||||
optimizer,
|
||||
optimizer_kwargs)
|
||||
|
||||
register_policy("MlpPolicy", MlpPolicy)
|
||||
register_policy("CnnPolicy", CnnPolicy)
|
||||
|
|
|
|||
Loading…
Reference in a new issue