mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-18 21:30:19 +00:00
Merge pull request #89 from DLR-RM/base-class-review
Refactor and clean-up of common code
This commit is contained in:
commit
c39ed397ac
9 changed files with 209 additions and 192 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -22,6 +22,7 @@ keys/
|
|||
|
||||
# Virtualenv
|
||||
/env
|
||||
/venv
|
||||
|
||||
|
||||
*.sublime-project
|
||||
|
|
|
|||
5
Makefile
5
Makefile
|
|
@ -1,4 +1,5 @@
|
|||
SHELL=/bin/bash
|
||||
LINT_PATHS=stable_baselines3/ tests/ docs/conf.py setup.py
|
||||
|
||||
pytest:
|
||||
./scripts/run_tests.sh
|
||||
|
|
@ -9,9 +10,9 @@ type:
|
|||
lint:
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
# see https://lintlyci.github.io/Flake8Rules/
|
||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
||||
flake8 ${LINT_PATHS} --count --select=E9,F63,F7,F82 --show-source --statistics
|
||||
# exit-zero treats all errors as warnings.
|
||||
flake8 . --count --exit-zero --statistics
|
||||
flake8 ${LINT_PATHS} --count --exit-zero --statistics
|
||||
|
||||
doc:
|
||||
cd docs && make html
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from unittest.mock import MagicMock
|
|||
# We CANNOT enable 'sphinxcontrib.spelling' because ReadTheDocs.org does not support
|
||||
# PyEnchant.
|
||||
try:
|
||||
import sphinxcontrib.spelling
|
||||
import sphinxcontrib.spelling # noqa: F401
|
||||
enable_spell_check = True
|
||||
except ImportError:
|
||||
enable_spell_check = False
|
||||
|
|
@ -129,6 +129,7 @@ html_logo = '_static/img/logo.png'
|
|||
def setup(app):
|
||||
app.add_stylesheet("css/baselines_theme.css")
|
||||
|
||||
|
||||
# Theme options are theme-specific and customize the look and feel of a theme
|
||||
# further. For a list of options available for each theme, see the
|
||||
# documentation.
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
"""Abstract base classes for RL algorithms."""
|
||||
|
||||
import time
|
||||
from typing import Union, Type, Optional, Dict, Any, List, Tuple, Callable
|
||||
from typing import Union, Type, Optional, Dict, Any, Iterable, List, Tuple, Callable
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
import pathlib
|
||||
|
|
@ -23,12 +25,30 @@ from stable_baselines3.common.monitor import Monitor
|
|||
from stable_baselines3.common.noise import ActionNoise
|
||||
|
||||
|
||||
def maybe_make_env(env: Union[GymEnv, str, None], monitor_wrapper: bool, verbose: int) -> Optional[GymEnv]:
|
||||
"""If env is a string, make the environment; otherwise, return env.
|
||||
|
||||
:param env: (Union[GymEnv, str, None]) The environment to learn from.
|
||||
:param monitor_wrapper: (bool) Whether to wrap env in a Monitor when creating env.
|
||||
:param verbose: (int) logging verbosity
|
||||
:return A Gym (vector) environment.
|
||||
"""
|
||||
if isinstance(env, str):
|
||||
if verbose >= 1:
|
||||
print(f"Creating environment from the given name '{env}'")
|
||||
env = gym.make(env)
|
||||
if monitor_wrapper:
|
||||
env = Monitor(env, filename=None)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
class BaseAlgorithm(ABC):
|
||||
"""
|
||||
The base of RL algorithms
|
||||
|
||||
:param policy: (Type[BasePolicy]) Policy object
|
||||
:param env: (Union[GymEnv, str]) The environment to learn from
|
||||
:param env: (Union[GymEnv, str, None]) The environment to learn from
|
||||
(if registered in Gym, can be str. Can be None for loading trained models)
|
||||
:param policy_base: (Type[BasePolicy]) The base policy used by this method
|
||||
:param learning_rate: (float or callable) learning rate for the optimizer,
|
||||
|
|
@ -54,7 +74,7 @@ class BaseAlgorithm(ABC):
|
|||
|
||||
def __init__(self,
|
||||
policy: Type[BasePolicy],
|
||||
env: Union[GymEnv, str],
|
||||
env: Union[GymEnv, str, None],
|
||||
policy_base: Type[BasePolicy],
|
||||
learning_rate: Union[float, Callable],
|
||||
policy_kwargs: Dict[str, Any] = None,
|
||||
|
|
@ -116,18 +136,9 @@ class BaseAlgorithm(ABC):
|
|||
if env is not None:
|
||||
if isinstance(env, str):
|
||||
if create_eval_env:
|
||||
eval_env = gym.make(env)
|
||||
if monitor_wrapper:
|
||||
eval_env = Monitor(eval_env, filename=None)
|
||||
self.eval_env = DummyVecEnv([lambda: eval_env])
|
||||
if self.verbose >= 1:
|
||||
print("Creating environment from the given name, wrapped in a DummyVecEnv.")
|
||||
|
||||
env = gym.make(env)
|
||||
if monitor_wrapper:
|
||||
env = Monitor(env, filename=None)
|
||||
env = DummyVecEnv([lambda: env])
|
||||
self.eval_env = maybe_make_env(env, monitor_wrapper, self.verbose)
|
||||
|
||||
env = maybe_make_env(env, monitor_wrapper, self.verbose)
|
||||
env = self._wrap_env(env)
|
||||
|
||||
self.observation_space = env.observation_space
|
||||
|
|
@ -136,8 +147,8 @@ class BaseAlgorithm(ABC):
|
|||
self.env = env
|
||||
|
||||
if not support_multi_env and self.n_envs > 1:
|
||||
raise ValueError("Error: the model does not support multiple envs requires a single vectorized"
|
||||
" environment.")
|
||||
raise ValueError("Error: the model does not support multiple envs; it requires "
|
||||
"a single vectorized environment.")
|
||||
|
||||
def _wrap_env(self, env: GymEnv) -> VecEnv:
|
||||
if not isinstance(env, VecEnv):
|
||||
|
|
@ -153,10 +164,7 @@ class BaseAlgorithm(ABC):
|
|||
|
||||
@abstractmethod
|
||||
def _setup_model(self) -> None:
|
||||
"""
|
||||
Create networks, buffer and optimizers
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
"""Create networks, buffer and optimizers."""
|
||||
|
||||
def _get_eval_env(self, eval_env: Optional[GymEnv]) -> Optional[GymEnv]:
|
||||
"""
|
||||
|
|
@ -238,7 +246,7 @@ class BaseAlgorithm(ABC):
|
|||
|
||||
def get_torch_variables(self) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
Get the name of the torch variable that will be saved.
|
||||
Get the name of the torch variables that will be saved.
|
||||
``th.save`` and ``th.load`` will be used with the right device
|
||||
instead of the default pickling strategy.
|
||||
|
||||
|
|
@ -263,10 +271,9 @@ class BaseAlgorithm(ABC):
|
|||
Return a trained model.
|
||||
|
||||
:param total_timesteps: (int) The total number of samples (env steps) to train on
|
||||
:param callback: (function (dict, dict)) -> boolean function called at every steps with state of the algorithm.
|
||||
It takes the local and global variables. If it returns False, training is aborted.
|
||||
:param callback: (MaybeCallback) callback(s) called at every step with state of the algorithm.
|
||||
:param log_interval: (int) The number of timesteps before logging.
|
||||
:param tb_log_name: (str) the name of the run for tensorboard log
|
||||
:param tb_log_name: (str) the name of the run for TensorBoard logging
|
||||
:param eval_env: (gym.Env) Environment that will be used to evaluate the agent
|
||||
:param eval_freq: (int) Evaluate the agent every ``eval_freq`` timesteps (this may vary a little)
|
||||
:param n_eval_episodes: (int) Number of episode to evaluate the agent
|
||||
|
|
@ -274,7 +281,6 @@ class BaseAlgorithm(ABC):
|
|||
:param reset_num_timesteps: (bool) whether or not to reset the current timestep number (used in logging)
|
||||
:return: (BaseAlgorithm) the trained model
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def predict(self, observation: np.ndarray,
|
||||
state: Optional[np.ndarray] = None,
|
||||
|
|
@ -329,8 +335,6 @@ class BaseAlgorithm(ABC):
|
|||
# load parameters
|
||||
model.__dict__.update(data)
|
||||
model.__dict__.update(kwargs)
|
||||
if not hasattr(model, "_setup_model") and len(params) > 0:
|
||||
raise NotImplementedError(f"{cls} has no ``_setup_model()`` method")
|
||||
model._setup_model()
|
||||
|
||||
# put state_dicts back in place
|
||||
|
|
@ -366,14 +370,18 @@ class BaseAlgorithm(ABC):
|
|||
self.eval_env.seed(seed)
|
||||
|
||||
def _init_callback(self,
|
||||
callback: Union[None, Callable, List[BaseCallback], BaseCallback],
|
||||
callback: MaybeCallback,
|
||||
eval_env: Optional[VecEnv] = None,
|
||||
eval_freq: int = 10000,
|
||||
n_eval_episodes: int = 5,
|
||||
log_path: Optional[str] = None) -> BaseCallback:
|
||||
"""
|
||||
:param callback: (Union[callable, [BaseCallback], BaseCallback, None])
|
||||
:return: (BaseCallback)
|
||||
:param callback: (MaybeCallback) Callback(s) called at every step with state of the algorithm.
|
||||
:param eval_freq: (Optional[int]) How many steps between evaluations; if None, do not evaluate.
|
||||
:param n_eval_episodes: (int) How many episodes to play per evaluation
|
||||
:param n_eval_episodes: (int) Number of episodes to rollout during evaluation.
|
||||
:param log_path: (Optional[str]) Path to a folder where the evaluations will be saved
|
||||
:return: (BaseCallback) A hybrid callback calling `callback` and performing evaluation.
|
||||
"""
|
||||
# Convert a list of callbacks into a callback
|
||||
if isinstance(callback, list):
|
||||
|
|
@ -396,7 +404,7 @@ class BaseAlgorithm(ABC):
|
|||
def _setup_learn(self,
|
||||
total_timesteps: int,
|
||||
eval_env: Optional[GymEnv],
|
||||
callback: Union[None, Callable, List[BaseCallback], BaseCallback] = None,
|
||||
callback: MaybeCallback = None,
|
||||
eval_freq: int = 10000,
|
||||
n_eval_episodes: int = 5,
|
||||
log_path: Optional[str] = None,
|
||||
|
|
@ -407,11 +415,11 @@ class BaseAlgorithm(ABC):
|
|||
Initialize different variables needed for training.
|
||||
|
||||
:param total_timesteps: (int) The total number of samples (env steps) to train on
|
||||
:param eval_env: (Optional[GymEnv])
|
||||
:param callback: (Union[None, BaseCallback, List[BaseCallback, Callable]])
|
||||
:param eval_env: (Optional[VecEnv]) Environment to use for evaluation.
|
||||
:param callback: (MaybeCallback) Callback(s) called at every step with state of the algorithm.
|
||||
:param eval_freq: (int) How many steps between evaluations
|
||||
:param n_eval_episodes: (int) How many episodes to play per evaluation
|
||||
:param log_path (Optional[str]): Path to a log folder
|
||||
:param log_path: (Optional[str]) Path to a folder where the evaluations will be saved
|
||||
:param reset_num_timesteps: (bool) Whether to reset or not the ``num_timesteps`` attribute
|
||||
:param tb_log_name: (str) the name of the run for tensorboard log
|
||||
:return: (Tuple[int, BaseCallback])
|
||||
|
|
@ -480,8 +488,8 @@ class BaseAlgorithm(ABC):
|
|||
def save(
|
||||
self,
|
||||
path: Union[str, pathlib.Path, io.BufferedIOBase],
|
||||
exclude: Optional[List[str]] = None,
|
||||
include: Optional[List[str]] = None,
|
||||
exclude: Optional[Iterable[str]] = None,
|
||||
include: Optional[Iterable[str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Save all the attributes of the object and the model parameters in a zip-file.
|
||||
|
|
@ -492,16 +500,15 @@ class BaseAlgorithm(ABC):
|
|||
"""
|
||||
# copy parameter list so we don't mutate the original dict
|
||||
data = self.__dict__.copy()
|
||||
# use standard list of excluded parameters if none given
|
||||
if exclude is None:
|
||||
exclude = self.excluded_save_params()
|
||||
else:
|
||||
# append standard exclude params to the given params
|
||||
exclude.extend([param for param in self.excluded_save_params() if param not in exclude])
|
||||
|
||||
# do not exclude params if they are specifically included
|
||||
# Exclude is union of specified parameters (if any) and standard exclusions
|
||||
if exclude is None:
|
||||
exclude = []
|
||||
exclude = set(exclude).union(self.excluded_save_params())
|
||||
|
||||
# Do not exclude params if they are specifically included
|
||||
if include is not None:
|
||||
exclude = [param_name for param_name in exclude if param_name not in include]
|
||||
exclude = exclude.difference(include)
|
||||
|
||||
state_dicts_names, tensors_names = self.get_torch_variables()
|
||||
# any params that are in the save vars must not be saved by data
|
||||
|
|
@ -509,12 +516,11 @@ class BaseAlgorithm(ABC):
|
|||
for torch_var in torch_variables:
|
||||
# we need to get only the name of the top most module as we'll remove that
|
||||
var_name = torch_var.split('.')[0]
|
||||
exclude.append(var_name)
|
||||
exclude.add(var_name)
|
||||
|
||||
# Remove parameter entries of parameters which are to be excluded
|
||||
for param_name in exclude:
|
||||
if param_name in data:
|
||||
data.pop(param_name, None)
|
||||
data.pop(param_name, None)
|
||||
|
||||
# Build dict of tensor variables
|
||||
tensors = None
|
||||
|
|
|
|||
|
|
@ -1,3 +1,6 @@
|
|||
"""Probability distributions."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Tuple, Dict, Any, List
|
||||
import gym
|
||||
import torch as th
|
||||
|
|
@ -8,36 +11,38 @@ from gym import spaces
|
|||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
|
||||
|
||||
class Distribution(object):
|
||||
class Distribution(ABC):
|
||||
"""Abstract base class for distributions."""
|
||||
|
||||
def __init__(self):
|
||||
super(Distribution, self).__init__()
|
||||
|
||||
@abstractmethod
|
||||
def log_prob(self, x: th.Tensor) -> th.Tensor:
|
||||
"""
|
||||
returns the log likelihood
|
||||
Returns the log likelihood
|
||||
|
||||
:param x: (th.Tensor) the taken action
|
||||
:return: (th.Tensor) The log likelihood of the distribution
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def entropy(self) -> Optional[th.Tensor]:
|
||||
"""
|
||||
Returns Shannon's entropy of the probability
|
||||
|
||||
:return: (Optional[th.Tensor]) the entropy,
|
||||
return None if no analytical form is known
|
||||
:return: (Optional[th.Tensor]) the entropy, or None if no analytical form is known
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def sample(self) -> th.Tensor:
|
||||
"""
|
||||
Returns a sample from the probability distribution
|
||||
|
||||
:return: (th.Tensor) the stochastic action
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def mode(self) -> th.Tensor:
|
||||
"""
|
||||
Returns the most likely action (deterministic output)
|
||||
|
|
@ -45,7 +50,6 @@ class Distribution(object):
|
|||
|
||||
:return: (th.Tensor) the stochastic action
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_actions(self, deterministic: bool = False) -> th.Tensor:
|
||||
"""
|
||||
|
|
@ -58,6 +62,7 @@ class Distribution(object):
|
|||
return self.mode()
|
||||
return self.sample()
|
||||
|
||||
@abstractmethod
|
||||
def actions_from_params(self, *args, **kwargs) -> th.Tensor:
|
||||
"""
|
||||
Returns samples from the probability distribution
|
||||
|
|
@ -65,8 +70,8 @@ class Distribution(object):
|
|||
|
||||
:return: (th.Tensor) actions
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def log_prob_from_params(self, *args, **kwargs) -> Tuple[th.Tensor, th.Tensor]:
|
||||
"""
|
||||
Returns samples and the associated log probabilities
|
||||
|
|
@ -74,14 +79,12 @@ class Distribution(object):
|
|||
|
||||
:return: (th.Tuple[th.Tensor, th.Tensor]) actions and log prob
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def sum_independent_dims(tensor: th.Tensor) -> th.Tensor:
|
||||
"""
|
||||
Continuous actions are usually considered to be independent,
|
||||
so we can sum the components for the ``log_prob``
|
||||
or the entropy.
|
||||
so we can sum components of the ``log_prob`` or the entropy.
|
||||
|
||||
:param tensor: (th.Tensor) shape: (n_batch, n_actions) or (n_batch,)
|
||||
:return: (th.Tensor) shape: (n_batch,)
|
||||
|
|
@ -95,8 +98,7 @@ def sum_independent_dims(tensor: th.Tensor) -> th.Tensor:
|
|||
|
||||
class DiagGaussianDistribution(Distribution):
|
||||
"""
|
||||
Gaussian distribution with diagonal covariance matrix,
|
||||
for continuous actions.
|
||||
Gaussian distribution with diagonal covariance matrix, for continuous actions.
|
||||
|
||||
:param action_dim: (int) Dimension of the action space.
|
||||
"""
|
||||
|
|
@ -115,7 +117,7 @@ class DiagGaussianDistribution(Distribution):
|
|||
one output will be the mean of the Gaussian, the other parameter will be the
|
||||
standard deviation (log std in fact to allow negative values)
|
||||
|
||||
:param latent_dim: (int) Dimension og the last layer of the policy (before the action layer)
|
||||
:param latent_dim: (int) Dimension of the last layer of the policy (before the action layer)
|
||||
:param log_std_init: (float) Initial value for the log standard deviation
|
||||
:return: (nn.Linear, nn.Parameter)
|
||||
"""
|
||||
|
|
@ -137,15 +139,26 @@ class DiagGaussianDistribution(Distribution):
|
|||
self.distribution = Normal(mean_actions, action_std)
|
||||
return self
|
||||
|
||||
def mode(self) -> th.Tensor:
|
||||
return self.distribution.mean
|
||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||
"""
|
||||
Get the log probabilities of actions according to the distribution.
|
||||
Note that you must first call the ``proba_distribution()`` method.
|
||||
|
||||
:param actions: (th.Tensor)
|
||||
:return: (th.Tensor)
|
||||
"""
|
||||
log_prob = self.distribution.log_prob(actions)
|
||||
return sum_independent_dims(log_prob)
|
||||
|
||||
def entropy(self) -> th.Tensor:
|
||||
return sum_independent_dims(self.distribution.entropy())
|
||||
|
||||
def sample(self) -> th.Tensor:
|
||||
# Reparametrization trick to pass gradients
|
||||
return self.distribution.rsample()
|
||||
|
||||
def entropy(self) -> th.Tensor:
|
||||
return sum_independent_dims(self.distribution.entropy())
|
||||
def mode(self) -> th.Tensor:
|
||||
return self.distribution.mean
|
||||
|
||||
def actions_from_params(self, mean_actions: th.Tensor,
|
||||
log_std: th.Tensor,
|
||||
|
|
@ -168,22 +181,10 @@ class DiagGaussianDistribution(Distribution):
|
|||
log_prob = self.log_prob(actions)
|
||||
return actions, log_prob
|
||||
|
||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||
"""
|
||||
Get the log probabilities of actions according to the distribution.
|
||||
Note that you must call ``proba_distribution()`` method before.
|
||||
|
||||
:param actions: (th.Tensor)
|
||||
:return: (th.Tensor)
|
||||
"""
|
||||
log_prob = self.distribution.log_prob(actions)
|
||||
return sum_independent_dims(log_prob)
|
||||
|
||||
|
||||
class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
|
||||
"""
|
||||
Gaussian distribution with diagonal covariance matrix,
|
||||
followed by a squashing function (tanh) to ensure bounds.
|
||||
Gaussian distribution with diagonal covariance matrix, followed by a squashing function (tanh) to ensure bounds.
|
||||
|
||||
:param action_dim: (int) Dimension of the action space.
|
||||
:param epsilon: (float) small value to avoid NaN due to numerical imprecision.
|
||||
|
|
@ -200,27 +201,6 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
|
|||
super(SquashedDiagGaussianDistribution, self).proba_distribution(mean_actions, log_std)
|
||||
return self
|
||||
|
||||
def mode(self) -> th.Tensor:
|
||||
self.gaussian_actions = self.distribution.mean
|
||||
# Squash the output
|
||||
return th.tanh(self.gaussian_actions)
|
||||
|
||||
def entropy(self) -> Optional[th.Tensor]:
|
||||
# No analytical form,
|
||||
# entropy needs to be estimated using -log_prob.mean()
|
||||
return None
|
||||
|
||||
def sample(self) -> th.Tensor:
|
||||
# Reparametrization trick to pass gradients
|
||||
self.gaussian_actions = self.distribution.rsample()
|
||||
return th.tanh(self.gaussian_actions)
|
||||
|
||||
def log_prob_from_params(self, mean_actions: th.Tensor,
|
||||
log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
action = self.actions_from_params(mean_actions, log_std)
|
||||
log_prob = self.log_prob(action, self.gaussian_actions)
|
||||
return action, log_prob
|
||||
|
||||
def log_prob(self, actions: th.Tensor,
|
||||
gaussian_actions: Optional[th.Tensor] = None) -> th.Tensor:
|
||||
# Inverse tanh
|
||||
|
|
@ -237,6 +217,27 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
|
|||
log_prob -= th.sum(th.log(1 - actions ** 2 + self.epsilon), dim=1)
|
||||
return log_prob
|
||||
|
||||
def entropy(self) -> Optional[th.Tensor]:
|
||||
# No analytical form,
|
||||
# entropy needs to be estimated using -log_prob.mean()
|
||||
return None
|
||||
|
||||
def sample(self) -> th.Tensor:
|
||||
# Reparametrization trick to pass gradients
|
||||
self.gaussian_actions = super().sample()
|
||||
return th.tanh(self.gaussian_actions)
|
||||
|
||||
def mode(self) -> th.Tensor:
|
||||
self.gaussian_actions = super().mode()
|
||||
# Squash the output
|
||||
return th.tanh(self.gaussian_actions)
|
||||
|
||||
def log_prob_from_params(self, mean_actions: th.Tensor,
|
||||
log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
action = self.actions_from_params(mean_actions, log_std)
|
||||
log_prob = self.log_prob(action, self.gaussian_actions)
|
||||
return action, log_prob
|
||||
|
||||
|
||||
class CategoricalDistribution(Distribution):
|
||||
"""
|
||||
|
|
@ -267,14 +268,17 @@ class CategoricalDistribution(Distribution):
|
|||
self.distribution = Categorical(logits=action_logits)
|
||||
return self
|
||||
|
||||
def mode(self) -> th.Tensor:
|
||||
return th.argmax(self.distribution.probs, dim=1)
|
||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||
return self.distribution.log_prob(actions)
|
||||
|
||||
def entropy(self) -> th.Tensor:
|
||||
return self.distribution.entropy()
|
||||
|
||||
def sample(self) -> th.Tensor:
|
||||
return self.distribution.sample()
|
||||
|
||||
def entropy(self) -> th.Tensor:
|
||||
return self.distribution.entropy()
|
||||
def mode(self) -> th.Tensor:
|
||||
return th.argmax(self.distribution.probs, dim=1)
|
||||
|
||||
def actions_from_params(self, action_logits: th.Tensor,
|
||||
deterministic: bool = False) -> th.Tensor:
|
||||
|
|
@ -287,9 +291,6 @@ class CategoricalDistribution(Distribution):
|
|||
log_prob = self.log_prob(actions)
|
||||
return actions, log_prob
|
||||
|
||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||
return self.distribution.log_prob(actions)
|
||||
|
||||
|
||||
class MultiCategoricalDistribution(Distribution):
|
||||
"""
|
||||
|
|
@ -321,14 +322,19 @@ class MultiCategoricalDistribution(Distribution):
|
|||
self.distributions = [Categorical(logits=split) for split in th.split(action_logits, tuple(self.action_dims), dim=1)]
|
||||
return self
|
||||
|
||||
def mode(self) -> th.Tensor:
|
||||
return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distributions], dim=1)
|
||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||
# Extract each discrete action and compute log prob for their respective distributions
|
||||
return th.stack([dist.log_prob(action) for dist, action in zip(self.distributions,
|
||||
th.unbind(actions, dim=1))], dim=1).sum(dim=1)
|
||||
|
||||
def entropy(self) -> th.Tensor:
|
||||
return th.stack([dist.entropy() for dist in self.distributions], dim=1).sum(dim=1)
|
||||
|
||||
def sample(self) -> th.Tensor:
|
||||
return th.stack([dist.sample() for dist in self.distributions], dim=1)
|
||||
|
||||
def entropy(self) -> th.Tensor:
|
||||
return th.stack([dist.entropy() for dist in self.distributions], dim=1).sum(dim=1)
|
||||
def mode(self) -> th.Tensor:
|
||||
return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distributions], dim=1)
|
||||
|
||||
def actions_from_params(self, action_logits: th.Tensor,
|
||||
deterministic: bool = False) -> th.Tensor:
|
||||
|
|
@ -341,11 +347,6 @@ class MultiCategoricalDistribution(Distribution):
|
|||
log_prob = self.log_prob(actions)
|
||||
return actions, log_prob
|
||||
|
||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||
# Extract each discrete action and compute log prob for their respective distributions
|
||||
return th.stack([dist.log_prob(action) for dist, action in zip(self.distributions,
|
||||
th.unbind(actions, dim=1))], dim=1).sum(dim=1)
|
||||
|
||||
|
||||
class BernoulliDistribution(Distribution):
|
||||
"""
|
||||
|
|
@ -375,14 +376,17 @@ class BernoulliDistribution(Distribution):
|
|||
self.distribution = Bernoulli(logits=action_logits)
|
||||
return self
|
||||
|
||||
def mode(self) -> th.Tensor:
|
||||
return th.round(self.distribution.probs)
|
||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||
return self.distribution.log_prob(actions).sum(dim=1)
|
||||
|
||||
def entropy(self) -> th.Tensor:
|
||||
return self.distribution.entropy().sum(dim=1)
|
||||
|
||||
def sample(self) -> th.Tensor:
|
||||
return self.distribution.sample()
|
||||
|
||||
def entropy(self) -> th.Tensor:
|
||||
return self.distribution.entropy().sum(dim=1)
|
||||
def mode(self) -> th.Tensor:
|
||||
return th.round(self.distribution.probs)
|
||||
|
||||
def actions_from_params(self, action_logits: th.Tensor,
|
||||
deterministic: bool = False) -> th.Tensor:
|
||||
|
|
@ -395,9 +399,6 @@ class BernoulliDistribution(Distribution):
|
|||
log_prob = self.log_prob(actions)
|
||||
return actions, log_prob
|
||||
|
||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||
return self.distribution.log_prob(actions).sum(dim=1)
|
||||
|
||||
|
||||
class StateDependentNoiseDistribution(Distribution):
|
||||
"""
|
||||
|
|
@ -414,7 +415,7 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
a positive standard deviation (cf paper). It allows to keep variance
|
||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||
:param squash_output: (bool) Whether to squash the output using a tanh function,
|
||||
this allows to ensure boundaries.
|
||||
this ensures bounds are satisfied.
|
||||
:param learn_features: (bool) Whether to learn features for gSDE or not.
|
||||
This will enable gradients to be backpropagated through the features
|
||||
``latent_sde`` in the code.
|
||||
|
|
@ -529,6 +530,35 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
self.distribution = Normal(mean_actions, th.sqrt(variance + self.epsilon))
|
||||
return self
|
||||
|
||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||
if self.bijector is not None:
|
||||
gaussian_actions = self.bijector.inverse(actions)
|
||||
else:
|
||||
gaussian_actions = actions
|
||||
# log likelihood for a gaussian
|
||||
log_prob = self.distribution.log_prob(gaussian_actions)
|
||||
# Sum along action dim
|
||||
log_prob = sum_independent_dims(log_prob)
|
||||
|
||||
if self.bijector is not None:
|
||||
# Squash correction (from original SAC implementation)
|
||||
log_prob -= th.sum(self.bijector.log_prob_correction(gaussian_actions), dim=1)
|
||||
return log_prob
|
||||
|
||||
def entropy(self) -> Optional[th.Tensor]:
|
||||
if self.bijector is not None:
|
||||
# No analytical form,
|
||||
# entropy needs to be estimated using -log_prob.mean()
|
||||
return None
|
||||
return sum_independent_dims(self.distribution.entropy())
|
||||
|
||||
def sample(self) -> th.Tensor:
|
||||
noise = self.get_noise(self._latent_sde)
|
||||
actions = self.distribution.mean + noise
|
||||
if self.bijector is not None:
|
||||
return self.bijector.forward(actions)
|
||||
return actions
|
||||
|
||||
def mode(self) -> th.Tensor:
|
||||
actions = self.distribution.mean
|
||||
if self.bijector is not None:
|
||||
|
|
@ -547,20 +577,6 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
noise = th.bmm(latent_sde, self.exploration_matrices)
|
||||
return noise.squeeze(1)
|
||||
|
||||
def sample(self) -> th.Tensor:
|
||||
noise = self.get_noise(self._latent_sde)
|
||||
actions = self.distribution.mean + noise
|
||||
if self.bijector is not None:
|
||||
return self.bijector.forward(actions)
|
||||
return actions
|
||||
|
||||
def entropy(self) -> Optional[th.Tensor]:
|
||||
# No analytical form,
|
||||
# entropy needs to be estimated using -log_prob.mean()
|
||||
if self.bijector is not None:
|
||||
return None
|
||||
return sum_independent_dims(self.distribution.entropy())
|
||||
|
||||
def actions_from_params(self, mean_actions: th.Tensor,
|
||||
log_std: th.Tensor,
|
||||
latent_sde: th.Tensor,
|
||||
|
|
@ -576,21 +592,6 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
log_prob = self.log_prob(actions)
|
||||
return actions, log_prob
|
||||
|
||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||
if self.bijector is not None:
|
||||
gaussian_actions = self.bijector.inverse(actions)
|
||||
else:
|
||||
gaussian_actions = actions
|
||||
# log likelihood for a gaussian
|
||||
log_prob = self.distribution.log_prob(gaussian_actions)
|
||||
# Sum along action dim
|
||||
log_prob = sum_independent_dims(log_prob)
|
||||
|
||||
if self.bijector is not None:
|
||||
# Squash correction (from original SAC implementation)
|
||||
log_prob -= th.sum(self.bijector.log_prob_correction(gaussian_actions), dim=1)
|
||||
return log_prob
|
||||
|
||||
|
||||
class TanhBijector(object):
|
||||
"""
|
||||
|
|
@ -653,9 +654,8 @@ def make_proba_distribution(action_space: gym.spaces.Space,
|
|||
|
||||
if isinstance(action_space, spaces.Box):
|
||||
assert len(action_space.shape) == 1, "Error: the action space must be a vector"
|
||||
if use_sde:
|
||||
return StateDependentNoiseDistribution(get_action_dim(action_space), **dist_kwargs)
|
||||
return DiagGaussianDistribution(get_action_dim(action_space), **dist_kwargs)
|
||||
cls = StateDependentNoiseDistribution if use_sde else DiagGaussianDistribution
|
||||
return cls(get_action_dim(action_space), **dist_kwargs)
|
||||
elif isinstance(action_space, spaces.Discrete):
|
||||
return CategoricalDistribution(action_space.n, **dist_kwargs)
|
||||
elif isinstance(action_space, spaces.MultiDiscrete):
|
||||
|
|
|
|||
|
|
@ -135,7 +135,7 @@ class OffPolicyAlgorithm(BaseAlgorithm):
|
|||
self.action_space, self.device,
|
||||
optimize_memory_usage=self.optimize_memory_usage)
|
||||
self.policy = self.policy_class(self.observation_space, self.action_space,
|
||||
self.lr_schedule, **self.policy_kwargs)
|
||||
self.lr_schedule, **self.policy_kwargs) # pytype:disable=not-instantiable
|
||||
self.policy = self.policy.to(self.device)
|
||||
|
||||
def save_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None:
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ class OnPolicyAlgorithm(BaseAlgorithm):
|
|||
n_envs=self.n_envs)
|
||||
self.policy = self.policy_class(self.observation_space, self.action_space,
|
||||
self.lr_schedule, use_sde=self.use_sde, device=self.device,
|
||||
**self.policy_kwargs)
|
||||
**self.policy_kwargs) # pytype:disable=not-instantiable
|
||||
self.policy = self.policy.to(self.device)
|
||||
|
||||
def collect_rollouts(self,
|
||||
|
|
|
|||
|
|
@ -1,3 +1,7 @@
|
|||
"""Policies: abstract base class and concrete implementations."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import collections
|
||||
from typing import Union, Type, Dict, List, Tuple, Optional, Any, Callable
|
||||
from functools import partial
|
||||
|
||||
|
|
@ -17,7 +21,7 @@ from stable_baselines3.common.distributions import (make_proba_distribution, Dis
|
|||
StateDependentNoiseDistribution)
|
||||
|
||||
|
||||
class BasePolicy(nn.Module):
|
||||
class BasePolicy(nn.Module, ABC):
|
||||
"""
|
||||
The base policy object
|
||||
|
||||
|
|
@ -98,17 +102,22 @@ class BasePolicy(nn.Module):
|
|||
module.bias.data.fill_(0.0)
|
||||
|
||||
@staticmethod
|
||||
def _dummy_schedule(_progress_remaining: float) -> float:
|
||||
def _dummy_schedule(progress_remaining: float) -> float:
|
||||
""" (float) Useful for pickling policy."""
|
||||
del progress_remaining
|
||||
return 0.0
|
||||
|
||||
def forward(self, *_args, **kwargs):
|
||||
raise NotImplementedError()
|
||||
@abstractmethod
|
||||
def forward(self, *args, **kwargs):
|
||||
del args, kwargs
|
||||
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
"""
|
||||
Get the action according to the policy for a given observation.
|
||||
|
||||
By default provides a dummy implementation -- not all BasePolicy classes
|
||||
implement this, e.g. if they are a Critic in an Actor-Critic method.
|
||||
|
||||
:param observation: (th.Tensor)
|
||||
:param deterministic: (bool) Whether to use stochastic or deterministic actions
|
||||
:return: (th.Tensor) Taken action according to the policy
|
||||
|
|
@ -140,10 +149,8 @@ class BasePolicy(nn.Module):
|
|||
# Handle the different cases for images
|
||||
# as PyTorch use channel first format
|
||||
if is_image_space(self.observation_space):
|
||||
if (observation.shape == self.observation_space.shape
|
||||
if not (observation.shape == self.observation_space.shape
|
||||
or observation.shape[1:] == self.observation_space.shape):
|
||||
pass
|
||||
else:
|
||||
# Try to re-order the channels
|
||||
transpose_obs = VecTransposeImage.transpose_image(observation)
|
||||
if (transpose_obs.shape == self.observation_space.shape
|
||||
|
|
@ -160,21 +167,21 @@ class BasePolicy(nn.Module):
|
|||
# Convert to numpy
|
||||
actions = actions.cpu().numpy()
|
||||
|
||||
# Rescale to proper domain when using squashing
|
||||
if isinstance(self.action_space, gym.spaces.Box) and self.squash_output:
|
||||
actions = self.unscale_action(actions)
|
||||
|
||||
clipped_actions = actions
|
||||
# Clip the actions to avoid out of bound error when using gaussian distribution
|
||||
if isinstance(self.action_space, gym.spaces.Box) and not self.squash_output:
|
||||
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
|
||||
if isinstance(self.action_space, gym.spaces.Box):
|
||||
if self.squash_output:
|
||||
# Rescale to proper domain when using squashing
|
||||
actions = self.unscale_action(actions)
|
||||
else:
|
||||
# Actions could be on arbitrary scale, so clip the actions to avoid
|
||||
# out of bound error (e.g. if sampling from a Gaussian distribution)
|
||||
actions = np.clip(actions, self.action_space.low, self.action_space.high)
|
||||
|
||||
if not vectorized_env:
|
||||
if state is not None:
|
||||
raise ValueError("Error: The environment must be vectorized when using recurrent policies.")
|
||||
clipped_actions = clipped_actions[0]
|
||||
actions = actions[0]
|
||||
|
||||
return clipped_actions, state
|
||||
return actions, state
|
||||
|
||||
def scale_action(self, action: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
|
|
@ -227,7 +234,7 @@ class BasePolicy(nn.Module):
|
|||
Load policy from path.
|
||||
|
||||
:param path: (str)
|
||||
:param device: ( Union[th.device, str]) Device on which the policy should be loaded.
|
||||
:param device: (Union[th.device, str]) Device on which the policy should be loaded.
|
||||
:return: (BasePolicy)
|
||||
"""
|
||||
device = get_device(device)
|
||||
|
|
@ -294,7 +301,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||
def __init__(self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: Callable,
|
||||
lr_schedule: Callable[[float], float],
|
||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||
device: Union[th.device, str] = 'auto',
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
|
|
@ -313,7 +320,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||
|
||||
if optimizer_kwargs is None:
|
||||
optimizer_kwargs = {}
|
||||
# Small values to avoid NaN in ADAM optimizer
|
||||
# Small values to avoid NaN in Adam optimizer
|
||||
if optimizer_class == th.optim.Adam:
|
||||
optimizer_kwargs['eps'] = 1e-5
|
||||
|
||||
|
|
@ -366,15 +373,17 @@ class ActorCriticPolicy(BasePolicy):
|
|||
def _get_data(self) -> Dict[str, Any]:
|
||||
data = super()._get_data()
|
||||
|
||||
default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None)
|
||||
|
||||
data.update(dict(
|
||||
net_arch=self.net_arch,
|
||||
activation_fn=self.activation_fn,
|
||||
use_sde=self.use_sde,
|
||||
log_std_init=self.log_std_init,
|
||||
squash_output=self.dist_kwargs['squash_output'] if self.dist_kwargs else None,
|
||||
full_std=self.dist_kwargs['full_std'] if self.dist_kwargs else None,
|
||||
sde_net_arch=self.dist_kwargs['sde_net_arch'] if self.dist_kwargs else None,
|
||||
use_expln=self.dist_kwargs['use_expln'] if self.dist_kwargs else None,
|
||||
squash_output=default_none_kwargs['squash_output'],
|
||||
full_std=default_none_kwargs['full_std'],
|
||||
sde_net_arch=default_none_kwargs['sde_net_arch'],
|
||||
use_expln=default_none_kwargs['use_expln'],
|
||||
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
|
||||
ortho_init=self.ortho_init,
|
||||
optimizer_class=self.optimizer_class,
|
||||
|
|
@ -394,7 +403,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||
StateDependentNoiseDistribution), 'reset_noise() is only available when using gSDE'
|
||||
self.action_dist.sample_weights(self.log_std, batch_size=n_envs)
|
||||
|
||||
def _build(self, lr_schedule: Callable) -> None:
|
||||
def _build(self, lr_schedule: Callable[[float], float]) -> None:
|
||||
"""
|
||||
Create the networks and the optimizer.
|
||||
|
||||
|
|
@ -719,10 +728,10 @@ def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[
|
|||
:return: (Type[BasePolicy]) the policy
|
||||
"""
|
||||
if base_policy_type not in _policy_registry:
|
||||
raise ValueError(f"Error: the policy type {base_policy_type} is not registered!")
|
||||
raise KeyError(f"Error: the policy type {base_policy_type} is not registered!")
|
||||
if name not in _policy_registry[base_policy_type]:
|
||||
raise ValueError(f"Error: unknown policy type {name},"
|
||||
f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!")
|
||||
raise KeyError(f"Error: unknown policy type {name},"
|
||||
f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!")
|
||||
return _policy_registry[base_policy_type][name]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
"""
|
||||
Common aliases for type hint
|
||||
"""
|
||||
"""Common aliases for type hints"""
|
||||
|
||||
from typing import Union, Dict, Any, NamedTuple, List, Callable, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
|
|
|||
Loading…
Reference in a new issue