mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-17 21:20:11 +00:00
* Replace "nature" with "Nature" (magazine) to reduce confusion * Replace "nature" with "Nature" (magazine) to reduce confusion * Update changelog Co-authored-by: mel <callmesolis@gmail.com>
317 lines
14 KiB
Python
317 lines
14 KiB
Python
from itertools import zip_longest
|
|
from typing import Dict, List, Tuple, Type, Union
|
|
|
|
import gym
|
|
import torch as th
|
|
from torch import nn
|
|
|
|
from stable_baselines3.common.preprocessing import get_flattened_obs_dim, is_image_space
|
|
from stable_baselines3.common.type_aliases import TensorDict
|
|
from stable_baselines3.common.utils import get_device
|
|
|
|
|
|
class BaseFeaturesExtractor(nn.Module):
|
|
"""
|
|
Base class that represents a features extractor.
|
|
|
|
:param observation_space:
|
|
:param features_dim: Number of features extracted.
|
|
"""
|
|
|
|
def __init__(self, observation_space: gym.Space, features_dim: int = 0):
|
|
super().__init__()
|
|
assert features_dim > 0
|
|
self._observation_space = observation_space
|
|
self._features_dim = features_dim
|
|
|
|
@property
|
|
def features_dim(self) -> int:
|
|
return self._features_dim
|
|
|
|
def forward(self, observations: th.Tensor) -> th.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
|
|
class FlattenExtractor(BaseFeaturesExtractor):
|
|
"""
|
|
Feature extract that flatten the input.
|
|
Used as a placeholder when feature extraction is not needed.
|
|
|
|
:param observation_space:
|
|
"""
|
|
|
|
def __init__(self, observation_space: gym.Space):
|
|
super().__init__(observation_space, get_flattened_obs_dim(observation_space))
|
|
self.flatten = nn.Flatten()
|
|
|
|
def forward(self, observations: th.Tensor) -> th.Tensor:
|
|
return self.flatten(observations)
|
|
|
|
|
|
class NatureCNN(BaseFeaturesExtractor):
|
|
"""
|
|
CNN from DQN Nature paper:
|
|
Mnih, Volodymyr, et al.
|
|
"Human-level control through deep reinforcement learning."
|
|
Nature 518.7540 (2015): 529-533.
|
|
|
|
:param observation_space:
|
|
:param features_dim: Number of features extracted.
|
|
This corresponds to the number of unit for the last layer.
|
|
"""
|
|
|
|
def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512):
|
|
super().__init__(observation_space, features_dim)
|
|
# We assume CxHxW images (channels first)
|
|
# Re-ordering will be done by pre-preprocessing or wrapper
|
|
assert is_image_space(observation_space, check_channels=False), (
|
|
"You should use NatureCNN "
|
|
f"only with images not with {observation_space}\n"
|
|
"(you are probably using `CnnPolicy` instead of `MlpPolicy` or `MultiInputPolicy`)\n"
|
|
"If you are using a custom environment,\n"
|
|
"please check it using our env checker:\n"
|
|
"https://stable-baselines3.readthedocs.io/en/master/common/env_checker.html"
|
|
)
|
|
n_input_channels = observation_space.shape[0]
|
|
self.cnn = nn.Sequential(
|
|
nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
|
|
nn.ReLU(),
|
|
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
|
|
nn.ReLU(),
|
|
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
|
|
nn.ReLU(),
|
|
nn.Flatten(),
|
|
)
|
|
|
|
# Compute shape by doing one forward pass
|
|
with th.no_grad():
|
|
n_flatten = self.cnn(th.as_tensor(observation_space.sample()[None]).float()).shape[1]
|
|
|
|
self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
|
|
|
|
def forward(self, observations: th.Tensor) -> th.Tensor:
|
|
return self.linear(self.cnn(observations))
|
|
|
|
|
|
def create_mlp(
|
|
input_dim: int,
|
|
output_dim: int,
|
|
net_arch: List[int],
|
|
activation_fn: Type[nn.Module] = nn.ReLU,
|
|
squash_output: bool = False,
|
|
) -> List[nn.Module]:
|
|
"""
|
|
Create a multi layer perceptron (MLP), which is
|
|
a collection of fully-connected layers each followed by an activation function.
|
|
|
|
:param input_dim: Dimension of the input vector
|
|
:param output_dim:
|
|
:param net_arch: Architecture of the neural net
|
|
It represents the number of units per layer.
|
|
The length of this list is the number of layers.
|
|
:param activation_fn: The activation function
|
|
to use after each layer.
|
|
:param squash_output: Whether to squash the output using a Tanh
|
|
activation function
|
|
:return:
|
|
"""
|
|
|
|
if len(net_arch) > 0:
|
|
modules = [nn.Linear(input_dim, net_arch[0]), activation_fn()]
|
|
else:
|
|
modules = []
|
|
|
|
for idx in range(len(net_arch) - 1):
|
|
modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1]))
|
|
modules.append(activation_fn())
|
|
|
|
if output_dim > 0:
|
|
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else input_dim
|
|
modules.append(nn.Linear(last_layer_dim, output_dim))
|
|
if squash_output:
|
|
modules.append(nn.Tanh())
|
|
return modules
|
|
|
|
|
|
class MlpExtractor(nn.Module):
|
|
"""
|
|
Constructs an MLP that receives the output from a previous feature extractor (i.e. a CNN) or directly
|
|
the observations (if no feature extractor is applied) as an input and outputs a latent representation
|
|
for the policy and a value network.
|
|
The ``net_arch`` parameter allows to specify the amount and size of the hidden layers and how many
|
|
of them are shared between the policy network and the value network. It is assumed to be a list with the following
|
|
structure:
|
|
|
|
1. An arbitrary length (zero allowed) number of integers each specifying the number of units in a shared layer.
|
|
If the number of ints is zero, there will be no shared layers.
|
|
2. An optional dict, to specify the following non-shared layers for the value network and the policy network.
|
|
It is formatted like ``dict(vf=[<value layer sizes>], pi=[<policy layer sizes>])``.
|
|
If it is missing any of the keys (pi or vf), no non-shared layers (empty list) is assumed.
|
|
|
|
For example to construct a network with one shared layer of size 55 followed by two non-shared layers for the value
|
|
network of size 255 and a single non-shared layer of size 128 for the policy network, the following layers_spec
|
|
would be used: ``[55, dict(vf=[255, 255], pi=[128])]``. A simple shared network topology with two layers of size 128
|
|
would be specified as [128, 128].
|
|
|
|
Adapted from Stable Baselines.
|
|
|
|
:param feature_dim: Dimension of the feature vector (can be the output of a CNN)
|
|
:param net_arch: The specification of the policy and value networks.
|
|
See above for details on its formatting.
|
|
:param activation_fn: The activation function to use for the networks.
|
|
:param device:
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
feature_dim: int,
|
|
net_arch: List[Union[int, Dict[str, List[int]]]],
|
|
activation_fn: Type[nn.Module],
|
|
device: Union[th.device, str] = "auto",
|
|
):
|
|
super().__init__()
|
|
device = get_device(device)
|
|
shared_net, policy_net, value_net = [], [], []
|
|
policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network
|
|
value_only_layers = [] # Layer sizes of the network that only belongs to the value network
|
|
last_layer_dim_shared = feature_dim
|
|
|
|
# Iterate through the shared layers and build the shared parts of the network
|
|
for layer in net_arch:
|
|
if isinstance(layer, int): # Check that this is a shared layer
|
|
# TODO: give layer a meaningful name
|
|
shared_net.append(nn.Linear(last_layer_dim_shared, layer)) # add linear of size layer
|
|
shared_net.append(activation_fn())
|
|
last_layer_dim_shared = layer
|
|
else:
|
|
assert isinstance(layer, dict), "Error: the net_arch list can only contain ints and dicts"
|
|
if "pi" in layer:
|
|
assert isinstance(layer["pi"], list), "Error: net_arch[-1]['pi'] must contain a list of integers."
|
|
policy_only_layers = layer["pi"]
|
|
|
|
if "vf" in layer:
|
|
assert isinstance(layer["vf"], list), "Error: net_arch[-1]['vf'] must contain a list of integers."
|
|
value_only_layers = layer["vf"]
|
|
break # From here on the network splits up in policy and value network
|
|
|
|
last_layer_dim_pi = last_layer_dim_shared
|
|
last_layer_dim_vf = last_layer_dim_shared
|
|
|
|
# Build the non-shared part of the network
|
|
for pi_layer_size, vf_layer_size in zip_longest(policy_only_layers, value_only_layers):
|
|
if pi_layer_size is not None:
|
|
assert isinstance(pi_layer_size, int), "Error: net_arch[-1]['pi'] must only contain integers."
|
|
policy_net.append(nn.Linear(last_layer_dim_pi, pi_layer_size))
|
|
policy_net.append(activation_fn())
|
|
last_layer_dim_pi = pi_layer_size
|
|
|
|
if vf_layer_size is not None:
|
|
assert isinstance(vf_layer_size, int), "Error: net_arch[-1]['vf'] must only contain integers."
|
|
value_net.append(nn.Linear(last_layer_dim_vf, vf_layer_size))
|
|
value_net.append(activation_fn())
|
|
last_layer_dim_vf = vf_layer_size
|
|
|
|
# Save dim, used to create the distributions
|
|
self.latent_dim_pi = last_layer_dim_pi
|
|
self.latent_dim_vf = last_layer_dim_vf
|
|
|
|
# Create networks
|
|
# If the list of layers is empty, the network will just act as an Identity module
|
|
self.shared_net = nn.Sequential(*shared_net).to(device)
|
|
self.policy_net = nn.Sequential(*policy_net).to(device)
|
|
self.value_net = nn.Sequential(*value_net).to(device)
|
|
|
|
def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
|
"""
|
|
:return: latent_policy, latent_value of the specified network.
|
|
If all layers are shared, then ``latent_policy == latent_value``
|
|
"""
|
|
shared_latent = self.shared_net(features)
|
|
return self.policy_net(shared_latent), self.value_net(shared_latent)
|
|
|
|
def forward_actor(self, features: th.Tensor) -> th.Tensor:
|
|
return self.policy_net(self.shared_net(features))
|
|
|
|
def forward_critic(self, features: th.Tensor) -> th.Tensor:
|
|
return self.value_net(self.shared_net(features))
|
|
|
|
|
|
class CombinedExtractor(BaseFeaturesExtractor):
|
|
"""
|
|
Combined feature extractor for Dict observation spaces.
|
|
Builds a feature extractor for each key of the space. Input from each space
|
|
is fed through a separate submodule (CNN or MLP, depending on input shape),
|
|
the output features are concatenated and fed through additional MLP network ("combined").
|
|
|
|
:param observation_space:
|
|
:param cnn_output_dim: Number of features to output from each CNN submodule(s). Defaults to
|
|
256 to avoid exploding network sizes.
|
|
"""
|
|
|
|
def __init__(self, observation_space: gym.spaces.Dict, cnn_output_dim: int = 256):
|
|
# TODO we do not know features-dim here before going over all the items, so put something there. This is dirty!
|
|
super().__init__(observation_space, features_dim=1)
|
|
|
|
extractors = {}
|
|
|
|
total_concat_size = 0
|
|
for key, subspace in observation_space.spaces.items():
|
|
if is_image_space(subspace):
|
|
extractors[key] = NatureCNN(subspace, features_dim=cnn_output_dim)
|
|
total_concat_size += cnn_output_dim
|
|
else:
|
|
# The observation key is a vector, flatten it if needed
|
|
extractors[key] = nn.Flatten()
|
|
total_concat_size += get_flattened_obs_dim(subspace)
|
|
|
|
self.extractors = nn.ModuleDict(extractors)
|
|
|
|
# Update the features dim manually
|
|
self._features_dim = total_concat_size
|
|
|
|
def forward(self, observations: TensorDict) -> th.Tensor:
|
|
encoded_tensor_list = []
|
|
|
|
for key, extractor in self.extractors.items():
|
|
encoded_tensor_list.append(extractor(observations[key]))
|
|
return th.cat(encoded_tensor_list, dim=1)
|
|
|
|
|
|
def get_actor_critic_arch(net_arch: Union[List[int], Dict[str, List[int]]]) -> Tuple[List[int], List[int]]:
|
|
"""
|
|
Get the actor and critic network architectures for off-policy actor-critic algorithms (SAC, TD3, DDPG).
|
|
|
|
The ``net_arch`` parameter allows to specify the amount and size of the hidden layers,
|
|
which can be different for the actor and the critic.
|
|
It is assumed to be a list of ints or a dict.
|
|
|
|
1. If it is a list, actor and critic networks will have the same architecture.
|
|
The architecture is represented by a list of integers (of arbitrary length (zero allowed))
|
|
each specifying the number of units per layer.
|
|
If the number of ints is zero, the network will be linear.
|
|
2. If it is a dict, it should have the following structure:
|
|
``dict(qf=[<critic network architecture>], pi=[<actor network architecture>])``.
|
|
where the network architecture is a list as described in 1.
|
|
|
|
For example, to have actor and critic that share the same network architecture,
|
|
you only need to specify ``net_arch=[256, 256]`` (here, two hidden layers of 256 units each).
|
|
|
|
If you want a different architecture for the actor and the critic,
|
|
then you can specify ``net_arch=dict(qf=[400, 300], pi=[64, 64])``.
|
|
|
|
.. note::
|
|
Compared to their on-policy counterparts, no shared layers (other than the features extractor)
|
|
between the actor and the critic are allowed (to prevent issues with target networks).
|
|
|
|
:param net_arch: The specification of the actor and critic networks.
|
|
See above for details on its formatting.
|
|
:return: The network architectures for the actor and the critic
|
|
"""
|
|
if isinstance(net_arch, list):
|
|
actor_arch, critic_arch = net_arch, net_arch
|
|
else:
|
|
assert isinstance(net_arch, dict), "Error: the net_arch can only contain be a list of ints or a dict"
|
|
assert "pi" in net_arch, "Error: no key 'pi' was provided in net_arch for the actor network"
|
|
assert "qf" in net_arch, "Error: no key 'qf' was provided in net_arch for the critic network"
|
|
actor_arch, critic_arch = net_arch["pi"], net_arch["qf"]
|
|
return actor_arch, critic_arch
|