Merge pull request #89 from DLR-RM/base-class-review

Refactor and clean-up of common code
This commit is contained in:
Adam Gleave 2020-07-07 18:53:51 -07:00 committed by GitHub
commit c39ed397ac
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 209 additions and 192 deletions

1
.gitignore vendored
View file

@ -22,6 +22,7 @@ keys/
# Virtualenv
/env
/venv
*.sublime-project

View file

@ -1,4 +1,5 @@
SHELL=/bin/bash
LINT_PATHS=stable_baselines3/ tests/ docs/conf.py setup.py
pytest:
./scripts/run_tests.sh
@ -9,9 +10,9 @@ type:
lint:
# stop the build if there are Python syntax errors or undefined names
# see https://lintlyci.github.io/Flake8Rules/
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
flake8 ${LINT_PATHS} --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings.
flake8 . --count --exit-zero --statistics
flake8 ${LINT_PATHS} --count --exit-zero --statistics
doc:
cd docs && make html

View file

@ -19,7 +19,7 @@ from unittest.mock import MagicMock
# We CANNOT enable 'sphinxcontrib.spelling' because ReadTheDocs.org does not support
# PyEnchant.
try:
import sphinxcontrib.spelling
import sphinxcontrib.spelling # noqa: F401
enable_spell_check = True
except ImportError:
enable_spell_check = False
@ -129,6 +129,7 @@ html_logo = '_static/img/logo.png'
def setup(app):
app.add_stylesheet("css/baselines_theme.css")
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.

View file

