mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-02 03:55:39 +00:00
Removed shared layers in mlp_extractor (#1292)
* Modified actor-critic policies & MlpExtractor class ActorCriticPolicy: - changed type hint of net_arch param: now it's a dict - removed check that if features extractor is not shared: no shared layers are allowed in the mlp_extractor regardless of the features extractor ActorCriticCnnPolicy: - changed type hint of net_arch param: now it's a dict MultiInputActorcriticPolicy: - changed type hint of net_arch param: now it's a dict MlpExtractor: - changed type hint of net_arch param: now it's a dict - adapted networks creation - adapted methods: forward, forward_actor & forward_critic * Removed shared layers in mlp_extractor * Updated docs and changelog + reformat * Updated custom policy tests * Removed test on deprecation warning for share layers in mlp_extractor Now shared layers are removed * Update version * Update RL Zoo doc * Fix linter warnings * Add ruff to Makefile (experimental) * Add backward compat code and minor updates * Update tests * Add backward compatibility * Fix test * Improve compat code Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
This commit is contained in:
parent
69fdf155e1
commit
b702884c23
11 changed files with 122 additions and 162 deletions
7
Makefile
7
Makefile
|
|
@ -19,6 +19,13 @@ lint:
|
|||
# exit-zero treats all errors as warnings.
|
||||
flake8 ${LINT_PATHS} --count --exit-zero --statistics
|
||||
|
||||
ruff:
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
# see https://lintlyci.github.io/Flake8Rules/
|
||||
ruff ${LINT_PATHS} --select=E9,F63,F7,F82 --show-source
|
||||
# exit-zero treats all errors as warnings.
|
||||
ruff ${LINT_PATHS} --exit-zero --line-length 127
|
||||
|
||||
format:
|
||||
# Sort imports
|
||||
isort ${LINT_PATHS}
|
||||
|
|
|
|||
|
|
@ -117,11 +117,6 @@ that derives from ``BaseFeaturesExtractor`` and then pass it to the model when t
|
|||
``policy_kwargs`` (both for on-policy and off-policy algorithms).
|
||||
|
||||
|
||||
.. warning::
|
||||
If the features extractor is **non-shared**, it is **not** possible to have shared layers in the ``mlp_extractor``.
|
||||
Please note that this option is **deprecated**, therefore in a future release the layers in the ``mlp_extractor`` will have to be non-shared.
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import torch as th
|
||||
|
|
@ -242,41 +237,31 @@ On-Policy Algorithms
|
|||
Custom Networks
|
||||
---------------
|
||||
|
||||
.. warning::
|
||||
Shared layers in the the ``mlp_extractor`` are **deprecated**.
|
||||
In a future release all layers will have to be non-shared.
|
||||
If needed, you can implement a custom policy network (see `advanced example below <#advanced-example>`_).
|
||||
|
||||
.. warning::
|
||||
In the next Stable-Baselines3 release, the behavior of ``net_arch=[128, 128]`` will change
|
||||
to match the one of off-policy algorithms: it will create **separate** networks (instead of shared currently)
|
||||
for the actor and the critic, with the same architecture.
|
||||
|
||||
|
||||
If you need a network architecture that is different for the actor and the critic when using ``PPO``, ``A2C`` or ``TRPO``,
|
||||
you can pass a dictionary of the following structure: ``dict(pi=[<actor network architecture>], vf=[<critic network architecture>])``.
|
||||
|
||||
For example, if you want a different architecture for the actor (aka ``pi``) and the critic ( value-function aka ``vf``) networks,
|
||||
then you can specify ``net_arch=dict(pi=[32, 32], vf=[64, 64])``.
|
||||
|
||||
.. Otherwise, to have actor and critic that share the same network architecture,
|
||||
.. you only need to specify ``net_arch=[128, 128]`` (here, two hidden layers of 128 units each).
|
||||
Otherwise, to have actor and critic that share the same network architecture,
|
||||
you only need to specify ``net_arch=[128, 128]`` (here, two hidden layers of 128 units each, this is equivalent to ``net_arch=dict(pi=[128, 128], vf=[128, 128])``).
|
||||
|
||||
If shared layers are needed, you need to implement a custom policy network (see `advanced example below <#advanced-example>`_).
|
||||
|
||||
Examples
|
||||
~~~~~~~~
|
||||
|
||||
.. TODO(antonin): uncomment when shared network is removed
|
||||
.. Same architecture for actor and critic with two layers of size 128: ``net_arch=[128, 128]``
|
||||
..
|
||||
.. .. code-block:: none
|
||||
..
|
||||
.. obs
|
||||
.. / \
|
||||
.. <128> <128>
|
||||
.. | |
|
||||
.. <128> <128>
|
||||
.. | |
|
||||
.. action value
|
||||
Same architecture for actor and critic with two layers of size 128: ``net_arch=[128, 128]``
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
obs
|
||||
/ \
|
||||
<128> <128>
|
||||
| |
|
||||
<128> <128>
|
||||
| |
|
||||
action value
|
||||
|
||||
Different architectures for actor and critic: ``net_arch=dict(pi=[32, 32], vf=[64, 64])``
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,10 @@ Goals of this repository:
|
|||
Installation
|
||||
------------
|
||||
|
||||
Option 1: install the python package ``pip install rl_zoo3``
|
||||
|
||||
or:
|
||||
|
||||
1. Clone the repository:
|
||||
|
||||
::
|
||||
|
|
@ -42,7 +46,10 @@ Installation
|
|||
::
|
||||
|
||||
apt-get install swig cmake ffmpeg
|
||||
# full dependencies
|
||||
pip install -r requirements.txt
|
||||
# minimal dependencies
|
||||
pip install -e .
|
||||
|
||||
|
||||
Train an Agent
|
||||
|
|
@ -56,13 +63,13 @@ using:
|
|||
|
||||
::
|
||||
|
||||
python train.py --algo algo_name --env env_id
|
||||
python -m rl_zoo3.train --algo algo_name --env env_id
|
||||
|
||||
For example (with evaluation and checkpoints):
|
||||
|
||||
::
|
||||
|
||||
python train.py --algo ppo --env CartPole-v1 --eval-freq 10000 --save-freq 50000
|
||||
python -m rl_zoo3.train --algo ppo --env CartPole-v1 --eval-freq 10000 --save-freq 50000
|
||||
|
||||
|
||||
Continue training (here, load pretrained agent for Breakout and continue
|
||||
|
|
@ -70,7 +77,7 @@ training for 5000 steps):
|
|||
|
||||
::
|
||||
|
||||
python train.py --algo a2c --env BreakoutNoFrameskip-v4 -i trained_agents/a2c/BreakoutNoFrameskip-v4_1/BreakoutNoFrameskip-v4.zip -n 5000
|
||||
python -m rl_zoo3.train --algo a2c --env BreakoutNoFrameskip-v4 -i trained_agents/a2c/BreakoutNoFrameskip-v4_1/BreakoutNoFrameskip-v4.zip -n 5000
|
||||
|
||||
|
||||
Enjoy a Trained Agent
|
||||
|
|
@ -80,13 +87,13 @@ If the trained agent exists, then you can see it in action using:
|
|||
|
||||
::
|
||||
|
||||
python enjoy.py --algo algo_name --env env_id
|
||||
python -m rl_zoo3.enjoy --algo algo_name --env env_id
|
||||
|
||||
For example, enjoy A2C on Breakout during 5000 timesteps:
|
||||
|
||||
::
|
||||
|
||||
python enjoy.py --algo a2c --env BreakoutNoFrameskip-v4 --folder rl-trained-agents/ -n 5000
|
||||
python -m rl_zoo3.enjoy --algo a2c --env BreakoutNoFrameskip-v4 --folder rl-trained-agents/ -n 5000
|
||||
|
||||
|
||||
Hyperparameter Optimization
|
||||
|
|
@ -100,7 +107,7 @@ with a budget of 1000 trials and a maximum of 50000 steps:
|
|||
|
||||
::
|
||||
|
||||
python train.py --algo ppo --env MountainCar-v0 -n 50000 -optimize --n-trials 1000 --n-jobs 2 \
|
||||
python -m rl_zoo3.train --algo ppo --env MountainCar-v0 -n 50000 -optimize --n-trials 1000 --n-jobs 2 \
|
||||
--sampler random --pruner median
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,12 +4,13 @@ Changelog
|
|||
==========
|
||||
|
||||
|
||||
Release 1.8.0a1 (WIP)
|
||||
Release 1.8.0a2 (WIP)
|
||||
--------------------------
|
||||
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Removed shared layers in ``mlp_extractor`` (@AlexPasqua)
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -667,6 +667,11 @@ class BaseAlgorithm(ABC):
|
|||
if "policy_kwargs" in data:
|
||||
if "device" in data["policy_kwargs"]:
|
||||
del data["policy_kwargs"]["device"]
|
||||
# backward compatibility, convert to new format
|
||||
if "net_arch" in data["policy_kwargs"] and len(data["policy_kwargs"]["net_arch"]) > 0:
|
||||
saved_net_arch = data["policy_kwargs"]["net_arch"]
|
||||
if isinstance(saved_net_arch, list) and isinstance(saved_net_arch[0], dict):
|
||||
data["policy_kwargs"]["net_arch"] = saved_net_arch[0]
|
||||
|
||||
if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data["policy_kwargs"]:
|
||||
raise ValueError(
|
||||
|
|
@ -726,7 +731,6 @@ class BaseAlgorithm(ABC):
|
|||
)
|
||||
else:
|
||||
raise e
|
||||
|
||||
# put other pytorch variables back in place
|
||||
if pytorch_variables is not None:
|
||||
for name in pytorch_variables:
|
||||
|
|
|
|||
|
|
@ -474,7 +474,11 @@ class RolloutBuffer(BaseBuffer):
|
|||
yield self._get_samples(indices[start_idx : start_idx + batch_size])
|
||||
start_idx += batch_size
|
||||
|
||||
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> RolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME
|
||||
def _get_samples(
|
||||
self,
|
||||
batch_inds: np.ndarray,
|
||||
env: Optional[VecNormalize] = None,
|
||||
) -> RolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME
|
||||
data = (
|
||||
self.observations[batch_inds],
|
||||
self.actions[batch_inds],
|
||||
|
|
@ -603,7 +607,11 @@ class DictReplayBuffer(ReplayBuffer):
|
|||
self.full = True
|
||||
self.pos = 0
|
||||
|
||||
def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples: # type: ignore[signature-mismatch] #FIXME:
|
||||
def sample(
|
||||
self,
|
||||
batch_size: int,
|
||||
env: Optional[VecNormalize] = None,
|
||||
) -> DictReplayBufferSamples: # type: ignore[signature-mismatch] #FIXME:
|
||||
"""
|
||||
Sample elements from the replay buffer.
|
||||
|
||||
|
|
@ -614,7 +622,11 @@ class DictReplayBuffer(ReplayBuffer):
|
|||
"""
|
||||
return super(ReplayBuffer, self).sample(batch_size=batch_size, env=env)
|
||||
|
||||
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples: # type: ignore[signature-mismatch] #FIXME:
|
||||
def _get_samples(
|
||||
self,
|
||||
batch_inds: np.ndarray,
|
||||
env: Optional[VecNormalize] = None,
|
||||
) -> DictReplayBufferSamples: # type: ignore[signature-mismatch] #FIXME:
|
||||
# Sample randomly the env idx
|
||||
env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),))
|
||||
|
||||
|
|
@ -743,7 +755,10 @@ class DictRolloutBuffer(RolloutBuffer):
|
|||
if self.pos == self.buffer_size:
|
||||
self.full = True
|
||||
|
||||
def get(self, batch_size: Optional[int] = None) -> Generator[DictRolloutBufferSamples, None, None]: # type: ignore[signature-mismatch] #FIXME
|
||||
def get(
|
||||
self,
|
||||
batch_size: Optional[int] = None,
|
||||
) -> Generator[DictRolloutBufferSamples, None, None]: # type: ignore[signature-mismatch] #FIXME
|
||||
assert self.full, ""
|
||||
indices = np.random.permutation(self.buffer_size * self.n_envs)
|
||||
# Prepare the data
|
||||
|
|
@ -767,7 +782,11 @@ class DictRolloutBuffer(RolloutBuffer):
|
|||
yield self._get_samples(indices[start_idx : start_idx + batch_size])
|
||||
start_idx += batch_size
|
||||
|
||||
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> DictRolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME
|
||||
def _get_samples(
|
||||
self,
|
||||
batch_inds: np.ndarray,
|
||||
env: Optional[VecNormalize] = None,
|
||||
) -> DictRolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME
|
||||
|
||||
return DictRolloutBufferSamples(
|
||||
observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()},
|
||||
|
|
|
|||
|
|
@ -418,8 +418,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
# TODO(antonin): update type annotation when we remove shared network support
|
||||
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
|
||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
ortho_init: bool = True,
|
||||
use_sde: bool = False,
|
||||
|
|
@ -452,21 +451,15 @@ class ActorCriticPolicy(BasePolicy):
|
|||
normalize_images=normalize_images,
|
||||
)
|
||||
|
||||
# Convert [dict()] to dict() as shared network are deprecated
|
||||
if isinstance(net_arch, list) and len(net_arch) > 0:
|
||||
if isinstance(net_arch[0], dict):
|
||||
warnings.warn(
|
||||
(
|
||||
"As shared layers in the mlp_extractor are deprecated and will be removed in SB3 v1.8.0, "
|
||||
"you should now pass directly a dictionary and not a list "
|
||||
"(net_arch=dict(pi=..., vf=...) instead of net_arch=[dict(pi=..., vf=...)])"
|
||||
),
|
||||
)
|
||||
net_arch = net_arch[0]
|
||||
else:
|
||||
# Note: deprecation warning will be emitted
|
||||
# by the MlpExtractor constructor
|
||||
pass
|
||||
if isinstance(net_arch, list) and len(net_arch) > 0 and isinstance(net_arch[0], dict):
|
||||
warnings.warn(
|
||||
(
|
||||
"As shared layers in the mlp_extractor are removed since SB3 v1.8.0, "
|
||||
"you should now pass directly a dictionary and not a list "
|
||||
"(net_arch=dict(pi=..., vf=...) instead of net_arch=[dict(pi=..., vf=...)])"
|
||||
),
|
||||
)
|
||||
net_arch = net_arch[0]
|
||||
|
||||
# Default network architecture, from stable-baselines
|
||||
if net_arch is None:
|
||||
|
|
@ -488,12 +481,6 @@ class ActorCriticPolicy(BasePolicy):
|
|||
else:
|
||||
self.pi_features_extractor = self.features_extractor
|
||||
self.vf_features_extractor = self.make_features_extractor()
|
||||
# if the features extractor is not shared, there cannot be shared layers in the mlp_extractor
|
||||
# TODO(antonin): update the check once we change net_arch behavior
|
||||
if isinstance(net_arch, list) and len(net_arch) > 0:
|
||||
raise ValueError(
|
||||
"Error: if the features extractor is not shared, there cannot be shared layers in the mlp_extractor"
|
||||
)
|
||||
|
||||
self.log_std_init = log_std_init
|
||||
dist_kwargs = None
|
||||
|
|
@ -770,7 +757,7 @@ class ActorCriticCnnPolicy(ActorCriticPolicy):
|
|||
observation_space: spaces.Space,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
|
||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
ortho_init: bool = True,
|
||||
use_sde: bool = False,
|
||||
|
|
@ -843,7 +830,7 @@ class MultiInputActorCriticPolicy(ActorCriticPolicy):
|
|||
observation_space: spaces.Dict,
|
||||
action_space: spaces.Space,
|
||||
lr_schedule: Schedule,
|
||||
net_arch: Union[List[int], Dict[str, List[int]], List[Dict[str, List[int]]], None] = None,
|
||||
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
ortho_init: bool = True,
|
||||
use_sde: bool = False,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
import warnings
|
||||
from itertools import zip_longest
|
||||
from typing import Dict, List, Tuple, Type, Union
|
||||
|
||||
import gym
|
||||
|
|
@ -151,98 +149,57 @@ class MlpExtractor(nn.Module):
|
|||
Constructs an MLP that receives the output from a previous features extractor (i.e. a CNN) or directly
|
||||
the observations (if no features extractor is applied) as an input and outputs a latent representation
|
||||
for the policy and a value network.
|
||||
The ``net_arch`` parameter allows to specify the amount and size of the hidden layers and how many
|
||||
of them are shared between the policy network and the value network. It is assumed to be a list with the following
|
||||
structure:
|
||||
|
||||
1. An arbitrary length (zero allowed) number of integers each specifying the number of units in a shared layer.
|
||||
If the number of ints is zero, there will be no shared layers.
|
||||
2. An optional dict, to specify the following non-shared layers for the value network and the policy network.
|
||||
It is formatted like ``dict(vf=[<value layer sizes>], pi=[<policy layer sizes>])``.
|
||||
If it is missing any of the keys (pi or vf), no non-shared layers (empty list) is assumed.
|
||||
The ``net_arch`` parameter allows to specify the amount and size of the hidden layers.
|
||||
It can be in either of the following forms:
|
||||
1. ``dict(vf=[<list of layer sizes>], pi=[<list of layer sizes>])``: to specify the amount and size of the layers in the
|
||||
policy and value nets individually. If it is missing any of the keys (pi or vf),
|
||||
zero layers will be considered for that key.
|
||||
2. ``[<list of layer sizes>]``: "shortcut" in case the amount and size of the layers
|
||||
in the policy and value nets are the same. Same as ``dict(vf=int_list, pi=int_list)``
|
||||
where int_list is the same for the actor and critic.
|
||||
|
||||
Deprecation note: shared layers in ``net_arch`` are deprecated, please use separate
|
||||
pi and vf networks (e.g. net_arch=dict(pi=[...], vf=[...]))
|
||||
|
||||
For example to construct a network with one shared layer of size 55 followed by two non-shared layers for the value
|
||||
network of size 255 and a single non-shared layer of size 128 for the policy network, the following layers_spec
|
||||
would be used: ``[55, dict(vf=[255, 255], pi=[128])]``. A simple shared network topology with two layers of size 128
|
||||
would be specified as [128, 128].
|
||||
|
||||
Adapted from Stable Baselines.
|
||||
.. note::
|
||||
If a key is not specified or an empty list is passed ``[]``, a linear network will be used.
|
||||
|
||||
:param feature_dim: Dimension of the feature vector (can be the output of a CNN)
|
||||
:param net_arch: The specification of the policy and value networks.
|
||||
See above for details on its formatting.
|
||||
:param activation_fn: The activation function to use for the networks.
|
||||
:param device:
|
||||
:param device: PyTorch device.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_dim: int,
|
||||
net_arch: Union[Dict[str, List[int]], List[Union[int, Dict[str, List[int]]]]],
|
||||
net_arch: Union[List[int], Dict[str, List[int]]],
|
||||
activation_fn: Type[nn.Module],
|
||||
device: Union[th.device, str] = "auto",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
device = get_device(device)
|
||||
shared_net: List[nn.Module] = []
|
||||
policy_net: List[nn.Module] = []
|
||||
value_net: List[nn.Module] = []
|
||||
policy_only_layers: List[int] = [] # Layer sizes of the network that only belongs to the policy network
|
||||
value_only_layers: List[int] = [] # Layer sizes of the network that only belongs to the value network
|
||||
last_layer_dim_shared = feature_dim
|
||||
last_layer_dim_pi = feature_dim
|
||||
last_layer_dim_vf = feature_dim
|
||||
|
||||
if isinstance(net_arch, list) and len(net_arch) > 0 and isinstance(net_arch[0], int):
|
||||
warnings.warn(
|
||||
(
|
||||
"Shared layers in the mlp_extractor are deprecated and will be removed in SB3 v1.8.0, "
|
||||
"please use separate pi and vf networks "
|
||||
"(e.g. net_arch=dict(pi=[...], vf=[...]))"
|
||||
),
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
# TODO(antonin): update behavior for net_arch=[64, 64]
|
||||
# once shared networks are removed
|
||||
# save dimensions of layers in policy and value nets
|
||||
if isinstance(net_arch, dict):
|
||||
policy_only_layers = net_arch["pi"]
|
||||
value_only_layers = net_arch["vf"]
|
||||
# Note: if key is not specificed, assume linear network
|
||||
pi_layers_dims = net_arch.get("pi", []) # Layer sizes of the policy network
|
||||
vf_layers_dims = net_arch.get("vf", []) # Layer sizes of the value network
|
||||
else:
|
||||
# Iterate through the shared layers and build the shared parts of the network
|
||||
for layer in net_arch:
|
||||
if isinstance(layer, int): # Check that this is a shared layer
|
||||
shared_net.append(nn.Linear(last_layer_dim_shared, layer)) # add linear of size layer
|
||||
shared_net.append(activation_fn())
|
||||
last_layer_dim_shared = layer
|
||||
else:
|
||||
assert isinstance(layer, dict), "Error: the net_arch list can only contain ints and dicts"
|
||||
if "pi" in layer:
|
||||
assert isinstance(layer["pi"], list), "Error: net_arch[-1]['pi'] must contain a list of integers."
|
||||
policy_only_layers = layer["pi"]
|
||||
|
||||
if "vf" in layer:
|
||||
assert isinstance(layer["vf"], list), "Error: net_arch[-1]['vf'] must contain a list of integers."
|
||||
value_only_layers = layer["vf"]
|
||||
break # From here on the network splits up in policy and value network
|
||||
|
||||
last_layer_dim_pi = last_layer_dim_shared
|
||||
last_layer_dim_vf = last_layer_dim_shared
|
||||
|
||||
# Build the non-shared part of the network
|
||||
for pi_layer_size, vf_layer_size in zip_longest(policy_only_layers, value_only_layers):
|
||||
if pi_layer_size is not None:
|
||||
assert isinstance(pi_layer_size, int), "Error: net_arch[-1]['pi'] must only contain integers."
|
||||
policy_net.append(nn.Linear(last_layer_dim_pi, pi_layer_size))
|
||||
policy_net.append(activation_fn())
|
||||
last_layer_dim_pi = pi_layer_size
|
||||
|
||||
if vf_layer_size is not None:
|
||||
assert isinstance(vf_layer_size, int), "Error: net_arch[-1]['vf'] must only contain integers."
|
||||
value_net.append(nn.Linear(last_layer_dim_vf, vf_layer_size))
|
||||
value_net.append(activation_fn())
|
||||
last_layer_dim_vf = vf_layer_size
|
||||
pi_layers_dims = vf_layers_dims = net_arch
|
||||
# Iterate through the policy layers and build the policy net
|
||||
for curr_layer_dim in pi_layers_dims:
|
||||
policy_net.append(nn.Linear(last_layer_dim_pi, curr_layer_dim))
|
||||
policy_net.append(activation_fn())
|
||||
last_layer_dim_pi = curr_layer_dim
|
||||
# Iterate through the value layers and build the value net
|
||||
for curr_layer_dim in vf_layers_dims:
|
||||
value_net.append(nn.Linear(last_layer_dim_vf, curr_layer_dim))
|
||||
value_net.append(activation_fn())
|
||||
last_layer_dim_vf = curr_layer_dim
|
||||
|
||||
# Save dim, used to create the distributions
|
||||
self.latent_dim_pi = last_layer_dim_pi
|
||||
|
|
@ -250,7 +207,6 @@ class MlpExtractor(nn.Module):
|
|||
|
||||
# Create networks
|
||||
# If the list of layers is empty, the network will just act as an Identity module
|
||||
self.shared_net = nn.Sequential(*shared_net).to(device)
|
||||
self.policy_net = nn.Sequential(*policy_net).to(device)
|
||||
self.value_net = nn.Sequential(*value_net).to(device)
|
||||
|
||||
|
|
@ -259,14 +215,13 @@ class MlpExtractor(nn.Module):
|
|||
:return: latent_policy, latent_value of the specified network.
|
||||
If all layers are shared, then ``latent_policy == latent_value``
|
||||
"""
|
||||
shared_latent = self.shared_net(features)
|
||||
return self.policy_net(shared_latent), self.value_net(shared_latent)
|
||||
return self.forward_actor(features), self.forward_critic(features)
|
||||
|
||||
def forward_actor(self, features: th.Tensor) -> th.Tensor:
|
||||
return self.policy_net(self.shared_net(features))
|
||||
return self.policy_net(features)
|
||||
|
||||
def forward_critic(self, features: th.Tensor) -> th.Tensor:
|
||||
return self.value_net(self.shared_net(features))
|
||||
return self.value_net(features)
|
||||
|
||||
|
||||
class CombinedExtractor(BaseFeaturesExtractor):
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
1.8.0a1
|
||||
1.8.0a2
|
||||
|
|
|
|||
|
|
@ -9,21 +9,21 @@ from stable_baselines3.common.sb2_compat.rmsprop_tf_like import RMSpropTFLike
|
|||
"net_arch",
|
||||
[
|
||||
[],
|
||||
dict(vf=[16], pi=[8]),
|
||||
# [<layer_sizes>] behavior will change
|
||||
[4],
|
||||
[4, 4],
|
||||
# All values below are deprecated
|
||||
[12, dict(vf=[16], pi=[8])],
|
||||
[12, dict(vf=[8, 4], pi=[8])],
|
||||
[12, dict(vf=[8], pi=[8, 4])],
|
||||
[12, dict(pi=[8])],
|
||||
dict(vf=[16], pi=[8]),
|
||||
dict(vf=[8, 4], pi=[8]),
|
||||
dict(vf=[8], pi=[8, 4]),
|
||||
dict(pi=[8]),
|
||||
# Old format, emits a warning
|
||||
[dict(vf=[8])],
|
||||
[dict(vf=[8], pi=[4])],
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("model_class", [A2C, PPO])
|
||||
def test_flexible_mlp(model_class, net_arch):
|
||||
if isinstance(net_arch, list) and len(net_arch) > 0 and isinstance(net_arch[0], int):
|
||||
with pytest.warns(DeprecationWarning):
|
||||
if isinstance(net_arch, list) and len(net_arch) > 0 and isinstance(net_arch[0], dict):
|
||||
with pytest.warns(UserWarning):
|
||||
_ = model_class("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=net_arch), n_steps=64).learn(300)
|
||||
else:
|
||||
_ = model_class("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=net_arch), n_steps=64).learn(300)
|
||||
|
|
@ -62,10 +62,3 @@ def test_tf_like_rmsprop_optimizer():
|
|||
def test_dqn_custom_policy():
|
||||
policy_kwargs = dict(optimizer_class=RMSpropTFLike, net_arch=[32])
|
||||
_ = DQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, learning_starts=100).learn(300)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C, PPO])
|
||||
def test_not_shared_features_extractor(model_class):
|
||||
policy_kwargs = dict(net_arch=[12, dict(vf=[16], pi=[8])], share_features_extractor=False)
|
||||
with pytest.raises(ValueError):
|
||||
model_class("MlpPolicy", "Pendulum-v1", policy_kwargs=policy_kwargs)
|
||||
|
|
|
|||
|
|
@ -45,6 +45,8 @@ def test_continuous(model_class):
|
|||
n_actions = 1
|
||||
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
|
||||
kwargs["action_noise"] = action_noise
|
||||
elif model_class in [A2C]:
|
||||
kwargs["policy_kwargs"]["log_std_init"] = -0.5
|
||||
|
||||
model = model_class("MlpPolicy", env, **kwargs).learn(n_steps)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue