mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-26 22:45:15 +00:00
Merge branch 'base-class-review' of github.com:DLR-RM/stable-baselines3 into base-class-review
This commit is contained in:
commit
bf73f01ee9
8 changed files with 114 additions and 118 deletions
|
|
@ -17,6 +17,7 @@ notebooks:
|
|||
- `Monitor Training and Plotting`_
|
||||
- `Atari Games`_
|
||||
- `RL Baselines zoo`_
|
||||
- `PyBullet`_
|
||||
|
||||
.. - `Hindsight Experience Replay`_
|
||||
|
||||
|
|
@ -27,6 +28,7 @@ notebooks:
|
|||
.. _Atari Games: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/atari_games.ipynb
|
||||
.. _Hindsight Experience Replay: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/stable_baselines_her.ipynb
|
||||
.. _RL Baselines zoo: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/rl-baselines-zoo.ipynb
|
||||
.. _PyBullet: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pybullet.ipynb
|
||||
|
||||
.. |colab| image:: ../_static/img/colab.svg
|
||||
|
||||
|
|
@ -291,7 +293,7 @@ PyBullet: Normalizing input features
|
|||
|
||||
Normalizing input features may be essential to successful training of an RL agent
|
||||
(by default, images are scaled but not other types of input),
|
||||
for instance when training on `PyBullet <https://github.com/bulletphysics/bullet3/>`_ environments. For that, a wrapper exists and
|
||||
for instance when training on `PyBullet <https://github.com/bulletphysics/bullet3/>`__ environments. For that, a wrapper exists and
|
||||
will compute a running average and standard deviation of input features (it can do the same for rewards).
|
||||
|
||||
|
||||
|
|
@ -300,6 +302,10 @@ will compute a running average and standard deviation of input features (it can
|
|||
you need to install pybullet with ``pip install pybullet``
|
||||
|
||||
|
||||
.. image:: ../_static/img/colab-badge.svg
|
||||
:target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pybullet.ipynb
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import gym
|
||||
|
|
|
|||
|
|
@ -3,13 +3,15 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Pre-Release 0.8.0a2 (WIP)
|
||||
Pre-Release 0.8.0a3 (WIP)
|
||||
------------------------------
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- ``AtariWrapper`` and other Atari wrappers were updated to match SB2 ones
|
||||
- ``save_replay_buffer`` now receives as argument the file path instead of the folder path (@tirafesi)
|
||||
- Refactored ``Critic`` class for ``TD3`` and ``SAC``, it is now called ``ContinuousCritic``
|
||||
and has an additional parameter ``n_critics``
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -40,6 +42,7 @@ Documentation:
|
|||
- Updated notebook links
|
||||
- Fixed a typo in the section of Enjoy a Trained Agent, in RL Baselines3 Zoo README. (@blurLake)
|
||||
- Added Unity reacher to the projects page (@koulakis)
|
||||
- Added PyBullet colab notebook
|
||||
|
||||
|
||||
|
||||
|
|
@ -342,4 +345,4 @@ And all the contributors:
|
|||
@Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket
|
||||
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching
|
||||
@flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3
|
||||
@tirafesi @blurLake @koulakis
|
||||
@tirafesi @blurLake @koulakis
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import torch as th
|
|||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from stable_baselines3.common.preprocessing import preprocess_obs, is_image_space
|
||||
from stable_baselines3.common.preprocessing import preprocess_obs, is_image_space, get_action_dim
|
||||
from stable_baselines3.common.torch_layers import (FlattenExtractor, BaseFeaturesExtractor, create_mlp,
|
||||
NatureCNN, MlpExtractor)
|
||||
from stable_baselines3.common.utils import get_device, is_vectorized_observation
|
||||
|
|
@ -644,6 +644,74 @@ class ActorCriticCnnPolicy(ActorCriticPolicy):
|
|||
optimizer_kwargs)
|
||||
|
||||
|
||||
class ContinuousCritic(BaseModel):
|
||||
"""
|
||||
Critic network(s) for DDPG/SAC/TD3.
|
||||
It represents the action-state value function (Q-value function).
|
||||
Compared to A2C/PPO critics, this one represents the Q-value
|
||||
and takes the continuous action as input. It is concatenated with the state
|
||||
and then fed to the network which outputs a single value: Q(s, a).
|
||||
For more recent algorithms like SAC/TD3, multiple networks
|
||||
are created to give different estimates.
|
||||
|
||||
By default, it creates two critic networks used to reduce overestimation
|
||||
thanks to clipped Q-learning (cf TD3 paper).
|
||||
|
||||
:param observation_space: (gym.spaces.Space) Obervation space
|
||||
:param action_space: (gym.spaces.Space) Action space
|
||||
:param net_arch: ([int]) Network architecture
|
||||
:param features_extractor: (nn.Module) Network to extract features
|
||||
(a CNN when using images, a nn.Flatten() layer otherwise)
|
||||
:param features_dim: (int) Number of features
|
||||
:param activation_fn: (Type[nn.Module]) Activation function
|
||||
:param normalize_images: (bool) Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param device: (Union[th.device, str]) Device on which the code should run.
|
||||
:param n_critics: (int) Number of critic networks to create.
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
net_arch: List[int],
|
||||
features_extractor: nn.Module,
|
||||
features_dim: int,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
normalize_images: bool = True,
|
||||
device: Union[th.device, str] = 'auto',
|
||||
n_critics: int = 2):
|
||||
super().__init__(observation_space, action_space,
|
||||
features_extractor=features_extractor,
|
||||
normalize_images=normalize_images,
|
||||
device=device)
|
||||
|
||||
action_dim = get_action_dim(self.action_space)
|
||||
|
||||
self.n_critics = n_critics
|
||||
self.q_networks = []
|
||||
for idx in range(n_critics):
|
||||
q_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
|
||||
q_net = nn.Sequential(*q_net)
|
||||
self.add_module(f'qf{idx}', q_net)
|
||||
self.q_networks.append(q_net)
|
||||
|
||||
def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, ...]:
|
||||
# Learn the features extractor using the policy loss only
|
||||
with th.no_grad():
|
||||
features = self.extract_features(obs)
|
||||
qvalue_input = th.cat([features, actions], dim=1)
|
||||
return tuple(q_net(qvalue_input) for q_net in self.q_networks)
|
||||
|
||||
def q1_forward(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor:
|
||||
"""
|
||||
Only predict the Q-value using the first network.
|
||||
This allows to reduce computation when all the estimates are not needed
|
||||
(e.g. when updating the policy in TD3).
|
||||
"""
|
||||
with th.no_grad():
|
||||
features = self.extract_features(obs)
|
||||
return self.q_networks[0](th.cat([features, actions], dim=1))
|
||||
|
||||
|
||||
def create_sde_features_extractor(features_dim: int,
|
||||
sde_net_arch: List[int],
|
||||
activation_fn: Type[nn.Module]) -> Tuple[nn.Sequential, int]:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import torch as th
|
|||
import torch.nn as nn
|
||||
|
||||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
from stable_baselines3.common.policies import BaseModel, BasePolicy, register_policy, create_sde_features_extractor
|
||||
from stable_baselines3.common.policies import BasePolicy, register_policy, create_sde_features_extractor, ContinuousCritic
|
||||
from stable_baselines3.common.torch_layers import create_mlp, NatureCNN, BaseFeaturesExtractor, FlattenExtractor
|
||||
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
|
||||
|
||||
|
|
@ -179,54 +179,6 @@ class Actor(BasePolicy):
|
|||
return self.forward(observation, deterministic)
|
||||
|
||||
|
||||
class Critic(BaseModel):
|
||||
"""
|
||||
Critic network (q-value function) for SAC.
|
||||
|
||||
:param observation_space: (gym.spaces.Space) Obervation space
|
||||
:param action_space: (gym.spaces.Space) Action space
|
||||
:param net_arch: ([int]) Network architecture
|
||||
:param features_extractor: (nn.Module) Network to extract features
|
||||
(a CNN when using images, a nn.Flatten() layer otherwise)
|
||||
:param features_dim: (int) Number of features
|
||||
:param activation_fn: (Type[nn.Module]) Activation function
|
||||
:param normalize_images: (bool) Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param device: (Union[th.device, str]) Device on which the code should run.
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
net_arch: List[int],
|
||||
features_extractor: nn.Module,
|
||||
features_dim: int,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
normalize_images: bool = True,
|
||||
device: Union[th.device, str] = 'auto'):
|
||||
super(Critic, self).__init__(observation_space, action_space,
|
||||
features_extractor=features_extractor,
|
||||
normalize_images=normalize_images,
|
||||
device=device)
|
||||
|
||||
action_dim = get_action_dim(self.action_space)
|
||||
|
||||
q1_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
|
||||
self.q1_net = nn.Sequential(*q1_net)
|
||||
|
||||
q2_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
|
||||
self.q2_net = nn.Sequential(*q2_net)
|
||||
|
||||
self.q_networks = [self.q1_net, self.q2_net]
|
||||
|
||||
def forward(self, obs: th.Tensor, action: th.Tensor) -> List[th.Tensor]:
|
||||
# Learn the features extractor using the policy loss only
|
||||
# this is much faster
|
||||
with th.no_grad():
|
||||
features = self.extract_features(obs)
|
||||
qvalue_input = th.cat([features, action], dim=1)
|
||||
return [q_net(qvalue_input) for q_net in self.q_networks]
|
||||
|
||||
|
||||
class SACPolicy(BasePolicy):
|
||||
"""
|
||||
Policy class (with both actor and critic) for SAC.
|
||||
|
|
@ -255,6 +207,7 @@ class SACPolicy(BasePolicy):
|
|||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
:param n_critics: (int) Number of critic networks to create.
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space: gym.spaces.Space,
|
||||
|
|
@ -272,7 +225,8 @@ class SACPolicy(BasePolicy):
|
|||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None):
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
n_critics: int = 2):
|
||||
super(SACPolicy, self).__init__(observation_space, action_space,
|
||||
device,
|
||||
features_extractor_class,
|
||||
|
|
@ -313,6 +267,9 @@ class SACPolicy(BasePolicy):
|
|||
'clip_mean': clip_mean
|
||||
}
|
||||
self.actor_kwargs.update(sde_kwargs)
|
||||
self.critic_kwargs = self.net_args.copy()
|
||||
self.critic_kwargs.update({'n_critics': n_critics})
|
||||
|
||||
self.actor, self.actor_target = None, None
|
||||
self.critic, self.critic_target = None, None
|
||||
|
||||
|
|
@ -345,6 +302,7 @@ class SACPolicy(BasePolicy):
|
|||
sde_net_arch=self.actor_kwargs['sde_net_arch'],
|
||||
use_expln=self.actor_kwargs['use_expln'],
|
||||
clip_mean=self.actor_kwargs['clip_mean'],
|
||||
n_critics=self.critic_kwargs['n_critics'],
|
||||
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
|
||||
optimizer_class=self.optimizer_class,
|
||||
optimizer_kwargs=self.optimizer_kwargs,
|
||||
|
|
@ -364,8 +322,8 @@ class SACPolicy(BasePolicy):
|
|||
def make_actor(self) -> Actor:
|
||||
return Actor(**self.actor_kwargs).to(self.device)
|
||||
|
||||
def make_critic(self) -> Critic:
|
||||
return Critic(**self.net_args).to(self.device)
|
||||
def make_critic(self) -> ContinuousCritic:
|
||||
return ContinuousCritic(**self.critic_kwargs).to(self.device)
|
||||
|
||||
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
return self._predict(obs, deterministic=deterministic)
|
||||
|
|
@ -403,6 +361,7 @@ class CnnPolicy(SACPolicy):
|
|||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
:param n_critics: (int) Number of critic networks to create.
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space: gym.spaces.Space,
|
||||
|
|
@ -420,7 +379,8 @@ class CnnPolicy(SACPolicy):
|
|||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None):
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
n_critics: int = 2):
|
||||
super(CnnPolicy, self).__init__(observation_space,
|
||||
action_space,
|
||||
lr_schedule,
|
||||
|
|
@ -436,7 +396,8 @@ class CnnPolicy(SACPolicy):
|
|||
features_extractor_kwargs,
|
||||
normalize_images,
|
||||
optimizer_class,
|
||||
optimizer_kwargs)
|
||||
optimizer_kwargs,
|
||||
n_critics)
|
||||
|
||||
|
||||
register_policy("MlpPolicy", MlpPolicy)
|
||||
|
|
|
|||
|
|
@ -118,7 +118,7 @@ class SAC(OffPolicyAlgorithm):
|
|||
def _setup_model(self) -> None:
|
||||
super(SAC, self)._setup_model()
|
||||
self._create_aliases()
|
||||
|
||||
assert self.critic.n_critics == 2, "SAC only supports `n_critics=2` for now"
|
||||
# Target entropy is used when learning the entropy coefficient
|
||||
if self.target_entropy == 'auto':
|
||||
# automatically set target entropy if needed
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
from typing import Optional, List, Tuple, Callable, Union, Type, Any, Dict
|
||||
from typing import Optional, List, Callable, Union, Type, Any, Dict
|
||||
|
||||
import gym
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
|
||||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
from stable_baselines3.common.policies import BaseModel, BasePolicy, register_policy
|
||||
from stable_baselines3.common.policies import BasePolicy, register_policy, ContinuousCritic
|
||||
from stable_baselines3.common.torch_layers import create_mlp, NatureCNN, BaseFeaturesExtractor, FlattenExtractor
|
||||
|
||||
|
||||
|
|
@ -71,57 +71,6 @@ class Actor(BasePolicy):
|
|||
return self.forward(observation, deterministic=deterministic)
|
||||
|
||||
|
||||
class Critic(BaseModel):
|
||||
"""
|
||||
Critic network for TD3,
|
||||
in fact it represents the action-state value function (Q-value function)
|
||||
|
||||
:param observation_space: (gym.spaces.Space) Obervation space
|
||||
:param action_space: (gym.spaces.Space) Action space
|
||||
:param net_arch: ([int]) Network architecture
|
||||
:param features_extractor: (nn.Module) Network to extract features
|
||||
(a CNN when using images, a nn.Flatten() layer otherwise)
|
||||
:param features_dim: (int) Number of features
|
||||
:param activation_fn: (Type[nn.Module]) Activation function
|
||||
:param normalize_images: (bool) Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param device: (Union[th.device, str]) Device on which the code should run.
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
net_arch: List[int],
|
||||
features_extractor: nn.Module,
|
||||
features_dim: int,
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
normalize_images: bool = True,
|
||||
device: Union[th.device, str] = 'auto'):
|
||||
super(Critic, self).__init__(observation_space, action_space,
|
||||
features_extractor=features_extractor,
|
||||
normalize_images=normalize_images,
|
||||
device=device)
|
||||
|
||||
action_dim = get_action_dim(self.action_space)
|
||||
|
||||
q1_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
|
||||
self.q1_net = nn.Sequential(*q1_net)
|
||||
|
||||
q2_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
|
||||
self.q2_net = nn.Sequential(*q2_net)
|
||||
|
||||
def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
# Learn the features extractor using the policy loss only
|
||||
with th.no_grad():
|
||||
features = self.extract_features(obs)
|
||||
qvalue_input = th.cat([features, actions], dim=1)
|
||||
return self.q1_net(qvalue_input), self.q2_net(qvalue_input)
|
||||
|
||||
def q1_forward(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor:
|
||||
with th.no_grad():
|
||||
features = self.extract_features(obs)
|
||||
return self.q1_net(th.cat([features, actions], dim=1))
|
||||
|
||||
|
||||
class TD3Policy(BasePolicy):
|
||||
"""
|
||||
Policy class (with both actor and critic) for TD3.
|
||||
|
|
@ -141,6 +90,7 @@ class TD3Policy(BasePolicy):
|
|||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
:param n_critics: (int) Number of critic networks to create.
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space: gym.spaces.Space,
|
||||
|
|
@ -153,7 +103,8 @@ class TD3Policy(BasePolicy):
|
|||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None):
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
n_critics: int = 2):
|
||||
super(TD3Policy, self).__init__(observation_space, action_space,
|
||||
device,
|
||||
features_extractor_class,
|
||||
|
|
@ -185,6 +136,8 @@ class TD3Policy(BasePolicy):
|
|||
'normalize_images': normalize_images,
|
||||
'device': device
|
||||
}
|
||||
self.critic_kwargs = self.net_args.copy()
|
||||
self.critic_kwargs.update({'n_critics': n_critics})
|
||||
self.actor, self.actor_target = None, None
|
||||
self.critic, self.critic_target = None, None
|
||||
|
||||
|
|
@ -208,6 +161,7 @@ class TD3Policy(BasePolicy):
|
|||
data.update(dict(
|
||||
net_arch=self.net_args['net_arch'],
|
||||
activation_fn=self.net_args['activation_fn'],
|
||||
n_critics=self.critic_kwargs['n_critics'],
|
||||
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
|
||||
optimizer_class=self.optimizer_class,
|
||||
optimizer_kwargs=self.optimizer_kwargs,
|
||||
|
|
@ -219,8 +173,8 @@ class TD3Policy(BasePolicy):
|
|||
def make_actor(self) -> Actor:
|
||||
return Actor(**self.net_args).to(self.device)
|
||||
|
||||
def make_critic(self) -> Critic:
|
||||
return Critic(**self.net_args).to(self.device)
|
||||
def make_critic(self) -> ContinuousCritic:
|
||||
return ContinuousCritic(**self.critic_kwargs).to(self.device)
|
||||
|
||||
def forward(self, observation: th.Tensor, deterministic: bool = False):
|
||||
return self._predict(observation, deterministic=deterministic)
|
||||
|
|
@ -251,6 +205,7 @@ class CnnPolicy(TD3Policy):
|
|||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
:param n_critics: (int) Number of critic networks to create.
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space: gym.spaces.Space,
|
||||
|
|
@ -263,7 +218,8 @@ class CnnPolicy(TD3Policy):
|
|||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None):
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
n_critics: int = 2):
|
||||
super(CnnPolicy, self).__init__(observation_space,
|
||||
action_space,
|
||||
lr_schedule,
|
||||
|
|
@ -274,7 +230,8 @@ class CnnPolicy(TD3Policy):
|
|||
features_extractor_kwargs,
|
||||
normalize_images,
|
||||
optimizer_class,
|
||||
optimizer_kwargs)
|
||||
optimizer_kwargs,
|
||||
n_critics)
|
||||
|
||||
|
||||
register_policy("MlpPolicy", MlpPolicy)
|
||||
|
|
|
|||
|
|
@ -96,6 +96,7 @@ class TD3(OffPolicyAlgorithm):
|
|||
def _setup_model(self) -> None:
|
||||
super(TD3, self)._setup_model()
|
||||
self._create_aliases()
|
||||
assert self.critic.n_critics == 2, "TD3 only supports `n_critics=2` for now"
|
||||
|
||||
def _create_aliases(self) -> None:
|
||||
self.actor = self.policy.actor
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
0.8.0a2
|
||||
0.8.0a3
|
||||
|
|
|
|||
Loading…
Reference in a new issue