mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-30 03:38:13 +00:00
Fix off policy features extractor (#198)
* Faster tests * Fix feature extractor bug + add check * Add missing check * Allow TD3 features extractor to be separate * Add share features extractor option for SAC * Bug fixes * Apply suggestions from code review Co-authored-by: Adam Gleave <adam@gleave.me> Co-authored-by: Adam Gleave <adam@gleave.me>
This commit is contained in:
parent
b252f4212c
commit
0fc0dd1b21
12 changed files with 286 additions and 59 deletions
|
|
@ -17,6 +17,7 @@ New Features:
|
|||
- Added Hindsight Experience Replay ``HER``. (@megan-klaiber)
|
||||
- ``VecNormalize`` now supports ``gym.spaces.Dict`` observation spaces
|
||||
- Support logging videos to Tensorboard (@SwamyDev)
|
||||
- Added ``share_features_extractor`` argument to ``SAC`` and ``TD3`` policies
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
@ -27,6 +28,8 @@ Bug Fixes:
|
|||
- Fix model creation initializing CUDA even when `device="cpu"` is provided
|
||||
- Fix ``check_env`` not checking if the env has a Dict actionspace before calling ``_check_nan`` (@wmmc88)
|
||||
- Update the check for spaces unsupported by Stable Baselines 3 to include checks on the action space (@wmmc88)
|
||||
- Fixed feature extractor bug for target network where the same net was shared instead
|
||||
of being separate. This bug affects ``SAC``, ``DDPG`` and ``TD3`` when using ``CnnPolicy`` (or custom feature extractor)
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import warnings
|
||||
|
||||
from stable_baselines3.common.env_util import * # noqa: F403
|
||||
from stable_baselines3.common.env_util import * # noqa: F403,F401
|
||||
|
||||
warnings.warn(
|
||||
"Module ``common.cmd_util`` has been renamed to ``common.env_util`` and will be removed in the future.", FutureWarning
|
||||
|
|
|
|||
|
|
@ -505,7 +505,7 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
|
||||
:param latent_dim: Dimension of the last layer of the policy (before the action layer)
|
||||
:param log_std_init: Initial value for the log standard deviation
|
||||
:param latent_sde_dim: Dimension of the last layer of the feature extractor
|
||||
:param latent_sde_dim: Dimension of the last layer of the features extractor
|
||||
for gSDE. By default, it is shared with the policy network.
|
||||
:return:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ class BaseModel(nn.Module, ABC):
|
|||
:param action_space: The action space of the environment
|
||||
:param features_extractor_class: Features extractor to use.
|
||||
:param features_extractor_kwargs: Keyword arguments
|
||||
to pass to the feature extractor.
|
||||
to pass to the features extractor.
|
||||
:param features_extractor: Network to extract features
|
||||
(a CNN when using images, a nn.Flatten() layer otherwise)
|
||||
:param normalize_images: Whether to normalize images or not,
|
||||
|
|
@ -83,6 +83,30 @@ class BaseModel(nn.Module, ABC):
|
|||
def forward(self, *args, **kwargs):
|
||||
del args, kwargs
|
||||
|
||||
def _update_features_extractor(
|
||||
self, net_kwargs: Dict[str, Any], features_extractor: Optional[BaseFeaturesExtractor] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Update the network keyword arguments and create a new features extractor object if needed.
|
||||
If a ``features_extractor`` object is passed, then it will be shared.
|
||||
|
||||
:param net_kwargs: the base network keyword arguments, without the ones
|
||||
related to features extractor
|
||||
:param features_extractor: a features extractor object.
|
||||
If None, a new object will be created.
|
||||
:return: The updated keyword arguments
|
||||
"""
|
||||
net_kwargs = net_kwargs.copy()
|
||||
if features_extractor is None:
|
||||
# The features extractor is not shared, create a new one
|
||||
features_extractor = self.make_features_extractor()
|
||||
net_kwargs.update(dict(features_extractor=features_extractor, features_dim=features_extractor.features_dim))
|
||||
return net_kwargs
|
||||
|
||||
def make_features_extractor(self) -> BaseFeaturesExtractor:
|
||||
""" Helper method to create a features extractor."""
|
||||
return self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
|
||||
|
||||
def extract_features(self, obs: th.Tensor) -> th.Tensor:
|
||||
"""
|
||||
Preprocess the observation if needed and extract features.
|
||||
|
|
@ -90,7 +114,7 @@ class BaseModel(nn.Module, ABC):
|
|||
:param obs:
|
||||
:return:
|
||||
"""
|
||||
assert self.features_extractor is not None, "No feature extractor was set"
|
||||
assert self.features_extractor is not None, "No features extractor was set"
|
||||
preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images)
|
||||
return self.features_extractor(preprocessed_obs)
|
||||
|
||||
|
|
@ -327,7 +351,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||
this allows to ensure boundaries when using gSDE.
|
||||
:param features_extractor_class: Features extractor to use.
|
||||
:param features_extractor_kwargs: Keyword arguments
|
||||
to pass to the feature extractor.
|
||||
to pass to the features extractor.
|
||||
:param normalize_images: Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer_class: The optimizer to use,
|
||||
|
|
@ -466,7 +490,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||
|
||||
latent_dim_pi = self.mlp_extractor.latent_dim_pi
|
||||
|
||||
# Separate feature extractor for gSDE
|
||||
# Separate features extractor for gSDE
|
||||
if self.sde_net_arch is not None:
|
||||
self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor(
|
||||
self.features_dim, self.sde_net_arch, self.activation_fn
|
||||
|
|
@ -496,7 +520,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||
if self.ortho_init:
|
||||
# TODO: check for features_extractor
|
||||
# Values from stable-baselines.
|
||||
# feature_extractor/mlp values are
|
||||
# features_extractor/mlp values are
|
||||
# originally from openai/baselines (default gains/init_scales).
|
||||
module_gains = {
|
||||
self.features_extractor: np.sqrt(2),
|
||||
|
|
@ -625,7 +649,7 @@ class ActorCriticCnnPolicy(ActorCriticPolicy):
|
|||
this allows to ensure boundaries when using gSDE.
|
||||
:param features_extractor_class: Features extractor to use.
|
||||
:param features_extractor_kwargs: Keyword arguments
|
||||
to pass to the feature extractor.
|
||||
to pass to the features extractor.
|
||||
:param normalize_images: Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer_class: The optimizer to use,
|
||||
|
|
@ -698,6 +722,8 @@ class ContinuousCritic(BaseModel):
|
|||
:param normalize_images: Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param n_critics: Number of critic networks to create.
|
||||
:param share_features_extractor: Whether the features extractor is shared or not
|
||||
between the actor and the critic (this saves computation time)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -710,6 +736,7 @@ class ContinuousCritic(BaseModel):
|
|||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
normalize_images: bool = True,
|
||||
n_critics: int = 2,
|
||||
share_features_extractor: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
observation_space,
|
||||
|
|
@ -720,6 +747,7 @@ class ContinuousCritic(BaseModel):
|
|||
|
||||
action_dim = get_action_dim(self.action_space)
|
||||
|
||||
self.share_features_extractor = share_features_extractor
|
||||
self.n_critics = n_critics
|
||||
self.q_networks = []
|
||||
for idx in range(n_critics):
|
||||
|
|
@ -730,7 +758,8 @@ class ContinuousCritic(BaseModel):
|
|||
|
||||
def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, ...]:
|
||||
# Learn the features extractor using the policy loss only
|
||||
with th.no_grad():
|
||||
# when the features_extractor is shared with the actor
|
||||
with th.set_grad_enabled(not self.share_features_extractor):
|
||||
features = self.extract_features(obs)
|
||||
qvalue_input = th.cat([features, actions], dim=1)
|
||||
return tuple(q_net(qvalue_input) for q_net in self.q_networks)
|
||||
|
|
|
|||
|
|
@ -244,7 +244,7 @@ def get_actor_critic_arch(net_arch: Union[List[int], Dict[str, List[int]]]) -> T
|
|||
then you can specify ``net_arch=dict(qf=[400, 300], pi=[64, 64])``.
|
||||
|
||||
.. note::
|
||||
Compared to their on-policy counterparts, no shared layers (other than the feature extractor)
|
||||
Compared to their on-policy counterparts, no shared layers (other than the features extractor)
|
||||
between the actor and the critic are allowed (to prevent issues with target networks).
|
||||
|
||||
:param net_arch: The specification of the actor and critic networks.
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ class DQNPolicy(BasePolicy):
|
|||
:param activation_fn: Activation function
|
||||
:param features_extractor_class: Features extractor to use.
|
||||
:param features_extractor_kwargs: Keyword arguments
|
||||
to pass to the feature extractor.
|
||||
to pass to the features extractor.
|
||||
:param normalize_images: Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer_class: The optimizer to use,
|
||||
|
|
@ -158,10 +158,9 @@ class DQNPolicy(BasePolicy):
|
|||
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||
|
||||
def make_q_net(self) -> QNetwork:
|
||||
# Make sure we always have separate networks for feature extractors etc
|
||||
features_extractor = self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
|
||||
features_dim = features_extractor.features_dim
|
||||
return QNetwork(features_extractor=features_extractor, features_dim=features_dim, **self.net_args).to(self.device)
|
||||
# Make sure we always have separate networks for features extractors etc
|
||||
net_args = self._update_features_extractor(self.net_args, features_extractor=None)
|
||||
return QNetwork(**net_args).to(self.device)
|
||||
|
||||
def forward(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
||||
return self._predict(obs, deterministic=deterministic)
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ class Actor(BasePolicy):
|
|||
|
||||
if self.use_sde:
|
||||
latent_sde_dim = last_layer_dim
|
||||
# Separate feature extractor for gSDE
|
||||
# Separate features extractor for gSDE
|
||||
if sde_net_arch is not None:
|
||||
self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor(
|
||||
features_dim, sde_net_arch, activation_fn
|
||||
|
|
@ -211,7 +211,7 @@ class SACPolicy(BasePolicy):
|
|||
:param clip_mean: Clip the mean output when using gSDE to avoid numerical instability.
|
||||
:param features_extractor_class: Features extractor to use.
|
||||
:param features_extractor_kwargs: Keyword arguments
|
||||
to pass to the feature extractor.
|
||||
to pass to the features extractor.
|
||||
:param normalize_images: Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer_class: The optimizer to use,
|
||||
|
|
@ -219,6 +219,8 @@ class SACPolicy(BasePolicy):
|
|||
:param optimizer_kwargs: Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
:param n_critics: Number of critic networks to create.
|
||||
:param share_features_extractor: Whether to share or not the features extractor
|
||||
between the actor and the critic (this saves computation time)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -239,6 +241,7 @@ class SACPolicy(BasePolicy):
|
|||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
n_critics: int = 2,
|
||||
share_features_extractor: bool = True,
|
||||
):
|
||||
super(SACPolicy, self).__init__(
|
||||
observation_space,
|
||||
|
|
@ -258,17 +261,11 @@ class SACPolicy(BasePolicy):
|
|||
|
||||
actor_arch, critic_arch = get_actor_critic_arch(net_arch)
|
||||
|
||||
# Create shared features extractor
|
||||
self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
|
||||
self.features_dim = self.features_extractor.features_dim
|
||||
|
||||
self.net_arch = net_arch
|
||||
self.activation_fn = activation_fn
|
||||
self.net_args = {
|
||||
"observation_space": self.observation_space,
|
||||
"action_space": self.action_space,
|
||||
"features_extractor": self.features_extractor,
|
||||
"features_dim": self.features_dim,
|
||||
"net_arch": actor_arch,
|
||||
"activation_fn": self.activation_fn,
|
||||
"normalize_images": normalize_images,
|
||||
|
|
@ -283,10 +280,17 @@ class SACPolicy(BasePolicy):
|
|||
}
|
||||
self.actor_kwargs.update(sde_kwargs)
|
||||
self.critic_kwargs = self.net_args.copy()
|
||||
self.critic_kwargs.update({"n_critics": n_critics, "net_arch": critic_arch})
|
||||
self.critic_kwargs.update(
|
||||
{
|
||||
"n_critics": n_critics,
|
||||
"net_arch": critic_arch,
|
||||
"share_features_extractor": share_features_extractor,
|
||||
}
|
||||
)
|
||||
|
||||
self.actor, self.actor_target = None, None
|
||||
self.critic, self.critic_target = None, None
|
||||
self.share_features_extractor = share_features_extractor
|
||||
|
||||
self._build(lr_schedule)
|
||||
|
||||
|
|
@ -294,13 +298,21 @@ class SACPolicy(BasePolicy):
|
|||
self.actor = self.make_actor()
|
||||
self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||
|
||||
self.critic = self.make_critic()
|
||||
self.critic_target = self.make_critic()
|
||||
if self.share_features_extractor:
|
||||
self.critic = self.make_critic(features_extractor=self.actor.features_extractor)
|
||||
# Do not optimize the shared features extractor with the critic loss
|
||||
# otherwise, there are gradient computation issues
|
||||
critic_parameters = [param for name, param in self.critic.named_parameters() if "features_extractor" not in name]
|
||||
else:
|
||||
# Create a separate features extractor for the critic
|
||||
# this requires more memory and computation
|
||||
self.critic = self.make_critic(features_extractor=None)
|
||||
critic_parameters = self.critic.parameters()
|
||||
|
||||
# Critic target should not share the features extractor with critic
|
||||
self.critic_target = self.make_critic(features_extractor=None)
|
||||
self.critic_target.load_state_dict(self.critic.state_dict())
|
||||
# Do not optimize the shared feature extractor with the critic loss
|
||||
# otherwise, there are gradient computation issues
|
||||
# Another solution: having duplicated features extractor but requires more memory and computation
|
||||
critic_parameters = [param for name, param in self.critic.named_parameters() if "features_extractor" not in name]
|
||||
|
||||
self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||
|
||||
def _get_data(self) -> Dict[str, Any]:
|
||||
|
|
@ -333,11 +345,13 @@ class SACPolicy(BasePolicy):
|
|||
"""
|
||||
self.actor.reset_noise(batch_size=batch_size)
|
||||
|
||||
def make_actor(self) -> Actor:
|
||||
return Actor(**self.actor_kwargs).to(self.device)
|
||||
def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Actor:
|
||||
actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor)
|
||||
return Actor(**actor_kwargs).to(self.device)
|
||||
|
||||
def make_critic(self) -> ContinuousCritic:
|
||||
return ContinuousCritic(**self.critic_kwargs).to(self.device)
|
||||
def make_critic(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> ContinuousCritic:
|
||||
critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor)
|
||||
return ContinuousCritic(**critic_kwargs).to(self.device)
|
||||
|
||||
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
return self._predict(obs, deterministic=deterministic)
|
||||
|
|
@ -375,6 +389,8 @@ class CnnPolicy(SACPolicy):
|
|||
:param optimizer_kwargs: Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
:param n_critics: Number of critic networks to create.
|
||||
:param share_features_extractor: Whether to share or not the features extractor
|
||||
between the actor and the critic (this saves computation time)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -395,6 +411,7 @@ class CnnPolicy(SACPolicy):
|
|||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
n_critics: int = 2,
|
||||
share_features_extractor: bool = True,
|
||||
):
|
||||
super(CnnPolicy, self).__init__(
|
||||
observation_space,
|
||||
|
|
@ -413,6 +430,7 @@ class CnnPolicy(SACPolicy):
|
|||
optimizer_class,
|
||||
optimizer_kwargs,
|
||||
n_critics,
|
||||
share_features_extractor,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ class TD3Policy(BasePolicy):
|
|||
:param activation_fn: Activation function
|
||||
:param features_extractor_class: Features extractor to use.
|
||||
:param features_extractor_kwargs: Keyword arguments
|
||||
to pass to the feature extractor.
|
||||
to pass to the features extractor.
|
||||
:param normalize_images: Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer_class: The optimizer to use,
|
||||
|
|
@ -100,6 +100,8 @@ class TD3Policy(BasePolicy):
|
|||
:param optimizer_kwargs: Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
:param n_critics: Number of critic networks to create.
|
||||
:param share_features_extractor: Whether to share or not the features extractor
|
||||
between the actor and the critic (this saves computation time)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -115,6 +117,7 @@ class TD3Policy(BasePolicy):
|
|||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
n_critics: int = 2,
|
||||
share_features_extractor: bool = True,
|
||||
):
|
||||
super(TD3Policy, self).__init__(
|
||||
observation_space,
|
||||
|
|
@ -135,35 +138,54 @@ class TD3Policy(BasePolicy):
|
|||
|
||||
actor_arch, critic_arch = get_actor_critic_arch(net_arch)
|
||||
|
||||
self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
|
||||
self.features_dim = self.features_extractor.features_dim
|
||||
|
||||
self.net_arch = net_arch
|
||||
self.activation_fn = activation_fn
|
||||
self.net_args = {
|
||||
"observation_space": self.observation_space,
|
||||
"action_space": self.action_space,
|
||||
"features_extractor": self.features_extractor,
|
||||
"features_dim": self.features_dim,
|
||||
"net_arch": actor_arch,
|
||||
"activation_fn": self.activation_fn,
|
||||
"normalize_images": normalize_images,
|
||||
}
|
||||
self.actor_kwargs = self.net_args.copy()
|
||||
self.critic_kwargs = self.net_args.copy()
|
||||
self.critic_kwargs.update({"n_critics": n_critics, "net_arch": critic_arch})
|
||||
self.critic_kwargs.update(
|
||||
{
|
||||
"n_critics": n_critics,
|
||||
"net_arch": critic_arch,
|
||||
"share_features_extractor": share_features_extractor,
|
||||
}
|
||||
)
|
||||
|
||||
self.actor, self.actor_target = None, None
|
||||
self.critic, self.critic_target = None, None
|
||||
self.share_features_extractor = share_features_extractor
|
||||
|
||||
self._build(lr_schedule)
|
||||
|
||||
def _build(self, lr_schedule: Callable) -> None:
|
||||
self.actor = self.make_actor()
|
||||
self.actor_target = self.make_actor()
|
||||
# Create actor and target
|
||||
# the features extractor should not be shared
|
||||
self.actor = self.make_actor(features_extractor=None)
|
||||
self.actor_target = self.make_actor(features_extractor=None)
|
||||
# Initialize the target to have the same weights as the actor
|
||||
self.actor_target.load_state_dict(self.actor.state_dict())
|
||||
|
||||
self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||
self.critic = self.make_critic()
|
||||
self.critic_target = self.make_critic()
|
||||
|
||||
if self.share_features_extractor:
|
||||
self.critic = self.make_critic(features_extractor=self.actor.features_extractor)
|
||||
# Critic target should not share the features extactor with critic
|
||||
# but it can share it with the actor target as actor and critic are sharing
|
||||
# the same features_extractor too
|
||||
# NOTE: as a result the effective poliak (soft-copy) coefficient for the features extractor
|
||||
# will be 2 * tau instead of tau (updated one time with the actor, a second time with the critic)
|
||||
self.critic_target = self.make_critic(features_extractor=self.actor_target.features_extractor)
|
||||
else:
|
||||
# Create new features extractor for each network
|
||||
self.critic = self.make_critic(features_extractor=None)
|
||||
self.critic_target = self.make_critic(features_extractor=None)
|
||||
|
||||
self.critic_target.load_state_dict(self.critic.state_dict())
|
||||
self.critic.optimizer = self.optimizer_class(self.critic.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||
|
||||
|
|
@ -180,15 +202,18 @@ class TD3Policy(BasePolicy):
|
|||
optimizer_kwargs=self.optimizer_kwargs,
|
||||
features_extractor_class=self.features_extractor_class,
|
||||
features_extractor_kwargs=self.features_extractor_kwargs,
|
||||
share_features_extractor=self.share_features_extractor,
|
||||
)
|
||||
)
|
||||
return data
|
||||
|
||||
def make_actor(self) -> Actor:
|
||||
return Actor(**self.actor_kwargs).to(self.device)
|
||||
def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Actor:
|
||||
actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor)
|
||||
return Actor(**actor_kwargs).to(self.device)
|
||||
|
||||
def make_critic(self) -> ContinuousCritic:
|
||||
return ContinuousCritic(**self.critic_kwargs).to(self.device)
|
||||
def make_critic(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> ContinuousCritic:
|
||||
critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor)
|
||||
return ContinuousCritic(**critic_kwargs).to(self.device)
|
||||
|
||||
def forward(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
return self._predict(observation, deterministic=deterministic)
|
||||
|
|
@ -211,7 +236,7 @@ class CnnPolicy(TD3Policy):
|
|||
:param activation_fn: Activation function
|
||||
:param features_extractor_class: Features extractor to use.
|
||||
:param features_extractor_kwargs: Keyword arguments
|
||||
to pass to the feature extractor.
|
||||
to pass to the features extractor.
|
||||
:param normalize_images: Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer_class: The optimizer to use,
|
||||
|
|
@ -219,6 +244,8 @@ class CnnPolicy(TD3Policy):
|
|||
:param optimizer_kwargs: Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
:param n_critics: Number of critic networks to create.
|
||||
:param share_features_extractor: Whether to share or not the features extractor
|
||||
between the actor and the critic (this saves computation time)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -234,6 +261,7 @@ class CnnPolicy(TD3Policy):
|
|||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
n_critics: int = 2,
|
||||
share_features_extractor: bool = True,
|
||||
):
|
||||
super(CnnPolicy, self).__init__(
|
||||
observation_space,
|
||||
|
|
@ -247,6 +275,7 @@ class CnnPolicy(TD3Policy):
|
|||
optimizer_class,
|
||||
optimizer_kwargs,
|
||||
n_critics,
|
||||
share_features_extractor,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,13 @@
|
|||
import os
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch as th
|
||||
|
||||
from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
|
||||
from stable_baselines3.common.identity_env import FakeImageEnv
|
||||
from stable_baselines3.common.utils import zip_strict
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN])
|
||||
|
|
@ -35,3 +38,139 @@ def test_cnn(tmp_path, model_class):
|
|||
assert np.allclose(action, model.predict(obs, deterministic=True)[0])
|
||||
|
||||
os.remove(str(tmp_path / SAVE_NAME))
|
||||
|
||||
|
||||
def patch_dqn_names_(model):
|
||||
# Small hack to make the test work with DQN
|
||||
if isinstance(model, DQN):
|
||||
model.critic = model.q_net
|
||||
model.critic_target = model.q_net_target
|
||||
|
||||
|
||||
def params_should_match(params, other_params):
|
||||
for param, other_param in zip_strict(params, other_params):
|
||||
assert th.allclose(param, other_param)
|
||||
|
||||
|
||||
def params_should_differ(params, other_params):
|
||||
for param, other_param in zip_strict(params, other_params):
|
||||
assert not th.allclose(param, other_param)
|
||||
|
||||
|
||||
def check_td3_feature_extractor_match(model):
|
||||
for (key, actor_param), critic_param in zip(model.actor_target.named_parameters(), model.critic_target.parameters()):
|
||||
if "features_extractor" in key:
|
||||
assert th.allclose(actor_param, critic_param), key
|
||||
|
||||
|
||||
def check_td3_feature_extractor_differ(model):
|
||||
for (key, actor_param), critic_param in zip(model.actor_target.named_parameters(), model.critic_target.parameters()):
|
||||
if "features_extractor" in key:
|
||||
assert not th.allclose(actor_param, critic_param), key
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [SAC, TD3, DQN])
|
||||
@pytest.mark.parametrize("share_features_extractor", [True, False])
|
||||
def test_features_extractor_target_net(model_class, share_features_extractor):
|
||||
if model_class == DQN and share_features_extractor:
|
||||
pytest.skip()
|
||||
|
||||
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {SAC, TD3})
|
||||
# Avoid memory error when using replay buffer
|
||||
# Reduce the size of the features
|
||||
kwargs = dict(buffer_size=250, learning_starts=100, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)))
|
||||
if model_class != DQN:
|
||||
kwargs["policy_kwargs"]["share_features_extractor"] = share_features_extractor
|
||||
|
||||
model = model_class("CnnPolicy", env, seed=0, **kwargs)
|
||||
|
||||
patch_dqn_names_(model)
|
||||
|
||||
if share_features_extractor:
|
||||
# Check that the objects are the same and not just copied
|
||||
assert id(model.policy.actor.features_extractor) == id(model.policy.critic.features_extractor)
|
||||
if model_class == TD3:
|
||||
assert id(model.policy.actor_target.features_extractor) == id(model.policy.critic_target.features_extractor)
|
||||
# Actor and critic feature extractor should be the same
|
||||
td3_features_extractor_check = check_td3_feature_extractor_match
|
||||
else:
|
||||
# Actor and critic feature extractor should differ same
|
||||
td3_features_extractor_check = check_td3_feature_extractor_differ
|
||||
# Check that the object differ
|
||||
if model_class != DQN:
|
||||
assert id(model.policy.actor.features_extractor) != id(model.policy.critic.features_extractor)
|
||||
|
||||
if model_class == TD3:
|
||||
assert id(model.policy.actor_target.features_extractor) != id(model.policy.critic_target.features_extractor)
|
||||
|
||||
# Critic and target should be equal at the begginning of training
|
||||
params_should_match(model.critic.parameters(), model.critic_target.parameters())
|
||||
|
||||
# TD3 has also a target actor net
|
||||
if model_class == TD3:
|
||||
params_should_match(model.actor.parameters(), model.actor_target.parameters())
|
||||
|
||||
model.learn(200)
|
||||
|
||||
# Critic and target should differ
|
||||
params_should_differ(model.critic.parameters(), model.critic_target.parameters())
|
||||
|
||||
if model_class == TD3:
|
||||
params_should_differ(model.actor.parameters(), model.actor_target.parameters())
|
||||
td3_features_extractor_check(model)
|
||||
|
||||
# Re-initialize and collect some random data (without doing gradient steps,
|
||||
# since 10 < learning_starts = 100)
|
||||
model = model_class("CnnPolicy", env, seed=0, **kwargs).learn(10)
|
||||
|
||||
patch_dqn_names_(model)
|
||||
|
||||
original_param = deepcopy(list(model.critic.parameters()))
|
||||
original_target_param = deepcopy(list(model.critic_target.parameters()))
|
||||
if model_class == TD3:
|
||||
original_actor_target_param = deepcopy(list(model.actor_target.parameters()))
|
||||
|
||||
# Deactivate copy to target
|
||||
model.tau = 0.0
|
||||
model.train(gradient_steps=1)
|
||||
|
||||
# Target should be the same
|
||||
params_should_match(original_target_param, model.critic_target.parameters())
|
||||
|
||||
if model_class == TD3:
|
||||
params_should_match(original_actor_target_param, model.actor_target.parameters())
|
||||
td3_features_extractor_check(model)
|
||||
|
||||
# not the same for critic net (updated by gradient descent)
|
||||
params_should_differ(original_param, model.critic.parameters())
|
||||
|
||||
# Update the reference as it should not change in the next step
|
||||
original_param = deepcopy(list(model.critic.parameters()))
|
||||
|
||||
if model_class == TD3:
|
||||
original_actor_param = deepcopy(list(model.actor.parameters()))
|
||||
|
||||
# Deactivate learning rate
|
||||
model.lr_schedule = lambda _: 0.0
|
||||
# Re-activate polyak update
|
||||
model.tau = 0.01
|
||||
# Special case for DQN: target net is updated in the `collect_rollout()`
|
||||
# not the `train()` method
|
||||
if model_class == DQN:
|
||||
model.target_update_interval = 1
|
||||
model._on_step()
|
||||
|
||||
model.train(gradient_steps=1)
|
||||
|
||||
# Target should have changed now (due to polyak update)
|
||||
params_should_differ(original_target_param, model.critic_target.parameters())
|
||||
|
||||
# Critic should be the same
|
||||
params_should_match(original_param, model.critic.parameters())
|
||||
|
||||
if model_class == TD3:
|
||||
params_should_differ(original_actor_target_param, model.actor_target.parameters())
|
||||
|
||||
params_should_match(original_actor_param, model.actor.parameters())
|
||||
|
||||
td3_features_extractor_check(model)
|
||||
|
|
|
|||
|
|
@ -20,9 +20,10 @@ def test_deterministic_pg(model_class, action_noise):
|
|||
learning_starts=100,
|
||||
verbose=1,
|
||||
create_eval_env=True,
|
||||
buffer_size=250,
|
||||
action_noise=action_noise,
|
||||
)
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
model.learn(total_timesteps=300, eval_freq=250)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v0"])
|
||||
|
|
@ -50,6 +51,7 @@ def test_ppo(env_id, clip_range_vf):
|
|||
model = PPO(
|
||||
"MlpPolicy",
|
||||
env_id,
|
||||
n_steps=512,
|
||||
seed=0,
|
||||
policy_kwargs=dict(net_arch=[16]),
|
||||
verbose=1,
|
||||
|
|
@ -68,19 +70,25 @@ def test_sac(ent_coef):
|
|||
learning_starts=100,
|
||||
verbose=1,
|
||||
create_eval_env=True,
|
||||
buffer_size=250,
|
||||
ent_coef=ent_coef,
|
||||
action_noise=NormalActionNoise(np.zeros(1), np.zeros(1)),
|
||||
)
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
model.learn(total_timesteps=300, eval_freq=250)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_critics", [1, 3])
|
||||
def test_n_critics(n_critics):
|
||||
# Test SAC with different number of critics, for TD3, n_critics=1 corresponds to DDPG
|
||||
model = SAC(
|
||||
"MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64, 64], n_critics=n_critics), learning_starts=100, verbose=1
|
||||
"MlpPolicy",
|
||||
"Pendulum-v0",
|
||||
policy_kwargs=dict(net_arch=[64, 64], n_critics=n_critics),
|
||||
learning_starts=100,
|
||||
buffer_size=10000,
|
||||
verbose=1,
|
||||
)
|
||||
model.learn(total_timesteps=500)
|
||||
model.learn(total_timesteps=300)
|
||||
|
||||
|
||||
def test_dqn():
|
||||
|
|
|
|||
|
|
@ -303,21 +303,23 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
|
|||
:param model_class: (BaseAlgorithm) A RL model
|
||||
:param policy_str: (str) Name of the policy.
|
||||
"""
|
||||
kwargs = {}
|
||||
kwargs = dict(policy_kwargs=dict(net_arch=[16]))
|
||||
if policy_str == "MlpPolicy":
|
||||
env = select_env(model_class)
|
||||
else:
|
||||
if model_class in [SAC, TD3, DQN, DDPG]:
|
||||
# Avoid memory error when using replay buffer
|
||||
# Reduce the size of the features
|
||||
kwargs = dict(buffer_size=250)
|
||||
kwargs = dict(
|
||||
buffer_size=250, learning_starts=100, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32))
|
||||
)
|
||||
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == DQN)
|
||||
|
||||
env = DummyVecEnv([lambda: env])
|
||||
|
||||
# create model
|
||||
model = model_class(policy_str, env, policy_kwargs=dict(net_arch=[16]), verbose=1, **kwargs)
|
||||
model.learn(total_timesteps=500)
|
||||
model = model_class(policy_str, env, verbose=1, **kwargs)
|
||||
model.learn(total_timesteps=300)
|
||||
|
||||
env.reset()
|
||||
observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
|
||||
|
|
|
|||
|
|
@ -195,4 +195,4 @@ def test_zip_strict():
|
|||
def test_cmd_util_rename():
|
||||
"""Test that importing cmd_util still works but raises warning"""
|
||||
with pytest.warns(FutureWarning):
|
||||
from stable_baselines3.common.cmd_util import make_vec_env
|
||||
from stable_baselines3.common.cmd_util import make_vec_env # noqa: F401
|
||||
|
|
|
|||
Loading…
Reference in a new issue