mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-31 23:28:05 +00:00
572 lines
25 KiB
Python
572 lines
25 KiB
Python
from typing import Union, Type, Dict, List, Tuple, Optional, Any
|
|
|
|
from itertools import zip_longest
|
|
|
|
import gym
|
|
import torch as th
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
|
|
from stable_baselines3.common.preprocessing import preprocess_obs, get_flattened_obs_dim, is_image_space
|
|
from stable_baselines3.common.utils import get_device
|
|
from stable_baselines3.common.vec_env import VecTransposeImage
|
|
|
|
|
|
class BaseFeaturesExtractor(nn.Module):
|
|
"""
|
|
Base class that represents a features extractor.
|
|
|
|
:param observation_space: (gym.Space)
|
|
:param features_dim: (int) Number of features extracted.
|
|
"""
|
|
|
|
def __init__(self, observation_space: gym.Space, features_dim: int = 0):
|
|
super(BaseFeaturesExtractor, self).__init__()
|
|
assert features_dim > 0
|
|
self._observation_space = observation_space
|
|
self._features_dim = features_dim
|
|
|
|
@property
|
|
def features_dim(self) -> int:
|
|
return self._features_dim
|
|
|
|
def forward(self, observations: th.Tensor) -> th.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
|
|
class FlattenExtractor(BaseFeaturesExtractor):
|
|
"""
|
|
Feature extract that flatten the input.
|
|
Used as a placeholder when feature extraction is not needed.
|
|
|
|
:param observation_space: (gym.Space)
|
|
"""
|
|
|
|
def __init__(self, observation_space: gym.Space):
|
|
super(FlattenExtractor, self).__init__(observation_space, get_flattened_obs_dim(observation_space))
|
|
self.flatten = nn.Flatten()
|
|
|
|
def forward(self, observations: th.Tensor) -> th.Tensor:
|
|
return self.flatten(observations)
|
|
|
|
|
|
class NatureCNN(BaseFeaturesExtractor):
|
|
"""
|
|
CNN from DQN nature paper: https://arxiv.org/abs/1312.5602
|
|
|
|
:param observation_space: (gym.Space)
|
|
:param features_dim: (int) Number of features extracted.
|
|
This corresponds to the number of unit for the last layer.
|
|
"""
|
|
|
|
def __init__(self, observation_space: gym.spaces.Box,
|
|
features_dim: int = 512):
|
|
super(NatureCNN, self).__init__(observation_space, features_dim)
|
|
# We assume CxWxH images (channels first)
|
|
# Re-ordering will be done by pre-preprocessing or wrapper
|
|
assert is_image_space(observation_space), ('You should use NatureCNN '
|
|
f'only with images not with {observation_space} '
|
|
'(you are probably using `CnnPolicy` instead of `MlpPolicy`)')
|
|
n_input_channels = observation_space.shape[0]
|
|
self.cnn = nn.Sequential(nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
|
|
nn.ReLU(),
|
|
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
|
|
nn.ReLU(),
|
|
nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=0),
|
|
nn.ReLU(),
|
|
nn.Flatten())
|
|
|
|
# Compute shape by doing one forward pass
|
|
with th.no_grad():
|
|
n_flatten = self.cnn(th.as_tensor(observation_space.sample()[None]).float()).shape[1]
|
|
|
|
self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
|
|
|
|
def forward(self, observations: th.Tensor) -> th.Tensor:
|
|
return self.linear(self.cnn(observations))
|
|
|
|
|
|
class BasePolicy(nn.Module):
|
|
"""
|
|
The base policy object
|
|
|
|
:param observation_space: (gym.spaces.Space) The observation space of the environment
|
|
:param action_space: (gym.spaces.Space) The action space of the environment
|
|
:param device: (Union[th.device, str]) Device on which the code should run.
|
|
:param squash_output: (bool) For continuous actions, whether the output is squashed
|
|
or not using a ``tanh()`` function.
|
|
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
|
|
:param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments
|
|
to pass to the feature extractor.
|
|
:param features_extractor: (nn.Module) Network to extract features
|
|
(a CNN when using images, a nn.Flatten() layer otherwise)
|
|
:param normalize_images: (bool) Whether to normalize images or not,
|
|
dividing by 255.0 (True by default)
|
|
:param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use,
|
|
``th.optim.Adam`` by default
|
|
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
|
|
excluding the learning rate, to pass to the optimizer
|
|
"""
|
|
|
|
def __init__(self, observation_space: gym.spaces.Space,
|
|
action_space: gym.spaces.Space,
|
|
device: Union[th.device, str] = 'auto',
|
|
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
|
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
|
features_extractor: Optional[nn.Module] = None,
|
|
normalize_images: bool = True,
|
|
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
|
squash_output: bool = False):
|
|
super(BasePolicy, self).__init__()
|
|
|
|
if optimizer_kwargs is None:
|
|
optimizer_kwargs = {}
|
|
|
|
if features_extractor_kwargs is None:
|
|
features_extractor_kwargs = {}
|
|
|
|
self.observation_space = observation_space
|
|
self.action_space = action_space
|
|
self.device = get_device(device)
|
|
self.features_extractor = features_extractor
|
|
self.normalize_images = normalize_images
|
|
self._squash_output = squash_output
|
|
|
|
self.optimizer_class = optimizer_class
|
|
self.optimizer_kwargs = optimizer_kwargs
|
|
self.optimizer = None # type: Optional[th.optim.Optimizer]
|
|
|
|
self.features_extractor_class = features_extractor_class
|
|
self.features_extractor_kwargs = features_extractor_kwargs
|
|
|
|
def extract_features(self, obs: th.Tensor) -> th.Tensor:
|
|
"""
|
|
Preprocess the observation if needed and extract features.
|
|
|
|
:param obs: (th.Tensor)
|
|
:return: (th.Tensor)
|
|
"""
|
|
assert self.features_extractor is not None, 'No feature extractor was set'
|
|
preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images)
|
|
return self.features_extractor(preprocessed_obs)
|
|
|
|
@property
|
|
def squash_output(self) -> bool:
|
|
""" (bool) Getter for squash_output."""
|
|
return self._squash_output
|
|
|
|
@staticmethod
|
|
def init_weights(module: nn.Module, gain: float = 1) -> None:
|
|
"""
|
|
Orthogonal initialization (used in PPO and A2C)
|
|
"""
|
|
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
|
nn.init.orthogonal_(module.weight, gain=gain)
|
|
module.bias.data.fill_(0.0)
|
|
|
|
@staticmethod
|
|
def _dummy_schedule(_progress: float) -> float:
|
|
""" (float) Useful for pickling policy."""
|
|
return 0.0
|
|
|
|
def forward(self, *_args, **kwargs):
|
|
raise NotImplementedError()
|
|
|
|
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
|
"""
|
|
Get the action according to the policy for a given observation.
|
|
|
|
:param observation: (th.Tensor)
|
|
:param deterministic: (bool) Whether to use stochastic or deterministic actions
|
|
:return: (th.Tensor) Taken action according to the policy
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
def predict(self, observation: np.ndarray,
|
|
state: Optional[np.ndarray] = None,
|
|
mask: Optional[np.ndarray] = None,
|
|
deterministic: bool = False) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
|
"""
|
|
Get the policy action and state from an observation (and optional state).
|
|
|
|
:param observation: (np.ndarray) the input observation
|
|
:param state: (Optional[np.ndarray]) The last states (can be None, used in recurrent policies)
|
|
:param mask: (Optional[np.ndarray]) The last masks (can be None, used in recurrent policies)
|
|
:param deterministic: (bool) Whether or not to return deterministic actions.
|
|
:return: (Tuple[np.ndarray, Optional[np.ndarray]]) the model's action and the next state
|
|
(used in recurrent policies)
|
|
"""
|
|
# if state is None:
|
|
# state = self.initial_state
|
|
# if mask is None:
|
|
# mask = [False for _ in range(self.n_envs)]
|
|
observation = np.array(observation)
|
|
|
|
# Handle the different cases for images
|
|
# as PyTorch use channel first format
|
|
if is_image_space(self.observation_space):
|
|
if (observation.shape == self.observation_space.shape or
|
|
observation.shape[1:] == self.observation_space.shape):
|
|
pass
|
|
else:
|
|
# Try to re-order the channels
|
|
transpose_obs = VecTransposeImage.transpose_image(observation)
|
|
if (transpose_obs.shape == self.observation_space.shape
|
|
or transpose_obs.shape[1:] == self.observation_space.shape):
|
|
observation = transpose_obs
|
|
|
|
vectorized_env = self._is_vectorized_observation(observation, self.observation_space)
|
|
|
|
observation = observation.reshape((-1,) + self.observation_space.shape)
|
|
|
|
observation = th.as_tensor(observation).to(self.device)
|
|
with th.no_grad():
|
|
actions = self._predict(observation, deterministic=deterministic)
|
|
# Convert to numpy
|
|
actions = actions.cpu().numpy()
|
|
|
|
# Rescale to proper domain when using squashing
|
|
if isinstance(self.action_space, gym.spaces.Box) and self.squash_output:
|
|
actions = self.unscale_action(actions)
|
|
|
|
clipped_actions = actions
|
|
# Clip the actions to avoid out of bound error when using gaussian distribution
|
|
if isinstance(self.action_space, gym.spaces.Box) and not self.squash_output:
|
|
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
|
|
|
|
if not vectorized_env:
|
|
if state is not None:
|
|
raise ValueError("Error: The environment must be vectorized when using recurrent policies.")
|
|
clipped_actions = clipped_actions[0]
|
|
|
|
return clipped_actions, state
|
|
|
|
def scale_action(self, action: np.ndarray) -> np.ndarray:
|
|
"""
|
|
Rescale the action from [low, high] to [-1, 1]
|
|
(no need for symmetric action space)
|
|
|
|
:param action: (np.ndarray) Action to scale
|
|
:return: (np.ndarray) Scaled action
|
|
"""
|
|
low, high = self.action_space.low, self.action_space.high
|
|
return 2.0 * ((action - low) / (high - low)) - 1.0
|
|
|
|
def unscale_action(self, scaled_action: np.ndarray) -> np.ndarray:
|
|
"""
|
|
Rescale the action from [-1, 1] to [low, high]
|
|
(no need for symmetric action space)
|
|
|
|
:param scaled_action: Action to un-scale
|
|
"""
|
|
low, high = self.action_space.low, self.action_space.high
|
|
return low + (0.5 * (scaled_action + 1.0) * (high - low))
|
|
|
|
@staticmethod
|
|
def _is_vectorized_observation(observation: np.ndarray, observation_space: gym.spaces.Space) -> bool:
|
|
"""
|
|
For every observation type, detects and validates the shape,
|
|
then returns whether or not the observation is vectorized.
|
|
|
|
:param observation: (np.ndarray) the input observation to validate
|
|
:param observation_space: (gym.spaces) the observation space
|
|
:return: (bool) whether the given observation is vectorized or not
|
|
"""
|
|
if isinstance(observation_space, gym.spaces.Box):
|
|
if observation.shape == observation_space.shape:
|
|
return False
|
|
elif observation.shape[1:] == observation_space.shape:
|
|
return True
|
|
else:
|
|
raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) +
|
|
"Box environment, please use {} ".format(observation_space.shape) +
|
|
"or (n_env, {}) for the observation shape."
|
|
.format(", ".join(map(str, observation_space.shape))))
|
|
elif isinstance(observation_space, gym.spaces.Discrete):
|
|
if observation.shape == (): # A numpy array of a number, has shape empty tuple '()'
|
|
return False
|
|
elif len(observation.shape) == 1:
|
|
return True
|
|
else:
|
|
raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) +
|
|
"Discrete environment, please use (1,) or (n_env, 1) for the observation shape.")
|
|
# TODO: add support for MultiDiscrete and MultiBinary observation spaces
|
|
# elif isinstance(observation_space, gym.spaces.MultiDiscrete):
|
|
# if observation.shape == (len(observation_space.nvec),):
|
|
# return False
|
|
# elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec):
|
|
# return True
|
|
# else:
|
|
# raise ValueError("Error: Unexpected observation shape {} for MultiDiscrete ".format(observation.shape) +
|
|
# "environment, please use ({},) or ".format(len(observation_space.nvec)) +
|
|
# "(n_env, {}) for the observation shape.".format(len(observation_space.nvec)))
|
|
# elif isinstance(observation_space, gym.spaces.MultiBinary):
|
|
# if observation.shape == (observation_space.n,):
|
|
# return False
|
|
# elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n:
|
|
# return True
|
|
# else:
|
|
# raise ValueError("Error: Unexpected observation shape {} for MultiBinary ".format(observation.shape) +
|
|
# "environment, please use ({},) or ".format(observation_space.n) +
|
|
# "(n_env, {}) for the observation shape.".format(observation_space.n))
|
|
else:
|
|
raise ValueError("Error: Cannot determine if the observation is vectorized with the space type {}."
|
|
.format(observation_space))
|
|
|
|
def _get_data(self) -> Dict[str, Any]:
|
|
"""
|
|
Get data that need to be saved in order to re-create the policy.
|
|
This corresponds to the arguments of the constructor.
|
|
|
|
:return: (Dict[str, Any])
|
|
"""
|
|
return dict(
|
|
observation_space=self.observation_space,
|
|
action_space=self.action_space,
|
|
# Passed to the constructor by child class
|
|
# squash_output=self.squash_output,
|
|
# features_extractor=self.features_extractor
|
|
normalize_images=self.normalize_images,
|
|
)
|
|
|
|
def save(self, path: str) -> None:
|
|
"""
|
|
Save policy to a given location.
|
|
|
|
:param path: (str)
|
|
"""
|
|
th.save({'state_dict': self.state_dict(), 'data': self._get_data()}, path)
|
|
|
|
@classmethod
|
|
def load(cls, path: str, device: Union[th.device, str] = 'auto') -> 'BasePolicy':
|
|
"""
|
|
Load policy from path.
|
|
|
|
:param path: (str)
|
|
:param device: ( Union[th.device, str]) Device on which the policy should be loaded.
|
|
:return: (BasePolicy)
|
|
"""
|
|
device = get_device(device)
|
|
saved_variables = th.load(path, map_location=device)
|
|
# Create policy object
|
|
model = cls(**saved_variables['data'])
|
|
# Load weights
|
|
model.load_state_dict(saved_variables['state_dict'])
|
|
model.to(device)
|
|
return model
|
|
|
|
def load_from_vector(self, vector: np.ndarray):
|
|
"""
|
|
Load parameters from a 1D vector.
|
|
|
|
:param vector: (np.ndarray)
|
|
"""
|
|
th.nn.utils.vector_to_parameters(th.FloatTensor(vector).to(self.device), self.parameters())
|
|
|
|
def parameters_to_vector(self) -> np.ndarray:
|
|
"""
|
|
Convert the parameters to a 1D vector.
|
|
|
|
:return: (np.ndarray)
|
|
"""
|
|
return th.nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy()
|
|
|
|
|
|
def create_mlp(input_dim: int,
|
|
output_dim: int,
|
|
net_arch: List[int],
|
|
activation_fn: Type[nn.Module] = nn.ReLU,
|
|
squash_output: bool = False) -> List[nn.Module]:
|
|
"""
|
|
Create a multi layer perceptron (MLP), which is
|
|
a collection of fully-connected layers each followed by an activation function.
|
|
|
|
:param input_dim: (int) Dimension of the input vector
|
|
:param output_dim: (int)
|
|
:param net_arch: (List[int]) Architecture of the neural net
|
|
It represents the number of units per layer.
|
|
The length of this list is the number of layers.
|
|
:param activation_fn: (Type[nn.Module]) The activation function
|
|
to use after each layer.
|
|
:param squash_output: (bool) Whether to squash the output using a Tanh
|
|
activation function
|
|
:return: (List[nn.Module])
|
|
"""
|
|
|
|
if len(net_arch) > 0:
|
|
modules = [nn.Linear(input_dim, net_arch[0]), activation_fn()]
|
|
else:
|
|
modules = []
|
|
|
|
for idx in range(len(net_arch) - 1):
|
|
modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1]))
|
|
modules.append(activation_fn())
|
|
|
|
if output_dim > 0:
|
|
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else input_dim
|
|
modules.append(nn.Linear(last_layer_dim, output_dim))
|
|
if squash_output:
|
|
modules.append(nn.Tanh())
|
|
return modules
|
|
|
|
|
|
def create_sde_features_extractor(features_dim: int,
|
|
sde_net_arch: List[int],
|
|
activation_fn: Type[nn.Module]) -> Tuple[nn.Sequential, int]:
|
|
"""
|
|
Create the neural network that will be used to extract features
|
|
for the SDE exploration function.
|
|
|
|
:param features_dim: (int)
|
|
:param sde_net_arch: ([int])
|
|
:param activation_fn: (Type[nn.Module])
|
|
:return: (nn.Sequential, int)
|
|
"""
|
|
# Special case: when using states as features (i.e. sde_net_arch is an empty list)
|
|
# don't use any activation function
|
|
sde_activation = activation_fn if len(sde_net_arch) > 0 else None
|
|
latent_sde_net = create_mlp(features_dim, -1, sde_net_arch, activation_fn=sde_activation, squash_output=False)
|
|
latent_sde_dim = sde_net_arch[-1] if len(sde_net_arch) > 0 else features_dim
|
|
sde_features_extractor = nn.Sequential(*latent_sde_net)
|
|
return sde_features_extractor, latent_sde_dim
|
|
|
|
|
|
_policy_registry = dict() # type: Dict[Type[BasePolicy], Dict[str, Type[BasePolicy]]]
|
|
|
|
|
|
def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[BasePolicy]:
|
|
"""
|
|
Returns the registered policy from the base type and name
|
|
|
|
:param base_policy_type: (Type[BasePolicy]) the base policy class
|
|
:param name: (str) the policy name
|
|
:return: (Type[BasePolicy]) the policy
|
|
"""
|
|
if base_policy_type not in _policy_registry:
|
|
raise ValueError(f"Error: the policy type {base_policy_type} is not registered!")
|
|
if name not in _policy_registry[base_policy_type]:
|
|
raise ValueError(f"Error: unknown policy type {name},"
|
|
"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!")
|
|
return _policy_registry[base_policy_type][name]
|
|
|
|
|
|
def register_policy(name: str, policy: Type[BasePolicy]) -> None:
|
|
"""
|
|
Register a policy, so it can be called using its name.
|
|
e.g. SAC('MlpPolicy', ...) instead of SAC(MlpPolicy, ...)
|
|
|
|
:param name: (str) the policy name
|
|
:param policy: (Type[BasePolicy]) the policy class
|
|
"""
|
|
sub_class = None
|
|
# For building the doc
|
|
try:
|
|
for cls in BasePolicy.__subclasses__():
|
|
if issubclass(policy, cls):
|
|
sub_class = cls
|
|
break
|
|
except AttributeError:
|
|
sub_class = str(th.random.randint(100))
|
|
if sub_class is None:
|
|
raise ValueError(f"Error: the policy {policy} is not of any known subclasses of BasePolicy!")
|
|
|
|
if sub_class not in _policy_registry:
|
|
_policy_registry[sub_class] = {}
|
|
if name in _policy_registry[sub_class]:
|
|
raise ValueError(f"Error: the name {name} is alreay registered for a different policy, will not override.")
|
|
_policy_registry[sub_class][name] = policy
|
|
|
|
|
|
class MlpExtractor(nn.Module):
|
|
"""
|
|
Constructs an MLP that receives observations as an input and outputs a latent representation for the policy and
|
|
a value network. The ``net_arch`` parameter allows to specify the amount and size of the hidden layers and how many
|
|
of them are shared between the policy network and the value network. It is assumed to be a list with the following
|
|
structure:
|
|
|
|
1. An arbitrary length (zero allowed) number of integers each specifying the number of units in a shared layer.
|
|
If the number of ints is zero, there will be no shared layers.
|
|
2. An optional dict, to specify the following non-shared layers for the value network and the policy network.
|
|
It is formatted like ``dict(vf=[<value layer sizes>], pi=[<policy layer sizes>])``.
|
|
If it is missing any of the keys (pi or vf), no non-shared layers (empty list) is assumed.
|
|
|
|
For example to construct a network with one shared layer of size 55 followed by two non-shared layers for the value
|
|
network of size 255 and a single non-shared layer of size 128 for the policy network, the following layers_spec
|
|
would be used: ``[55, dict(vf=[255, 255], pi=[128])]``. A simple shared network topology with two layers of size 128
|
|
would be specified as [128, 128].
|
|
|
|
Adapted from Stable Baselines.
|
|
|
|
:param feature_dim: (int) Dimension of the feature vector (can be the output of a CNN)
|
|
:param net_arch: ([int or dict]) The specification of the policy and value networks.
|
|
See above for details on its formatting.
|
|
:param activation_fn: (Type[nn.Module]) The activation function to use for the networks.
|
|
:param device: (th.device)
|
|
"""
|
|
|
|
def __init__(self, feature_dim: int,
|
|
net_arch: List[Union[int, Dict[str, List[int]]]],
|
|
activation_fn: Type[nn.Module],
|
|
device: Union[th.device, str] = 'auto'):
|
|
super(MlpExtractor, self).__init__()
|
|
device = get_device(device)
|
|
|
|
shared_net, policy_net, value_net = [], [], []
|
|
policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network
|
|
value_only_layers = [] # Layer sizes of the network that only belongs to the value network
|
|
last_layer_dim_shared = feature_dim
|
|
|
|
# Iterate through the shared layers and build the shared parts of the network
|
|
for idx, layer in enumerate(net_arch):
|
|
if isinstance(layer, int): # Check that this is a shared layer
|
|
layer_size = layer
|
|
# TODO: give layer a meaningful name
|
|
shared_net.append(nn.Linear(last_layer_dim_shared, layer_size))
|
|
shared_net.append(activation_fn())
|
|
last_layer_dim_shared = layer_size
|
|
else:
|
|
assert isinstance(layer, dict), "Error: the net_arch list can only contain ints and dicts"
|
|
if 'pi' in layer:
|
|
assert isinstance(layer['pi'], list), "Error: net_arch[-1]['pi'] must contain a list of integers."
|
|
policy_only_layers = layer['pi']
|
|
|
|
if 'vf' in layer:
|
|
assert isinstance(layer['vf'], list), "Error: net_arch[-1]['vf'] must contain a list of integers."
|
|
value_only_layers = layer['vf']
|
|
break # From here on the network splits up in policy and value network
|
|
|
|
last_layer_dim_pi = last_layer_dim_shared
|
|
last_layer_dim_vf = last_layer_dim_shared
|
|
|
|
# Build the non-shared part of the network
|
|
for idx, (pi_layer_size, vf_layer_size) in enumerate(zip_longest(policy_only_layers, value_only_layers)):
|
|
if pi_layer_size is not None:
|
|
assert isinstance(pi_layer_size, int), "Error: net_arch[-1]['pi'] must only contain integers."
|
|
policy_net.append(nn.Linear(last_layer_dim_pi, pi_layer_size))
|
|
policy_net.append(activation_fn())
|
|
last_layer_dim_pi = pi_layer_size
|
|
|
|
if vf_layer_size is not None:
|
|
assert isinstance(vf_layer_size, int), "Error: net_arch[-1]['vf'] must only contain integers."
|
|
value_net.append(nn.Linear(last_layer_dim_vf, vf_layer_size))
|
|
value_net.append(activation_fn())
|
|
last_layer_dim_vf = vf_layer_size
|
|
|
|
# Save dim, used to create the distributions
|
|
self.latent_dim_pi = last_layer_dim_pi
|
|
self.latent_dim_vf = last_layer_dim_vf
|
|
|
|
# Create networks
|
|
# If the list of layers is empty, the network will just act as an Identity module
|
|
self.shared_net = nn.Sequential(*shared_net).to(device)
|
|
self.policy_net = nn.Sequential(*policy_net).to(device)
|
|
self.value_net = nn.Sequential(*value_net).to(device)
|
|
|
|
def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
|
"""
|
|
:return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network.
|
|
If all layers are shared, then ``latent_policy == latent_value``
|
|
"""
|
|
shared_latent = self.shared_net(features)
|
|
return self.policy_net(shared_latent), self.value_net(shared_latent)
|