@ -1,5 +1,7 @@
"""Abstract base classes for RL algorithms."""
import time
from typing import Union, Type, Optional, Dict, Any, List, Tuple, Callable
from typing import Union, Type, Optional, Dict, Any, Iterable, List, Tuple, Callable
from abc import ABC, abstractmethod
from collections import deque
import pathlib
@ -23,12 +25,30 @@ from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.noise import ActionNoise
def maybe_make_env(env: Union[GymEnv, str, None], monitor_wrapper: bool, verbose: int) -> Optional[GymEnv]:
"""If env is a string, make the environment; otherwise, return env.
:param env: (Union[GymEnv, str, None]) The environment to learn from.
:param monitor_wrapper: (bool) Whether to wrap env in a Monitor when creating env.
:param verbose: (int) logging verbosity
:return A Gym (vector) environment.
"""
if isinstance(env, str):
if verbose >= 1:
print(f"Creating environment from the given name '{env}'")
env = gym.make(env)
if monitor_wrapper:
env = Monitor(env, filename=None)
return env
class BaseAlgorithm(ABC):
"""
The base of RL algorithms
:param policy: (Type[BasePolicy]) Policy object
:param env: (Union[GymEnv, str]) The environment to learn from
:param env: (Union[GymEnv, str, None]) The environment to learn from
(if registered in Gym, can be str. Can be None for loading trained models)
:param policy_base: (Type[BasePolicy]) The base policy used by this method
:param learning_rate: (float or callable) learning rate for the optimizer,
@ -54,7 +74,7 @@ class BaseAlgorithm(ABC):
def __init__(self,
policy: Type[BasePolicy],
env: Union[GymEnv, str],
env: Union[GymEnv, str, None],
policy_base: Type[BasePolicy],
learning_rate: Union[float, Callable],
policy_kwargs: Dict[str, Any] = None,
@ -116,18 +136,9 @@ class BaseAlgorithm(ABC):
if env is not None:
if isinstance(env, str):
if create_eval_env:
eval_env = gym.make(env)
if monitor_wrapper:
eval_env = Monitor(eval_env, filename=None)
self.eval_env = DummyVecEnv([lambda: eval_env])
if self.verbose >= 1:
print("Creating environment from the given name, wrapped in a DummyVecEnv.")
env = gym.make(env)
if monitor_wrapper:
env = Monitor(env, filename=None)
env = DummyVecEnv([lambda: env])
self.eval_env = maybe_make_env(env, monitor_wrapper, self.verbose)
env = maybe_make_env(env, monitor_wrapper, self.verbose)
env = self._wrap_env(env)
self.observation_space = env.observation_space
@ -136,8 +147,8 @@ class BaseAlgorithm(ABC):
self.env = env
if not support_multi_env and self.n_envs > 1:
raise ValueError("Error: the model does not support multiple envs requires a single vectorized"
" environment.")
raise ValueError("Error: the model does not support multiple envs; it requires "
"a single vectorized environment.")
def _wrap_env(self, env: GymEnv) -> VecEnv:
if not isinstance(env, VecEnv):
@ -153,10 +164,7 @@ class BaseAlgorithm(ABC):
@abstractmethod
def _setup_model(self) -> None:
"""
Create networks, buffer and optimizers
"""
raise NotImplementedError()
"""Create networks, buffer and optimizers."""
def _get_eval_env(self, eval_env: Optional[GymEnv]) -> Optional[GymEnv]:
"""
@ -238,7 +246,7 @@ class BaseAlgorithm(ABC):
def get_torch_variables(self) -> Tuple[List[str], List[str]]:
"""
Get the name of the torch variable that will be saved.
Get the name of the torch variables that will be saved.
``th.save`` and ``th.load`` will be used with the right device
instead of the default pickling strategy.
@ -263,10 +271,9 @@ class BaseAlgorithm(ABC):
Return a trained model.
:param total_timesteps: (int) The total number of samples (env steps) to train on
:param callback: (function (dict, dict)) -> boolean function called at every steps with state of the algorithm.
It takes the local and global variables. If it returns False, training is aborted.
:param callback: (MaybeCallback) callback(s) called at every step with state of the algorithm.
:param log_interval: (int) The number of timesteps before logging.
:param tb_log_name: (str) the name of the run for tensorboard log
:param tb_log_name: (str) the name of the run for TensorBoard logging
:param eval_env: (gym.Env) Environment that will be used to evaluate the agent
:param eval_freq: (int) Evaluate the agent every ``eval_freq`` timesteps (this may vary a little)
:param n_eval_episodes: (int) Number of episode to evaluate the agent
@ -274,7 +281,6 @@ class BaseAlgorithm(ABC):
:param reset_num_timesteps: (bool) whether or not to reset the current timestep number (used in logging)
:return: (BaseAlgorithm) the trained model
"""
raise NotImplementedError()
def predict(self, observation: np.ndarray,
state: Optional[np.ndarray] = None,
@ -329,8 +335,6 @@ class BaseAlgorithm(ABC):
# load parameters
model.__dict__.update(data)
model.__dict__.update(kwargs)
if not hasattr(model, "_setup_model") and len(params) > 0:
raise NotImplementedError(f"{cls} has no ``_setup_model()`` method")
model._setup_model()
# put state_dicts back in place
@ -366,14 +370,18 @@ class BaseAlgorithm(ABC):
self.eval_env.seed(seed)
def _init_callback(self,
callback: Union[None, Callable, List[BaseCallback], BaseCallback],
callback: MaybeCallback,
eval_env: Optional[VecEnv] = None,
eval_freq: int = 10000,
n_eval_episodes: int = 5,
log_path: Optional[str] = None) -> BaseCallback:
"""
:param callback: (Union[callable, [BaseCallback], BaseCallback, None])
:return: (BaseCallback)
:param callback: (MaybeCallback) Callback(s) called at every step with state of the algorithm.
:param eval_freq: (Optional[int]) How many steps between evaluations; if None, do not evaluate.
:param n_eval_episodes: (int) How many episodes to play per evaluation
:param n_eval_episodes: (int) Number of episodes to rollout during evaluation.
:param log_path: (Optional[str]) Path to a folder where the evaluations will be saved
:return: (BaseCallback) A hybrid callback calling `callback` and performing evaluation.
"""
# Convert a list of callbacks into a callback
if isinstance(callback, list):
@ -396,7 +404,7 @@ class BaseAlgorithm(ABC):
def _setup_learn(self,
total_timesteps: int,
eval_env: Optional[GymEnv],
callback: Union[None, Callable, List[BaseCallback], BaseCallback] = None,
callback: MaybeCallback = None,
eval_freq: int = 10000,
n_eval_episodes: int = 5,
log_path: Optional[str] = None,
@ -407,11 +415,11 @@ class BaseAlgorithm(ABC):
Initialize different variables needed for training.
:param total_timesteps: (int) The total number of samples (env steps) to train on
:param eval_env: (Optional[GymEnv])
:param callback: (Union[None, BaseCallback, List[BaseCallback, Callable]])
:param eval_env: (Optional[VecEnv]) Environment to use for evaluation.
:param callback: (MaybeCallback) Callback(s) called at every step with state of the algorithm.
:param eval_freq: (int) How many steps between evaluations
:param n_eval_episodes: (int) How many episodes to play per evaluation
:param log_path (Optional[str]): Path to a log folder
:param log_path: (Optional[str]) Path to a folder where the evaluations will be saved
:param reset_num_timesteps: (bool) Whether to reset or not the ``num_timesteps`` attribute
:param tb_log_name: (str) the name of the run for tensorboard log
:return: (Tuple[int, BaseCallback])
@ -480,8 +488,8 @@ class BaseAlgorithm(ABC):
def save(
self,
path: Union[str, pathlib.Path, io.BufferedIOBase],
exclude: Optional[List[str]] = None,
include: Optional[List[str]] = None,
exclude: Optional[Iterable[str]] = None,
include: Optional[Iterable[str]] = None,
) -> None:
"""
Save all the attributes of the object and the model parameters in a zip-file.
@ -492,16 +500,15 @@ class BaseAlgorithm(ABC):
"""
# copy parameter list so we don't mutate the original dict
data = self.__dict__.copy()
# use standard list of excluded parameters if none given
if exclude is None:
exclude = self.excluded_save_params()
else:
# append standard exclude params to the given params
exclude.extend([param for param in self.excluded_save_params() if param not in exclude])
# do not exclude params if they are specifically included
# Exclude is union of specified parameters (if any) and standard exclusions
if exclude is None:
exclude = []
exclude = set(exclude).union(self.excluded_save_params())
# Do not exclude params if they are specifically included
if include is not None:
exclude = [param_name for param_name in exclude if param_name not in include]
exclude = exclude.difference(include)
state_dicts_names, tensors_names = self.get_torch_variables()
# any params that are in the save vars must not be saved by data
@ -509,12 +516,11 @@ class BaseAlgorithm(ABC):
for torch_var in torch_variables:
# we need to get only the name of the top most module as we'll remove that
var_name = torch_var.split('.')[0]
exclude.append(var_name)
exclude.add(var_name)
# Remove parameter entries of parameters which are to be excluded
for param_name in exclude:
if param_name in data:
data.pop(param_name, None)
data.pop(param_name, None)
# Build dict of tensor variables
tensors = None

View file

@ -1,3 +1,6 @@
"""Probability distributions."""
from abc import ABC, abstractmethod
from typing import Optional, Tuple, Dict, Any, List
import gym
import torch as th
@ -8,36 +11,38 @@ from gym import spaces
from stable_baselines3.common.preprocessing import get_action_dim
class Distribution(object):
class Distribution(ABC):
"""Abstract base class for distributions."""
def __init__(self):
super(Distribution, self).__init__()
@abstractmethod
def log_prob(self, x: th.Tensor) -> th.Tensor:
"""
returns the log likelihood
Returns the log likelihood
:param x: (th.Tensor) the taken action
:return: (th.Tensor) The log likelihood of the distribution
"""
raise NotImplementedError
@abstractmethod
def entropy(self) -> Optional[th.Tensor]:
"""
Returns Shannon's entropy of the probability
:return: (Optional[th.Tensor]) the entropy,
return None if no analytical form is known
:return: (Optional[th.Tensor]) the entropy, or None if no analytical form is known
"""
raise NotImplementedError
@abstractmethod
def sample(self) -> th.Tensor:
"""
Returns a sample from the probability distribution
:return: (th.Tensor) the stochastic action
"""
raise NotImplementedError
@abstractmethod
def mode(self) -> th.Tensor:
"""
Returns the most likely action (deterministic output)
@ -45,7 +50,6 @@ class Distribution(object):
:return: (th.Tensor) the stochastic action
"""
raise NotImplementedError
def get_actions(self, deterministic: bool = False) -> th.Tensor:
"""
@ -58,6 +62,7 @@ class Distribution(object):
return self.mode()
return self.sample()
@abstractmethod
def actions_from_params(self, *args, **kwargs) -> th.Tensor:
"""
Returns samples from the probability distribution
@ -65,8 +70,8 @@ class Distribution(object):
:return: (th.Tensor) actions
"""
raise NotImplementedError
@abstractmethod
def log_prob_from_params(self, *args, **kwargs) -> Tuple[th.Tensor, th.Tensor]:
"""
Returns samples and the associated log probabilities
@ -74,14 +79,12 @@ class Distribution(object):
:return: (th.Tuple[th.Tensor, th.Tensor]) actions and log prob
"""
raise NotImplementedError
def sum_independent_dims(tensor: th.Tensor) -> th.Tensor:
"""
Continuous actions are usually considered to be independent,
so we can sum the components for the ``log_prob``
or the entropy.
so we can sum components of the ``log_prob`` or the entropy.
:param tensor: (th.Tensor) shape: (n_batch, n_actions) or (n_batch,)
:return: (th.Tensor) shape: (n_batch,)
@ -95,8 +98,7 @@ def sum_independent_dims(tensor: th.Tensor) -> th.Tensor:
class DiagGaussianDistribution(Distribution):
"""
Gaussian distribution with diagonal covariance matrix,
for continuous actions.
Gaussian distribution with diagonal covariance matrix, for continuous actions.
:param action_dim: (int) Dimension of the action space.
"""
@ -115,7 +117,7 @@ class DiagGaussianDistribution(Distribution):
one output will be the mean of the Gaussian, the other parameter will be the
standard deviation (log std in fact to allow negative values)
:param latent_dim: (int) Dimension og the last layer of the policy (before the action layer)
:param latent_dim: (int) Dimension of the last layer of the policy (before the action layer)
:param log_std_init: (float) Initial value for the log standard deviation
:return: (nn.Linear, nn.Parameter)
"""
@ -137,15 +139,26 @@ class DiagGaussianDistribution(Distribution):
self.distribution = Normal(mean_actions, action_std)
return self
def mode(self) -> th.Tensor:
return self.distribution.mean
def log_prob(self, actions: th.Tensor) -> th.Tensor:
"""
Get the log probabilities of actions according to the distribution.
Note that you must first call the ``proba_distribution()`` method.
:param actions: (th.Tensor)
:return: (th.Tensor)
"""
log_prob = self.distribution.log_prob(actions)
return sum_independent_dims(log_prob)
def entropy(self) -> th.Tensor:
return sum_independent_dims(self.distribution.entropy())
def sample(self) -> th.Tensor:
# Reparametrization trick to pass gradients
return self.distribution.rsample()
def entropy(self) -> th.Tensor:
return sum_independent_dims(self.distribution.entropy())
def mode(self) -> th.Tensor:
return self.distribution.mean
def actions_from_params(self, mean_actions: th.Tensor,
log_std: th.Tensor,
@ -168,22 +181,10 @@ class DiagGaussianDistribution(Distribution):
log_prob = self.log_prob(actions)
return actions, log_prob
def log_prob(self, actions: th.Tensor) -> th.Tensor:
"""
Get the log probabilities of actions according to the distribution.
Note that you must call ``proba_distribution()`` method before.
:param actions: (th.Tensor)
:return: (th.Tensor)
"""
log_prob = self.distribution.log_prob(actions)
return sum_independent_dims(log_prob)
class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
"""
Gaussian distribution with diagonal covariance matrix,
followed by a squashing function (tanh) to ensure bounds.
Gaussian distribution with diagonal covariance matrix, followed by a squashing function (tanh) to ensure bounds.
:param action_dim: (int) Dimension of the action space.
:param epsilon: (float) small value to avoid NaN due to numerical imprecision.
@ -200,27 +201,6 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
super(SquashedDiagGaussianDistribution, self).proba_distribution(mean_actions, log_std)
return self
def mode(self) -> th.Tensor:
self.gaussian_actions = self.distribution.mean
# Squash the output
return th.tanh(self.gaussian_actions)
def entropy(self) -> Optional[th.Tensor]:
# No analytical form,
# entropy needs to be estimated using -log_prob.mean()
return None
def sample(self) -> th.Tensor:
# Reparametrization trick to pass gradients
self.gaussian_actions = self.distribution.rsample()
return th.tanh(self.gaussian_actions)
def log_prob_from_params(self, mean_actions: th.Tensor,
log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
action = self.actions_from_params(mean_actions, log_std)
log_prob = self.log_prob(action, self.gaussian_actions)
return action, log_prob
def log_prob(self, actions: th.Tensor,
gaussian_actions: Optional[th.Tensor] = None) -> th.Tensor:
# Inverse tanh
@ -237,6 +217,27 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
log_prob -= th.sum(th.log(1 - actions ** 2 + self.epsilon), dim=1)
return log_prob
def entropy(self) -> Optional[th.Tensor]:
# No analytical form,
# entropy needs to be estimated using -log_prob.mean()
return None
def sample(self) -> th.Tensor:
# Reparametrization trick to pass gradients
self.gaussian_actions = super().sample()
return th.tanh(self.gaussian_actions)
def mode(self) -> th.Tensor:
self.gaussian_actions = super().mode()
# Squash the output
return th.tanh(self.gaussian_actions)
def log_prob_from_params(self, mean_actions: th.Tensor,
log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
action = self.actions_from_params(mean_actions, log_std)
log_prob = self.log_prob(action, self.gaussian_actions)
return action, log_prob
class CategoricalDistribution(Distribution):
"""
@ -267,14 +268,17 @@ class CategoricalDistribution(Distribution):
self.distribution = Categorical(logits=action_logits)
return self
def mode(self) -> th.Tensor:
return th.argmax(self.distribution.probs, dim=1)
def log_prob(self, actions: th.Tensor) -> th.Tensor:
return self.distribution.log_prob(actions)
def entropy(self) -> th.Tensor:
return self.distribution.entropy()
def sample(self) -> th.Tensor:
return self.distribution.sample()
def entropy(self) -> th.Tensor:
return self.distribution.entropy()
def mode(self) -> th.Tensor:
return th.argmax(self.distribution.probs, dim=1)
def actions_from_params(self, action_logits: th.Tensor,
deterministic: bool = False) -> th.Tensor:
@ -287,9 +291,6 @@ class CategoricalDistribution(Distribution):
log_prob = self.log_prob(actions)
return actions, log_prob
def log_prob(self, actions: th.Tensor) -> th.Tensor:
return self.distribution.log_prob(actions)
class MultiCategoricalDistribution(Distribution):
"""
@ -321,14 +322,19 @@ class MultiCategoricalDistribution(Distribution):
self.distributions = [Categorical(logits=split) for split in th.split(action_logits, tuple(self.action_dims), dim=1)]
return self
def mode(self) -> th.Tensor:
return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distributions], dim=1)
def log_prob(self, actions: th.Tensor) -> th.Tensor:
# Extract each discrete action and compute log prob for their respective distributions
return th.stack([dist.log_prob(action) for dist, action in zip(self.distributions,
th.unbind(actions, dim=1))], dim=1).sum(dim=1)
def entropy(self) -> th.Tensor:
return th.stack([dist.entropy() for dist in self.distributions], dim=1).sum(dim=1)
def sample(self) -> th.Tensor:
return th.stack([dist.sample() for dist in self.distributions], dim=1)
def entropy(self) -> th.Tensor:
return th.stack([dist.entropy() for dist in self.distributions], dim=1).sum(dim=1)
def mode(self) -> th.Tensor:
return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distributions], dim=1)
def actions_from_params(self, action_logits: th.Tensor,
deterministic: bool = False) -> th.Tensor:
@ -341,11 +347,6 @@ class MultiCategoricalDistribution(Distribution):
log_prob = self.log_prob(actions)
return actions, log_prob
def log_prob(self, actions: th.Tensor) -> th.Tensor:
# Extract each discrete action and compute log prob for their respective distributions
return th.stack([dist.log_prob(action) for dist, action in zip(self.distributions,
th.unbind(actions, dim=1))], dim=1).sum(dim=1)
class BernoulliDistribution(Distribution):
"""
@ -375,14 +376,17 @@ class BernoulliDistribution(Distribution):
self.distribution = Bernoulli(logits=action_logits)
return self
def mode(self) -> th.Tensor:
return th.round(self.distribution.probs)
def log_prob(self, actions: th.Tensor) -> th.Tensor:
return self.distribution.log_prob(actions).sum(dim=1)
def entropy(self) -> th.Tensor:
return self.distribution.entropy().sum(dim=1)
def sample(self) -> th.Tensor:
return self.distribution.sample()
def entropy(self) -> th.Tensor:
return self.distribution.entropy().sum(dim=1)
def mode(self) -> th.Tensor:
return th.round(self.distribution.probs)
def actions_from_params(self, action_logits: th.Tensor,
deterministic: bool = False) -> th.Tensor:
@ -395,9 +399,6 @@ class BernoulliDistribution(Distribution):
log_prob = self.log_prob(actions)
return actions, log_prob
def log_prob(self, actions: th.Tensor) -> th.Tensor:
return self.distribution.log_prob(actions).sum(dim=1)
class StateDependentNoiseDistribution(Distribution):
"""
@ -414,7 +415,7 @@ class StateDependentNoiseDistribution(Distribution):
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param squash_output: (bool) Whether to squash the output using a tanh function,
this allows to ensure boundaries.
this ensures bounds are satisfied.
:param learn_features: (bool) Whether to learn features for gSDE or not.
This will enable gradients to be backpropagated through the features
``latent_sde`` in the code.
@ -529,6 +530,35 @@ class StateDependentNoiseDistribution(Distribution):
self.distribution = Normal(mean_actions, th.sqrt(variance + self.epsilon))
return self
def log_prob(self, actions: th.Tensor) -> th.Tensor:
if self.bijector is not None:
gaussian_actions = self.bijector.inverse(actions)
else:
gaussian_actions = actions
# log likelihood for a gaussian
log_prob = self.distribution.log_prob(gaussian_actions)
# Sum along action dim
log_prob = sum_independent_dims(log_prob)
if self.bijector is not None:
# Squash correction (from original SAC implementation)
log_prob -= th.sum(self.bijector.log_prob_correction(gaussian_actions), dim=1)
return log_prob
def entropy(self) -> Optional[th.Tensor]:
if self.bijector is not None:
# No analytical form,
# entropy needs to be estimated using -log_prob.mean()
return None
return sum_independent_dims(self.distribution.entropy())
def sample(self) -> th.Tensor:
noise = self.get_noise(self._latent_sde)
actions = self.distribution.mean + noise
if self.bijector is not None:
return self.bijector.forward(actions)
return actions
def mode(self) -> th.Tensor:
actions = self.distribution.mean
if self.bijector is not None:
@ -547,20 +577,6 @@ class StateDependentNoiseDistribution(Distribution):
noise = th.bmm(latent_sde, self.exploration_matrices)
return noise.squeeze(1)
def sample(self) -> th.Tensor:
noise = self.get_noise(self._latent_sde)
actions = self.distribution.mean + noise
if self.bijector is not None:
return self.bijector.forward(actions)
return actions
def entropy(self) -> Optional[th.Tensor]:
# No analytical form,
# entropy needs to be estimated using -log_prob.mean()
if self.bijector is not None:
return None
return sum_independent_dims(self.distribution.entropy())
def actions_from_params(self, mean_actions: th.Tensor,
log_std: th.Tensor,
latent_sde: th.Tensor,
@ -576,21 +592,6 @@ class StateDependentNoiseDistribution(Distribution):
log_prob = self.log_prob(actions)
return actions, log_prob
def log_prob(self, actions: th.Tensor) -> th.Tensor:
if self.bijector is not None:
gaussian_actions = self.bijector.inverse(actions)
else:
gaussian_actions = actions
# log likelihood for a gaussian
log_prob = self.distribution.log_prob(gaussian_actions)
# Sum along action dim
log_prob = sum_independent_dims(log_prob)
if self.bijector is not None:
# Squash correction (from original SAC implementation)
log_prob -= th.sum(self.bijector.log_prob_correction(gaussian_actions), dim=1)
return log_prob
class TanhBijector(object):
"""
@ -653,9 +654,8 @@ def make_proba_distribution(action_space: gym.spaces.Space,
if isinstance(action_space, spaces.Box):
assert len(action_space.shape) == 1, "Error: the action space must be a vector"
if use_sde:
return StateDependentNoiseDistribution(get_action_dim(action_space), **dist_kwargs)
return DiagGaussianDistribution(get_action_dim(action_space), **dist_kwargs)
cls = StateDependentNoiseDistribution if use_sde else DiagGaussianDistribution
return cls(get_action_dim(action_space), **dist_kwargs)
elif isinstance(action_space, spaces.Discrete):
return CategoricalDistribution(action_space.n, **dist_kwargs)
elif isinstance(action_space, spaces.MultiDiscrete):

View file

@ -135,7 +135,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
self.action_space, self.device,
optimize_memory_usage=self.optimize_memory_usage)
self.policy = self.policy_class(self.observation_space, self.action_space,
self.lr_schedule, **self.policy_kwargs)
self.lr_schedule, **self.policy_kwargs) # pytype:disable=not-instantiable
self.policy = self.policy.to(self.device)
def save_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None:

View file

@ -96,7 +96,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
n_envs=self.n_envs)
self.policy = self.policy_class(self.observation_space, self.action_space,
self.lr_schedule, use_sde=self.use_sde, device=self.device,
**self.policy_kwargs)
**self.policy_kwargs) # pytype:disable=not-instantiable
self.policy = self.policy.to(self.device)
def collect_rollouts(self,

View file

@ -1,3 +1,7 @@
"""Policies: abstract base class and concrete implementations."""
from abc import ABC, abstractmethod
import collections
from typing import Union, Type, Dict, List, Tuple, Optional, Any, Callable
from functools import partial
@ -17,7 +21,7 @@ from stable_baselines3.common.distributions import (make_proba_distribution, Dis
StateDependentNoiseDistribution)
class BasePolicy(nn.Module):
class BasePolicy(nn.Module, ABC):
"""
The base policy object
@ -98,17 +102,22 @@ class BasePolicy(nn.Module):
module.bias.data.fill_(0.0)
@staticmethod
def _dummy_schedule(_progress_remaining: float) -> float:
def _dummy_schedule(progress_remaining: float) -> float:
""" (float) Useful for pickling policy."""
del progress_remaining
return 0.0
def forward(self, *_args, **kwargs):
raise NotImplementedError()
@abstractmethod
def forward(self, *args, **kwargs):
del args, kwargs
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
"""
Get the action according to the policy for a given observation.
By default provides a dummy implementation -- not all BasePolicy classes
implement this, e.g. if they are a Critic in an Actor-Critic method.
:param observation: (th.Tensor)
:param deterministic: (bool) Whether to use stochastic or deterministic actions
:return: (th.Tensor) Taken action according to the policy
@ -140,10 +149,8 @@ class BasePolicy(nn.Module):
# Handle the different cases for images
# as PyTorch use channel first format
if is_image_space(self.observation_space):
if (observation.shape == self.observation_space.shape
if not (observation.shape == self.observation_space.shape
or observation.shape[1:] == self.observation_space.shape):
pass
else:
# Try to re-order the channels
transpose_obs = VecTransposeImage.transpose_image(observation)
if (transpose_obs.shape == self.observation_space.shape
@ -160,21 +167,21 @@ class BasePolicy(nn.Module):
# Convert to numpy
actions = actions.cpu().numpy()
# Rescale to proper domain when using squashing
if isinstance(self.action_space, gym.spaces.Box) and self.squash_output:
actions = self.unscale_action(actions)
clipped_actions = actions
# Clip the actions to avoid out of bound error when using gaussian distribution
if isinstance(self.action_space, gym.spaces.Box) and not self.squash_output:
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
if isinstance(self.action_space, gym.spaces.Box):
if self.squash_output:
# Rescale to proper domain when using squashing
actions = self.unscale_action(actions)
else:
# Actions could be on arbitrary scale, so clip the actions to avoid
# out of bound error (e.g. if sampling from a Gaussian distribution)
actions = np.clip(actions, self.action_space.low, self.action_space.high)
if not vectorized_env:
if state is not None:
raise ValueError("Error: The environment must be vectorized when using recurrent policies.")
clipped_actions = clipped_actions[0]
actions = actions[0]
return clipped_actions, state
return actions, state
def scale_action(self, action: np.ndarray) -> np.ndarray:
"""
@ -227,7 +234,7 @@ class BasePolicy(nn.Module):
Load policy from path.
:param path: (str)
:param device: ( Union[th.device, str]) Device on which the policy should be loaded.
:param device: (Union[th.device, str]) Device on which the policy should be loaded.
:return: (BasePolicy)
"""
device = get_device(device)
@ -294,7 +301,7 @@ class ActorCriticPolicy(BasePolicy):
def __init__(self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Callable,
lr_schedule: Callable[[float], float],
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
device: Union[th.device, str] = 'auto',
activation_fn: Type[nn.Module] = nn.Tanh,
@ -313,7 +320,7 @@ class ActorCriticPolicy(BasePolicy):
if optimizer_kwargs is None:
optimizer_kwargs = {}
# Small values to avoid NaN in ADAM optimizer
# Small values to avoid NaN in Adam optimizer
if optimizer_class == th.optim.Adam:
optimizer_kwargs['eps'] = 1e-5
@ -366,15 +373,17 @@ class ActorCriticPolicy(BasePolicy):
def _get_data(self) -> Dict[str, Any]:
data = super()._get_data()
default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None)
data.update(dict(
net_arch=self.net_arch,
activation_fn=self.activation_fn,
use_sde=self.use_sde,
log_std_init=self.log_std_init,
squash_output=self.dist_kwargs['squash_output'] if self.dist_kwargs else None,
full_std=self.dist_kwargs['full_std'] if self.dist_kwargs else None,
sde_net_arch=self.dist_kwargs['sde_net_arch'] if self.dist_kwargs else None,
use_expln=self.dist_kwargs['use_expln'] if self.dist_kwargs else None,
squash_output=default_none_kwargs['squash_output'],
full_std=default_none_kwargs['full_std'],
sde_net_arch=default_none_kwargs['sde_net_arch'],
use_expln=default_none_kwargs['use_expln'],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
ortho_init=self.ortho_init,
optimizer_class=self.optimizer_class,
@ -394,7 +403,7 @@ class ActorCriticPolicy(BasePolicy):
StateDependentNoiseDistribution), 'reset_noise() is only available when using gSDE'
self.action_dist.sample_weights(self.log_std, batch_size=n_envs)
def _build(self, lr_schedule: Callable) -> None:
def _build(self, lr_schedule: Callable[[float], float]) -> None:
"""
Create the networks and the optimizer.
@ -719,10 +728,10 @@ def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[
:return: (Type[BasePolicy]) the policy
"""
if base_policy_type not in _policy_registry:
raise ValueError(f"Error: the policy type {base_policy_type} is not registered!")
raise KeyError(f"Error: the policy type {base_policy_type} is not registered!")
if name not in _policy_registry[base_policy_type]:
raise ValueError(f"Error: unknown policy type {name},"
f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!")
raise KeyError(f"Error: unknown policy type {name},"
f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!")
return _policy_registry[base_policy_type][name]

View file

@ -1,6 +1,5 @@
"""
Common aliases for type hint
"""
"""Common aliases for type hints"""
from typing import Union, Dict, Any, NamedTuple, List, Callable, Tuple
import numpy as np