stable-baselines3/stable_baselines3/common/torch_layers.py
Anssi 44f8218df0
Review of code (A2C, PPO and refactoring) (#35)
* Split torch module code into torch_layers file

* Updated reference to CNN

* Change 'CxWxH' to 'CxHxW', as per common notion

* Fix missing import in policies.py

* Move PPOPolicy to OnlineActorCriticPolicy

* Create OnPolicyRLModel from PPO, and make A2C and PPO inherit

* Update A2C optimizer comment

* Clean weight init scales for clarity

* Fix A2C log_interval default parameter

* Rename 'progress' to 'progress_remaining

* Rename 'Models' to 'Algorithms'

* Rename 'OnlineActorCriticPolicy' to 'ActorCriticPolicy'

* Move static functions out from BaseAlgorithm

* Move on/off_policy base algorithms to their own files

* Add  files for A2C/PPO

* Fix docs

* Fix pytype

* Update documentation on OnPolicyAlgorithm

* Add proper doctstring for on_policy rollout gathering

* Add bit clarification on the mlppolicy/cnnpolicy naming

* Move static function is_vectorized_policies to utils.py

* Checking docstrings, pep8 fixes

* Update changelog

* Clean changelog

* Remove policy warnings for sac/td3

* Add monitor_wrapper for OnPolicyAlgorithm. Clean tb logging variables. Add parameter keywords to OffPolicyAlgorithm super init

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
2020-06-09 13:54:18 +02:00

218 lines
9.7 KiB
Python

from typing import Union, Type, Dict, List, Tuple
from itertools import zip_longest
import gym
import torch as th
import torch.nn as nn
from stable_baselines3.common.preprocessing import get_flattened_obs_dim, is_image_space
from stable_baselines3.common.utils import get_device
class BaseFeaturesExtractor(nn.Module):
"""
Base class that represents a features extractor.
:param observation_space: (gym.Space)
:param features_dim: (int) Number of features extracted.
"""
def __init__(self, observation_space: gym.Space, features_dim: int = 0):
super(BaseFeaturesExtractor, self).__init__()
assert features_dim > 0
self._observation_space = observation_space
self._features_dim = features_dim
@property
def features_dim(self) -> int:
return self._features_dim
def forward(self, observations: th.Tensor) -> th.Tensor:
raise NotImplementedError()
class FlattenExtractor(BaseFeaturesExtractor):
"""
Feature extract that flatten the input.
Used as a placeholder when feature extraction is not needed.
:param observation_space: (gym.Space)
"""
def __init__(self, observation_space: gym.Space):
super(FlattenExtractor, self).__init__(observation_space, get_flattened_obs_dim(observation_space))
self.flatten = nn.Flatten()
def forward(self, observations: th.Tensor) -> th.Tensor:
return self.flatten(observations)
class NatureCNN(BaseFeaturesExtractor):
"""
CNN from DQN nature paper:
Mnih, Volodymyr, et al.
"Human-level control through deep reinforcement learning."
Nature 518.7540 (2015): 529-533.
:param observation_space: (gym.Space)
:param features_dim: (int) Number of features extracted.
This corresponds to the number of unit for the last layer.
"""
def __init__(self, observation_space: gym.spaces.Box,
features_dim: int = 512):
super(NatureCNN, self).__init__(observation_space, features_dim)
# We assume CxHxW images (channels first)
# Re-ordering will be done by pre-preprocessing or wrapper
assert is_image_space(observation_space), ('You should use NatureCNN '
f'only with images not with {observation_space} '
'(you are probably using `CnnPolicy` instead of `MlpPolicy`)')
n_input_channels = observation_space.shape[0]
self.cnn = nn.Sequential(nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=0),
nn.ReLU(),
nn.Flatten())
# Compute shape by doing one forward pass
with th.no_grad():
n_flatten = self.cnn(th.as_tensor(observation_space.sample()[None]).float()).shape[1]
self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
def forward(self, observations: th.Tensor) -> th.Tensor:
return self.linear(self.cnn(observations))
def create_mlp(input_dim: int,
output_dim: int,
net_arch: List[int],
activation_fn: Type[nn.Module] = nn.ReLU,
squash_output: bool = False) -> List[nn.Module]:
"""
Create a multi layer perceptron (MLP), which is
a collection of fully-connected layers each followed by an activation function.
:param input_dim: (int) Dimension of the input vector
:param output_dim: (int)
:param net_arch: (List[int]) Architecture of the neural net
It represents the number of units per layer.
The length of this list is the number of layers.
:param activation_fn: (Type[nn.Module]) The activation function
to use after each layer.
:param squash_output: (bool) Whether to squash the output using a Tanh
activation function
:return: (List[nn.Module])
"""
if len(net_arch) > 0:
modules = [nn.Linear(input_dim, net_arch[0]), activation_fn()]
else:
modules = []
for idx in range(len(net_arch) - 1):
modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1]))
modules.append(activation_fn())
if output_dim > 0:
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else input_dim
modules.append(nn.Linear(last_layer_dim, output_dim))
if squash_output:
modules.append(nn.Tanh())
return modules
class MlpExtractor(nn.Module):
"""
Constructs an MLP that receives observations as an input and outputs a latent representation for the policy and
a value network. The ``net_arch`` parameter allows to specify the amount and size of the hidden layers and how many
of them are shared between the policy network and the value network. It is assumed to be a list with the following
structure:
1. An arbitrary length (zero allowed) number of integers each specifying the number of units in a shared layer.
If the number of ints is zero, there will be no shared layers.
2. An optional dict, to specify the following non-shared layers for the value network and the policy network.
It is formatted like ``dict(vf=[<value layer sizes>], pi=[<policy layer sizes>])``.
If it is missing any of the keys (pi or vf), no non-shared layers (empty list) is assumed.
For example to construct a network with one shared layer of size 55 followed by two non-shared layers for the value
network of size 255 and a single non-shared layer of size 128 for the policy network, the following layers_spec
would be used: ``[55, dict(vf=[255, 255], pi=[128])]``. A simple shared network topology with two layers of size 128
would be specified as [128, 128].
Adapted from Stable Baselines.
:param feature_dim: (int) Dimension of the feature vector (can be the output of a CNN)
:param net_arch: ([int or dict]) The specification of the policy and value networks.
See above for details on its formatting.
:param activation_fn: (Type[nn.Module]) The activation function to use for the networks.
:param device: (th.device)
"""
def __init__(self, feature_dim: int,
net_arch: List[Union[int, Dict[str, List[int]]]],
activation_fn: Type[nn.Module],
device: Union[th.device, str] = 'auto'):
super(MlpExtractor, self).__init__()
device = get_device(device)
shared_net, policy_net, value_net = [], [], []
policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network
value_only_layers = [] # Layer sizes of the network that only belongs to the value network
last_layer_dim_shared = feature_dim
# Iterate through the shared layers and build the shared parts of the network
for idx, layer in enumerate(net_arch):
if isinstance(layer, int): # Check that this is a shared layer
layer_size = layer
# TODO: give layer a meaningful name
shared_net.append(nn.Linear(last_layer_dim_shared, layer_size))
shared_net.append(activation_fn())
last_layer_dim_shared = layer_size
else:
assert isinstance(layer, dict), "Error: the net_arch list can only contain ints and dicts"
if 'pi' in layer:
assert isinstance(layer['pi'], list), "Error: net_arch[-1]['pi'] must contain a list of integers."
policy_only_layers = layer['pi']
if 'vf' in layer:
assert isinstance(layer['vf'], list), "Error: net_arch[-1]['vf'] must contain a list of integers."
value_only_layers = layer['vf']
break # From here on the network splits up in policy and value network
last_layer_dim_pi = last_layer_dim_shared
last_layer_dim_vf = last_layer_dim_shared
# Build the non-shared part of the network
for idx, (pi_layer_size, vf_layer_size) in enumerate(zip_longest(policy_only_layers, value_only_layers)):
if pi_layer_size is not None:
assert isinstance(pi_layer_size, int), "Error: net_arch[-1]['pi'] must only contain integers."
policy_net.append(nn.Linear(last_layer_dim_pi, pi_layer_size))
policy_net.append(activation_fn())
last_layer_dim_pi = pi_layer_size
if vf_layer_size is not None:
assert isinstance(vf_layer_size, int), "Error: net_arch[-1]['vf'] must only contain integers."
value_net.append(nn.Linear(last_layer_dim_vf, vf_layer_size))
value_net.append(activation_fn())
last_layer_dim_vf = vf_layer_size
# Save dim, used to create the distributions
self.latent_dim_pi = last_layer_dim_pi
self.latent_dim_vf = last_layer_dim_vf
# Create networks
# If the list of layers is empty, the network will just act as an Identity module
self.shared_net = nn.Sequential(*shared_net).to(device)
self.policy_net = nn.Sequential(*policy_net).to(device)
self.value_net = nn.Sequential(*value_net).to(device)
def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
"""
:return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network.
If all layers are shared, then ``latent_policy == latent_value``
"""
shared_latent = self.shared_net(features)
return self.policy_net(shared_latent), self.value_net(shared_latent)