From b702884c23b6aeaa5d2a830b37d6b15fb1bdf983 Mon Sep 17 00:00:00 2001 From: Alex Pasquali Date: Mon, 23 Jan 2023 14:55:19 +0100 Subject: [PATCH] Removed shared layers in mlp_extractor (#1292) * Modified actor-critic policies & MlpExtractor class ActorCriticPolicy: - changed type hint of net_arch param: now it's a dict - removed check that if features extractor is not shared: no shared layers are allowed in the mlp_extractor regardless of the features extractor ActorCriticCnnPolicy: - changed type hint of net_arch param: now it's a dict MultiInputActorcriticPolicy: - changed type hint of net_arch param: now it's a dict MlpExtractor: - changed type hint of net_arch param: now it's a dict - adapted networks creation - adapted methods: forward, forward_actor & forward_critic * Removed shared layers in mlp_extractor * Updated docs and changelog + reformat * Updated custom policy tests * Removed test on deprecation warning for share layers in mlp_extractor Now shared layers are removed * Update version * Update RL Zoo doc * Fix linter warnings * Add ruff to Makefile (experimental) * Add backward compat code and minor updates * Update tests * Add backward compatibility * Fix test * Improve compat code Co-authored-by: Antonin RAFFIN --- Makefile | 7 ++ docs/guide/custom_policy.rst | 45 ++++------ docs/guide/rl_zoo.rst | 19 ++-- docs/misc/changelog.rst | 3 +- stable_baselines3/common/base_class.py | 6 +- stable_baselines3/common/buffers.py | 29 ++++-- stable_baselines3/common/policies.py | 37 +++----- stable_baselines3/common/torch_layers.py | 109 +++++++---------------- stable_baselines3/version.txt | 2 +- tests/test_custom_policy.py | 25 ++---- tests/test_identity.py | 2 + 11 files changed, 122 insertions(+), 162 deletions(-) diff --git a/Makefile b/Makefile index c806507..6351162 100644 --- a/Makefile +++ b/Makefile @@ -19,6 +19,13 @@ lint: # exit-zero treats all errors as warnings. flake8 ${LINT_PATHS} --count --exit-zero --statistics +ruff: + # stop the build if there are Python syntax errors or undefined names + # see https://lintlyci.github.io/Flake8Rules/ + ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --show-source + # exit-zero treats all errors as warnings. + ruff ${LINT_PATHS} --exit-zero --line-length 127 + format: # Sort imports isort ${LINT_PATHS} diff --git a/docs/guide/custom_policy.rst b/docs/guide/custom_policy.rst index c9e598e..dae6048 100644 --- a/docs/guide/custom_policy.rst +++ b/docs/guide/custom_policy.rst @@ -117,11 +117,6 @@ that derives from ``BaseFeaturesExtractor`` and then pass it to the model when t ``policy_kwargs`` (both for on-policy and off-policy algorithms). -.. warning:: - If the features extractor is **non-shared**, it is **not** possible to have shared layers in the ``mlp_extractor``. - Please note that this option is **deprecated**, therefore in a future release the layers in the ``mlp_extractor`` will have to be non-shared. - - .. code-block:: python import torch as th @@ -242,41 +237,31 @@ On-Policy Algorithms Custom Networks --------------- -.. warning:: - Shared layers in the the ``mlp_extractor`` are **deprecated**. - In a future release all layers will have to be non-shared. - If needed, you can implement a custom policy network (see `advanced example below <#advanced-example>`_). - -.. warning:: - In the next Stable-Baselines3 release, the behavior of ``net_arch=[128, 128]`` will change - to match the one of off-policy algorithms: it will create **separate** networks (instead of shared currently) - for the actor and the critic, with the same architecture. - - If you need a network architecture that is different for the actor and the critic when using ``PPO``, ``A2C`` or ``TRPO``, you can pass a dictionary of the following structure: ``dict(pi=[], vf=[])``. For example, if you want a different architecture for the actor (aka ``pi``) and the critic ( value-function aka ``vf``) networks, then you can specify ``net_arch=dict(pi=[32, 32], vf=[64, 64])``. -.. Otherwise, to have actor and critic that share the same network architecture, -.. you only need to specify ``net_arch=[128, 128]`` (here, two hidden layers of 128 units each). +Otherwise, to have actor and critic that share the same network architecture, +you only need to specify ``net_arch=[128, 128]`` (here, two hidden layers of 128 units each, this is equivalent to ``net_arch=dict(pi=[128, 128], vf=[128, 128])``). + +If shared layers are needed, you need to implement a custom policy network (see `advanced example below <#advanced-example>`_). Examples ~~~~~~~~ -.. TODO(antonin): uncomment when shared network is removed -.. Same architecture for actor and critic with two layers of size 128: ``net_arch=[128, 128]`` -.. -.. .. code-block:: none -.. -.. obs -.. / \ -.. <128> <128> -.. | | -.. <128> <128> -.. | | -.. action value +Same architecture for actor and critic with two layers of size 128: ``net_arch=[128, 128]`` + +.. code-block:: none + + obs + / \ + <128> <128> + | | + <128> <128> + | | + action value Different architectures for actor and critic: ``net_arch=dict(pi=[32, 32], vf=[64, 64])`` diff --git a/docs/guide/rl_zoo.rst b/docs/guide/rl_zoo.rst index ea15832..8a611d8 100644 --- a/docs/guide/rl_zoo.rst +++ b/docs/guide/rl_zoo.rst @@ -20,6 +20,10 @@ Goals of this repository: Installation ------------ +Option 1: install the python package ``pip install rl_zoo3`` + +or: + 1. Clone the repository: :: @@ -42,7 +46,10 @@ Installation :: apt-get install swig cmake ffmpeg + # full dependencies pip install -r requirements.txt + # minimal dependencies + pip install -e . Train an Agent @@ -56,13 +63,13 @@ using: :: - python train.py --algo algo_name --env env_id + python -m rl_zoo3.train --algo algo_name --env env_id For example (with evaluation and checkpoints): :: - python train.py --algo ppo --env CartPole-v1 --eval-freq 10000 --save-freq 50000 + python -m rl_zoo3.train --algo ppo --env CartPole-v1 --eval-freq 10000 --save-freq 50000 Continue training (here, load pretrained agent for Breakout and continue @@ -70,7 +77,7 @@ training for 5000 steps): :: - python train.py --algo a2c --env BreakoutNoFrameskip-v4 -i trained_agents/a2c/BreakoutNoFrameskip-v4_1/BreakoutNoFrameskip-v4.zip -n 5000 + python -m rl_zoo3.train --algo a2c --env BreakoutNoFrameskip-v4 -i trained_agents/a2c/BreakoutNoFrameskip-v4_1/BreakoutNoFrameskip-v4.zip -n 5000 Enjoy a Trained Agent @@ -80,13 +87,13 @@ If the trained agent exists, then you can see it in action using: :: - python enjoy.py --algo algo_name --env env_id + python -m rl_zoo3.enjoy --algo algo_name --env env_id For example, enjoy A2C on Breakout during 5000 timesteps: :: - python enjoy.py --algo a2c --env BreakoutNoFrameskip-v4 --folder rl-trained-agents/ -n 5000 + python -m rl_zoo3.enjoy --algo a2c --env BreakoutNoFrameskip-v4 --folder rl-trained-agents/ -n 5000 Hyperparameter Optimization @@ -100,7 +107,7 @@ with a budget of 1000 trials and a maximum of 50000 steps: :: - python train.py --algo ppo --env MountainCar-v0 -n 50000 -optimize --n-trials 1000 --n-jobs 2 \ + python -m rl_zoo3.train --algo ppo --env MountainCar-v0 -n 50000 -optimize --n-trials 1000 --n-jobs 2 \ --sampler random --pruner median diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index db9d6c6..7cb344b 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,12 +4,13 @@ Changelog ========== -Release 1.8.0a1 (WIP) +Release 1.8.0a2 (WIP) -------------------------- Breaking Changes: ^^^^^^^^^^^^^^^^^ +- Removed shared layers in ``mlp_extractor`` (@AlexPasqua) New Features: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index a71043d..b6fba85 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -667,6 +667,11 @@ class BaseAlgorithm(ABC): if "policy_kwargs" in data: if "device" in data["policy_kwargs"]: del data["policy_kwargs"]["device"] + # backward compatibility, convert to new format + if "net_arch" in data["policy_kwargs"] and len(data["policy_kwargs"]["net_arch"]) > 0: + saved_net_arch = data["policy_kwargs"]["net_arch"] + if isinstance(saved_net_arch, list) and isinstance(saved_net_arch[0], dict): + data["policy_kwargs"]["net_arch"] = saved_net_arch[0] if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data["policy_kwargs"]: raise ValueError( @@ -726,7 +731,6 @@ class BaseAlgorithm(ABC): ) else: raise e - # put other pytorch variables back in place if pytorch_variables is not None: for name in pytorch_variables: diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index f9f0c72..f71dd29 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -474,7 +474,11 @@ class RolloutBuffer(BaseBuffer): yield self._get_samples(indices[start_idx : start_idx + batch_size]) start_idx += batch_size - def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> RolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME + def _get_samples( + self, + batch_inds: np.ndarray, + env: Optional[VecNormalize] = None, + ) -> RolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME data = ( self.observations[batch_inds], self.actions[batch_inds], @@ -603,7 +607,11 @@ class DictReplayBuffer(ReplayBuffer): self.full = True self.pos = 0 - def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples: # type: ignore[signature-mismatch] #FIXME: + def sample( + self, + batch_size: int, + env: Optional[VecNormalize] = None, + ) -> DictReplayBufferSamples: # type: ignore[signature-mismatch] #FIXME: """ Sample elements from the replay buffer. @@ -614,7 +622,11 @@ class DictReplayBuffer(ReplayBuffer): """ return super(ReplayBuffer, self).sample(batch_size=batch_size, env=env) - def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples: # type: ignore[signature-mismatch] #FIXME: + def _get_samples( + self, + batch_inds: np.ndarray, + env: Optional[VecNormalize] = None, + ) -> DictReplayBufferSamples: # type: ignore[signature-mismatch] #FIXME: # Sample randomly the env idx env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),)) @@ -743,7 +755,10 @@ class DictRolloutBuffer(RolloutBuffer): if self.pos == self.buffer_size: self.full = True - def get(self, batch_size: Optional[int] = None) -> Generator[DictRolloutBufferSamples, None, None]: # type: ignore[signature-mismatch] #FIXME + def get( + self, + batch_size: Optional[int] = None, + ) -> Generator[DictRolloutBufferSamples, None, None]: # type: ignore[signature-mismatch] #FIXME assert self.full, "" indices = np.random.permutation(self.buffer_size * self.n_envs) # Prepare the data @@ -767,7 +782,11 @@ class DictRolloutBuffer(RolloutBuffer): yield self._get_samples(indices[start_idx : start_idx + batch_size]) start_idx += batch_size - def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> DictRolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME + def _get_samples( + self, + batch_inds: np.ndarray, + env: Optional[VecNormalize] = None, + ) -> DictRolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME return DictRolloutBufferSamples( observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()}, diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 0cf4917..793cfc5 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -418,8 +418,7 @@ class ActorCriticPolicy(BasePolicy): observation_space: spaces.Space, action_space: spaces.Space, lr_schedule: Schedule, - # TODO(antonin): update type annotation when we remove shared network support - net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None, + net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.Tanh, ortho_init: bool = True, use_sde: bool = False, @@ -452,21 +451,15 @@ class ActorCriticPolicy(BasePolicy): normalize_images=normalize_images, ) - # Convert [dict()] to dict() as shared network are deprecated - if isinstance(net_arch, list) and len(net_arch) > 0: - if isinstance(net_arch[0], dict): - warnings.warn( - ( - "As shared layers in the mlp_extractor are deprecated and will be removed in SB3 v1.8.0, " - "you should now pass directly a dictionary and not a list " - "(net_arch=dict(pi=..., vf=...) instead of net_arch=[dict(pi=..., vf=...)])" - ), - ) - net_arch = net_arch[0] - else: - # Note: deprecation warning will be emitted - # by the MlpExtractor constructor - pass + if isinstance(net_arch, list) and len(net_arch) > 0 and isinstance(net_arch[0], dict): + warnings.warn( + ( + "As shared layers in the mlp_extractor are removed since SB3 v1.8.0, " + "you should now pass directly a dictionary and not a list " + "(net_arch=dict(pi=..., vf=...) instead of net_arch=[dict(pi=..., vf=...)])" + ), + ) + net_arch = net_arch[0] # Default network architecture, from stable-baselines if net_arch is None: @@ -488,12 +481,6 @@ class ActorCriticPolicy(BasePolicy): else: self.pi_features_extractor = self.features_extractor self.vf_features_extractor = self.make_features_extractor() - # if the features extractor is not shared, there cannot be shared layers in the mlp_extractor - # TODO(antonin): update the check once we change net_arch behavior - if isinstance(net_arch, list) and len(net_arch) > 0: - raise ValueError( - "Error: if the features extractor is not shared, there cannot be shared layers in the mlp_extractor" - ) self.log_std_init = log_std_init dist_kwargs = None @@ -770,7 +757,7 @@ class ActorCriticCnnPolicy(ActorCriticPolicy): observation_space: spaces.Space, action_space: spaces.Space, lr_schedule: Schedule, - net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None, + net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.Tanh, ortho_init: bool = True, use_sde: bool = False, @@ -843,7 +830,7 @@ class MultiInputActorCriticPolicy(ActorCriticPolicy): observation_space: spaces.Dict, action_space: spaces.Space, lr_schedule: Schedule, - net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None, + net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, activation_fn: Type[nn.Module] = nn.Tanh, ortho_init: bool = True, use_sde: bool = False, diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index 302d9b1..44714d6 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -1,5 +1,3 @@ -import warnings -from itertools import zip_longest from typing import Dict, List, Tuple, Type, Union import gym @@ -151,98 +149,57 @@ class MlpExtractor(nn.Module): Constructs an MLP that receives the output from a previous features extractor (i.e. a CNN) or directly the observations (if no features extractor is applied) as an input and outputs a latent representation for the policy and a value network. - The ``net_arch`` parameter allows to specify the amount and size of the hidden layers and how many - of them are shared between the policy network and the value network. It is assumed to be a list with the following - structure: - 1. An arbitrary length (zero allowed) number of integers each specifying the number of units in a shared layer. - If the number of ints is zero, there will be no shared layers. - 2. An optional dict, to specify the following non-shared layers for the value network and the policy network. - It is formatted like ``dict(vf=[], pi=[])``. - If it is missing any of the keys (pi or vf), no non-shared layers (empty list) is assumed. + The ``net_arch`` parameter allows to specify the amount and size of the hidden layers. + It can be in either of the following forms: + 1. ``dict(vf=[], pi=[])``: to specify the amount and size of the layers in the + policy and value nets individually. If it is missing any of the keys (pi or vf), + zero layers will be considered for that key. + 2. ``[]``: "shortcut" in case the amount and size of the layers + in the policy and value nets are the same. Same as ``dict(vf=int_list, pi=int_list)`` + where int_list is the same for the actor and critic. - Deprecation note: shared layers in ``net_arch`` are deprecated, please use separate - pi and vf networks (e.g. net_arch=dict(pi=[...], vf=[...])) - - For example to construct a network with one shared layer of size 55 followed by two non-shared layers for the value - network of size 255 and a single non-shared layer of size 128 for the policy network, the following layers_spec - would be used: ``[55, dict(vf=[255, 255], pi=[128])]``. A simple shared network topology with two layers of size 128 - would be specified as [128, 128]. - - Adapted from Stable Baselines. + .. note:: + If a key is not specified or an empty list is passed ``[]``, a linear network will be used. :param feature_dim: Dimension of the feature vector (can be the output of a CNN) :param net_arch: The specification of the policy and value networks. See above for details on its formatting. :param activation_fn: The activation function to use for the networks. - :param device: + :param device: PyTorch device. """ def __init__( self, feature_dim: int, - net_arch: Union[Dict[str, List[int]], List[Union[int, Dict[str, List[int]]]]], + net_arch: Union[List[int], Dict[str, List[int]]], activation_fn: Type[nn.Module], device: Union[th.device, str] = "auto", ) -> None: super().__init__() device = get_device(device) - shared_net: List[nn.Module] = [] policy_net: List[nn.Module] = [] value_net: List[nn.Module] = [] - policy_only_layers: List[int] = [] # Layer sizes of the network that only belongs to the policy network - value_only_layers: List[int] = [] # Layer sizes of the network that only belongs to the value network - last_layer_dim_shared = feature_dim + last_layer_dim_pi = feature_dim + last_layer_dim_vf = feature_dim - if isinstance(net_arch, list) and len(net_arch) > 0 and isinstance(net_arch[0], int): - warnings.warn( - ( - "Shared layers in the mlp_extractor are deprecated and will be removed in SB3 v1.8.0, " - "please use separate pi and vf networks " - "(e.g. net_arch=dict(pi=[...], vf=[...]))" - ), - DeprecationWarning, - ) - - # TODO(antonin): update behavior for net_arch=[64, 64] - # once shared networks are removed + # save dimensions of layers in policy and value nets if isinstance(net_arch, dict): - policy_only_layers = net_arch["pi"] - value_only_layers = net_arch["vf"] + # Note: if key is not specificed, assume linear network + pi_layers_dims = net_arch.get("pi", []) # Layer sizes of the policy network + vf_layers_dims = net_arch.get("vf", []) # Layer sizes of the value network else: - # Iterate through the shared layers and build the shared parts of the network - for layer in net_arch: - if isinstance(layer, int): # Check that this is a shared layer - shared_net.append(nn.Linear(last_layer_dim_shared, layer)) # add linear of size layer - shared_net.append(activation_fn()) - last_layer_dim_shared = layer - else: - assert isinstance(layer, dict), "Error: the net_arch list can only contain ints and dicts" - if "pi" in layer: - assert isinstance(layer["pi"], list), "Error: net_arch[-1]['pi'] must contain a list of integers." - policy_only_layers = layer["pi"] - - if "vf" in layer: - assert isinstance(layer["vf"], list), "Error: net_arch[-1]['vf'] must contain a list of integers." - value_only_layers = layer["vf"] - break # From here on the network splits up in policy and value network - - last_layer_dim_pi = last_layer_dim_shared - last_layer_dim_vf = last_layer_dim_shared - - # Build the non-shared part of the network - for pi_layer_size, vf_layer_size in zip_longest(policy_only_layers, value_only_layers): - if pi_layer_size is not None: - assert isinstance(pi_layer_size, int), "Error: net_arch[-1]['pi'] must only contain integers." - policy_net.append(nn.Linear(last_layer_dim_pi, pi_layer_size)) - policy_net.append(activation_fn()) - last_layer_dim_pi = pi_layer_size - - if vf_layer_size is not None: - assert isinstance(vf_layer_size, int), "Error: net_arch[-1]['vf'] must only contain integers." - value_net.append(nn.Linear(last_layer_dim_vf, vf_layer_size)) - value_net.append(activation_fn()) - last_layer_dim_vf = vf_layer_size + pi_layers_dims = vf_layers_dims = net_arch + # Iterate through the policy layers and build the policy net + for curr_layer_dim in pi_layers_dims: + policy_net.append(nn.Linear(last_layer_dim_pi, curr_layer_dim)) + policy_net.append(activation_fn()) + last_layer_dim_pi = curr_layer_dim + # Iterate through the value layers and build the value net + for curr_layer_dim in vf_layers_dims: + value_net.append(nn.Linear(last_layer_dim_vf, curr_layer_dim)) + value_net.append(activation_fn()) + last_layer_dim_vf = curr_layer_dim # Save dim, used to create the distributions self.latent_dim_pi = last_layer_dim_pi @@ -250,7 +207,6 @@ class MlpExtractor(nn.Module): # Create networks # If the list of layers is empty, the network will just act as an Identity module - self.shared_net = nn.Sequential(*shared_net).to(device) self.policy_net = nn.Sequential(*policy_net).to(device) self.value_net = nn.Sequential(*value_net).to(device) @@ -259,14 +215,13 @@ class MlpExtractor(nn.Module): :return: latent_policy, latent_value of the specified network. If all layers are shared, then ``latent_policy == latent_value`` """ - shared_latent = self.shared_net(features) - return self.policy_net(shared_latent), self.value_net(shared_latent) + return self.forward_actor(features), self.forward_critic(features) def forward_actor(self, features: th.Tensor) -> th.Tensor: - return self.policy_net(self.shared_net(features)) + return self.policy_net(features) def forward_critic(self, features: th.Tensor) -> th.Tensor: - return self.value_net(self.shared_net(features)) + return self.value_net(features) class CombinedExtractor(BaseFeaturesExtractor): diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 0d03ef9..c3d22c0 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.8.0a1 +1.8.0a2 diff --git a/tests/test_custom_policy.py b/tests/test_custom_policy.py index 85c3d37..1f89b23 100644 --- a/tests/test_custom_policy.py +++ b/tests/test_custom_policy.py @@ -9,21 +9,21 @@ from stable_baselines3.common.sb2_compat.rmsprop_tf_like import RMSpropTFLike "net_arch", [ [], - dict(vf=[16], pi=[8]), - # [] behavior will change [4], [4, 4], - # All values below are deprecated - [12, dict(vf=[16], pi=[8])], - [12, dict(vf=[8, 4], pi=[8])], - [12, dict(vf=[8], pi=[8, 4])], - [12, dict(pi=[8])], + dict(vf=[16], pi=[8]), + dict(vf=[8, 4], pi=[8]), + dict(vf=[8], pi=[8, 4]), + dict(pi=[8]), + # Old format, emits a warning + [dict(vf=[8])], + [dict(vf=[8], pi=[4])], ], ) @pytest.mark.parametrize("model_class", [A2C, PPO]) def test_flexible_mlp(model_class, net_arch): - if isinstance(net_arch, list) and len(net_arch) > 0 and isinstance(net_arch[0], int): - with pytest.warns(DeprecationWarning): + if isinstance(net_arch, list) and len(net_arch) > 0 and isinstance(net_arch[0], dict): + with pytest.warns(UserWarning): _ = model_class("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=net_arch), n_steps=64).learn(300) else: _ = model_class("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=net_arch), n_steps=64).learn(300) @@ -62,10 +62,3 @@ def test_tf_like_rmsprop_optimizer(): def test_dqn_custom_policy(): policy_kwargs = dict(optimizer_class=RMSpropTFLike, net_arch=[32]) _ = DQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, learning_starts=100).learn(300) - - -@pytest.mark.parametrize("model_class", [A2C, PPO]) -def test_not_shared_features_extractor(model_class): - policy_kwargs = dict(net_arch=[12, dict(vf=[16], pi=[8])], share_features_extractor=False) - with pytest.raises(ValueError): - model_class("MlpPolicy", "Pendulum-v1", policy_kwargs=policy_kwargs) diff --git a/tests/test_identity.py b/tests/test_identity.py index f5bbc49..cc7746b 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -45,6 +45,8 @@ def test_continuous(model_class): n_actions = 1 action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions)) kwargs["action_noise"] = action_noise + elif model_class in [A2C]: + kwargs["policy_kwargs"]["log_std_init"] = -0.5 model = model_class("MlpPolicy", env, **kwargs).learn(n_steps)