diff --git a/.gitignore b/.gitignore index 43b627c..9f54889 100644 --- a/.gitignore +++ b/.gitignore @@ -22,6 +22,7 @@ keys/ # Virtualenv /env +/venv *.sublime-project diff --git a/Makefile b/Makefile index e874ee7..05ce4d2 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,5 @@ SHELL=/bin/bash +LINT_PATHS=stable_baselines3/ tests/ docs/conf.py setup.py pytest: ./scripts/run_tests.sh @@ -9,9 +10,9 @@ type: lint: # stop the build if there are Python syntax errors or undefined names # see https://lintlyci.github.io/Flake8Rules/ - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + flake8 ${LINT_PATHS} --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. - flake8 . --count --exit-zero --statistics + flake8 ${LINT_PATHS} --count --exit-zero --statistics doc: cd docs && make html diff --git a/docs/conf.py b/docs/conf.py index 5d74970..78138f6 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -19,7 +19,7 @@ from unittest.mock import MagicMock # We CANNOT enable 'sphinxcontrib.spelling' because ReadTheDocs.org does not support # PyEnchant. try: - import sphinxcontrib.spelling + import sphinxcontrib.spelling # noqa: F401 enable_spell_check = True except ImportError: enable_spell_check = False @@ -129,6 +129,7 @@ html_logo = '_static/img/logo.png' def setup(app): app.add_stylesheet("css/baselines_theme.css") + # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 4453231..f66220e 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -1,5 +1,7 @@ +"""Abstract base classes for RL algorithms.""" + import time -from typing import Union, Type, Optional, Dict, Any, List, Tuple, Callable +from typing import Union, Type, Optional, Dict, Any, Iterable, List, Tuple, Callable from abc import ABC, abstractmethod from collections import deque import pathlib @@ -23,12 +25,30 @@ from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.noise import ActionNoise +def maybe_make_env(env: Union[GymEnv, str, None], monitor_wrapper: bool, verbose: int) -> Optional[GymEnv]: + """If env is a string, make the environment; otherwise, return env. + + :param env: (Union[GymEnv, str, None]) The environment to learn from. + :param monitor_wrapper: (bool) Whether to wrap env in a Monitor when creating env. + :param verbose: (int) logging verbosity + :return A Gym (vector) environment. + """ + if isinstance(env, str): + if verbose >= 1: + print(f"Creating environment from the given name '{env}'") + env = gym.make(env) + if monitor_wrapper: + env = Monitor(env, filename=None) + + return env + + class BaseAlgorithm(ABC): """ The base of RL algorithms :param policy: (Type[BasePolicy]) Policy object - :param env: (Union[GymEnv, str]) The environment to learn from + :param env: (Union[GymEnv, str, None]) The environment to learn from (if registered in Gym, can be str. Can be None for loading trained models) :param policy_base: (Type[BasePolicy]) The base policy used by this method :param learning_rate: (float or callable) learning rate for the optimizer, @@ -54,7 +74,7 @@ class BaseAlgorithm(ABC): def __init__(self, policy: Type[BasePolicy], - env: Union[GymEnv, str], + env: Union[GymEnv, str, None], policy_base: Type[BasePolicy], learning_rate: Union[float, Callable], policy_kwargs: Dict[str, Any] = None, @@ -116,18 +136,9 @@ class BaseAlgorithm(ABC): if env is not None: if isinstance(env, str): if create_eval_env: - eval_env = gym.make(env) - if monitor_wrapper: - eval_env = Monitor(eval_env, filename=None) - self.eval_env = DummyVecEnv([lambda: eval_env]) - if self.verbose >= 1: - print("Creating environment from the given name, wrapped in a DummyVecEnv.") - - env = gym.make(env) - if monitor_wrapper: - env = Monitor(env, filename=None) - env = DummyVecEnv([lambda: env]) + self.eval_env = maybe_make_env(env, monitor_wrapper, self.verbose) + env = maybe_make_env(env, monitor_wrapper, self.verbose) env = self._wrap_env(env) self.observation_space = env.observation_space @@ -136,8 +147,8 @@ class BaseAlgorithm(ABC): self.env = env if not support_multi_env and self.n_envs > 1: - raise ValueError("Error: the model does not support multiple envs requires a single vectorized" - " environment.") + raise ValueError("Error: the model does not support multiple envs; it requires " + "a single vectorized environment.") def _wrap_env(self, env: GymEnv) -> VecEnv: if not isinstance(env, VecEnv): @@ -153,10 +164,7 @@ class BaseAlgorithm(ABC): @abstractmethod def _setup_model(self) -> None: - """ - Create networks, buffer and optimizers - """ - raise NotImplementedError() + """Create networks, buffer and optimizers.""" def _get_eval_env(self, eval_env: Optional[GymEnv]) -> Optional[GymEnv]: """ @@ -238,7 +246,7 @@ class BaseAlgorithm(ABC): def get_torch_variables(self) -> Tuple[List[str], List[str]]: """ - Get the name of the torch variable that will be saved. + Get the name of the torch variables that will be saved. ``th.save`` and ``th.load`` will be used with the right device instead of the default pickling strategy. @@ -263,10 +271,9 @@ class BaseAlgorithm(ABC): Return a trained model. :param total_timesteps: (int) The total number of samples (env steps) to train on - :param callback: (function (dict, dict)) -> boolean function called at every steps with state of the algorithm. - It takes the local and global variables. If it returns False, training is aborted. + :param callback: (MaybeCallback) callback(s) called at every step with state of the algorithm. :param log_interval: (int) The number of timesteps before logging. - :param tb_log_name: (str) the name of the run for tensorboard log + :param tb_log_name: (str) the name of the run for TensorBoard logging :param eval_env: (gym.Env) Environment that will be used to evaluate the agent :param eval_freq: (int) Evaluate the agent every ``eval_freq`` timesteps (this may vary a little) :param n_eval_episodes: (int) Number of episode to evaluate the agent @@ -274,7 +281,6 @@ class BaseAlgorithm(ABC): :param reset_num_timesteps: (bool) whether or not to reset the current timestep number (used in logging) :return: (BaseAlgorithm) the trained model """ - raise NotImplementedError() def predict(self, observation: np.ndarray, state: Optional[np.ndarray] = None, @@ -329,8 +335,6 @@ class BaseAlgorithm(ABC): # load parameters model.__dict__.update(data) model.__dict__.update(kwargs) - if not hasattr(model, "_setup_model") and len(params) > 0: - raise NotImplementedError(f"{cls} has no ``_setup_model()`` method") model._setup_model() # put state_dicts back in place @@ -366,14 +370,18 @@ class BaseAlgorithm(ABC): self.eval_env.seed(seed) def _init_callback(self, - callback: Union[None, Callable, List[BaseCallback], BaseCallback], + callback: MaybeCallback, eval_env: Optional[VecEnv] = None, eval_freq: int = 10000, n_eval_episodes: int = 5, log_path: Optional[str] = None) -> BaseCallback: """ - :param callback: (Union[callable, [BaseCallback], BaseCallback, None]) - :return: (BaseCallback) + :param callback: (MaybeCallback) Callback(s) called at every step with state of the algorithm. + :param eval_freq: (Optional[int]) How many steps between evaluations; if None, do not evaluate. + :param n_eval_episodes: (int) How many episodes to play per evaluation + :param n_eval_episodes: (int) Number of episodes to rollout during evaluation. + :param log_path: (Optional[str]) Path to a folder where the evaluations will be saved + :return: (BaseCallback) A hybrid callback calling `callback` and performing evaluation. """ # Convert a list of callbacks into a callback if isinstance(callback, list): @@ -396,7 +404,7 @@ class BaseAlgorithm(ABC): def _setup_learn(self, total_timesteps: int, eval_env: Optional[GymEnv], - callback: Union[None, Callable, List[BaseCallback], BaseCallback] = None, + callback: MaybeCallback = None, eval_freq: int = 10000, n_eval_episodes: int = 5, log_path: Optional[str] = None, @@ -407,11 +415,11 @@ class BaseAlgorithm(ABC): Initialize different variables needed for training. :param total_timesteps: (int) The total number of samples (env steps) to train on - :param eval_env: (Optional[GymEnv]) - :param callback: (Union[None, BaseCallback, List[BaseCallback, Callable]]) + :param eval_env: (Optional[VecEnv]) Environment to use for evaluation. + :param callback: (MaybeCallback) Callback(s) called at every step with state of the algorithm. :param eval_freq: (int) How many steps between evaluations :param n_eval_episodes: (int) How many episodes to play per evaluation - :param log_path (Optional[str]): Path to a log folder + :param log_path: (Optional[str]) Path to a folder where the evaluations will be saved :param reset_num_timesteps: (bool) Whether to reset or not the ``num_timesteps`` attribute :param tb_log_name: (str) the name of the run for tensorboard log :return: (Tuple[int, BaseCallback]) @@ -480,8 +488,8 @@ class BaseAlgorithm(ABC): def save( self, path: Union[str, pathlib.Path, io.BufferedIOBase], - exclude: Optional[List[str]] = None, - include: Optional[List[str]] = None, + exclude: Optional[Iterable[str]] = None, + include: Optional[Iterable[str]] = None, ) -> None: """ Save all the attributes of the object and the model parameters in a zip-file. @@ -492,16 +500,15 @@ class BaseAlgorithm(ABC): """ # copy parameter list so we don't mutate the original dict data = self.__dict__.copy() - # use standard list of excluded parameters if none given - if exclude is None: - exclude = self.excluded_save_params() - else: - # append standard exclude params to the given params - exclude.extend([param for param in self.excluded_save_params() if param not in exclude]) - # do not exclude params if they are specifically included + # Exclude is union of specified parameters (if any) and standard exclusions + if exclude is None: + exclude = [] + exclude = set(exclude).union(self.excluded_save_params()) + + # Do not exclude params if they are specifically included if include is not None: - exclude = [param_name for param_name in exclude if param_name not in include] + exclude = exclude.difference(include) state_dicts_names, tensors_names = self.get_torch_variables() # any params that are in the save vars must not be saved by data @@ -509,12 +516,11 @@ class BaseAlgorithm(ABC): for torch_var in torch_variables: # we need to get only the name of the top most module as we'll remove that var_name = torch_var.split('.')[0] - exclude.append(var_name) + exclude.add(var_name) # Remove parameter entries of parameters which are to be excluded for param_name in exclude: - if param_name in data: - data.pop(param_name, None) + data.pop(param_name, None) # Build dict of tensor variables tensors = None diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index 951f163..6fa70f8 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -1,3 +1,6 @@ +"""Probability distributions.""" + +from abc import ABC, abstractmethod from typing import Optional, Tuple, Dict, Any, List import gym import torch as th @@ -8,36 +11,38 @@ from gym import spaces from stable_baselines3.common.preprocessing import get_action_dim -class Distribution(object): +class Distribution(ABC): + """Abstract base class for distributions.""" + def __init__(self): super(Distribution, self).__init__() + @abstractmethod def log_prob(self, x: th.Tensor) -> th.Tensor: """ - returns the log likelihood + Returns the log likelihood :param x: (th.Tensor) the taken action :return: (th.Tensor) The log likelihood of the distribution """ - raise NotImplementedError + @abstractmethod def entropy(self) -> Optional[th.Tensor]: """ Returns Shannon's entropy of the probability - :return: (Optional[th.Tensor]) the entropy, - return None if no analytical form is known + :return: (Optional[th.Tensor]) the entropy, or None if no analytical form is known """ - raise NotImplementedError + @abstractmethod def sample(self) -> th.Tensor: """ Returns a sample from the probability distribution :return: (th.Tensor) the stochastic action """ - raise NotImplementedError + @abstractmethod def mode(self) -> th.Tensor: """ Returns the most likely action (deterministic output) @@ -45,7 +50,6 @@ class Distribution(object): :return: (th.Tensor) the stochastic action """ - raise NotImplementedError def get_actions(self, deterministic: bool = False) -> th.Tensor: """ @@ -58,6 +62,7 @@ class Distribution(object): return self.mode() return self.sample() + @abstractmethod def actions_from_params(self, *args, **kwargs) -> th.Tensor: """ Returns samples from the probability distribution @@ -65,8 +70,8 @@ class Distribution(object): :return: (th.Tensor) actions """ - raise NotImplementedError + @abstractmethod def log_prob_from_params(self, *args, **kwargs) -> Tuple[th.Tensor, th.Tensor]: """ Returns samples and the associated log probabilities @@ -74,14 +79,12 @@ class Distribution(object): :return: (th.Tuple[th.Tensor, th.Tensor]) actions and log prob """ - raise NotImplementedError def sum_independent_dims(tensor: th.Tensor) -> th.Tensor: """ Continuous actions are usually considered to be independent, - so we can sum the components for the ``log_prob`` - or the entropy. + so we can sum components of the ``log_prob`` or the entropy. :param tensor: (th.Tensor) shape: (n_batch, n_actions) or (n_batch,) :return: (th.Tensor) shape: (n_batch,) @@ -95,8 +98,7 @@ def sum_independent_dims(tensor: th.Tensor) -> th.Tensor: class DiagGaussianDistribution(Distribution): """ - Gaussian distribution with diagonal covariance matrix, - for continuous actions. + Gaussian distribution with diagonal covariance matrix, for continuous actions. :param action_dim: (int) Dimension of the action space. """ @@ -115,7 +117,7 @@ class DiagGaussianDistribution(Distribution): one output will be the mean of the Gaussian, the other parameter will be the standard deviation (log std in fact to allow negative values) - :param latent_dim: (int) Dimension og the last layer of the policy (before the action layer) + :param latent_dim: (int) Dimension of the last layer of the policy (before the action layer) :param log_std_init: (float) Initial value for the log standard deviation :return: (nn.Linear, nn.Parameter) """ @@ -137,15 +139,26 @@ class DiagGaussianDistribution(Distribution): self.distribution = Normal(mean_actions, action_std) return self - def mode(self) -> th.Tensor: - return self.distribution.mean + def log_prob(self, actions: th.Tensor) -> th.Tensor: + """ + Get the log probabilities of actions according to the distribution. + Note that you must first call the ``proba_distribution()`` method. + + :param actions: (th.Tensor) + :return: (th.Tensor) + """ + log_prob = self.distribution.log_prob(actions) + return sum_independent_dims(log_prob) + + def entropy(self) -> th.Tensor: + return sum_independent_dims(self.distribution.entropy()) def sample(self) -> th.Tensor: # Reparametrization trick to pass gradients return self.distribution.rsample() - def entropy(self) -> th.Tensor: - return sum_independent_dims(self.distribution.entropy()) + def mode(self) -> th.Tensor: + return self.distribution.mean def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, @@ -168,22 +181,10 @@ class DiagGaussianDistribution(Distribution): log_prob = self.log_prob(actions) return actions, log_prob - def log_prob(self, actions: th.Tensor) -> th.Tensor: - """ - Get the log probabilities of actions according to the distribution. - Note that you must call ``proba_distribution()`` method before. - - :param actions: (th.Tensor) - :return: (th.Tensor) - """ - log_prob = self.distribution.log_prob(actions) - return sum_independent_dims(log_prob) - class SquashedDiagGaussianDistribution(DiagGaussianDistribution): """ - Gaussian distribution with diagonal covariance matrix, - followed by a squashing function (tanh) to ensure bounds. + Gaussian distribution with diagonal covariance matrix, followed by a squashing function (tanh) to ensure bounds. :param action_dim: (int) Dimension of the action space. :param epsilon: (float) small value to avoid NaN due to numerical imprecision. @@ -200,27 +201,6 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution): super(SquashedDiagGaussianDistribution, self).proba_distribution(mean_actions, log_std) return self - def mode(self) -> th.Tensor: - self.gaussian_actions = self.distribution.mean - # Squash the output - return th.tanh(self.gaussian_actions) - - def entropy(self) -> Optional[th.Tensor]: - # No analytical form, - # entropy needs to be estimated using -log_prob.mean() - return None - - def sample(self) -> th.Tensor: - # Reparametrization trick to pass gradients - self.gaussian_actions = self.distribution.rsample() - return th.tanh(self.gaussian_actions) - - def log_prob_from_params(self, mean_actions: th.Tensor, - log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: - action = self.actions_from_params(mean_actions, log_std) - log_prob = self.log_prob(action, self.gaussian_actions) - return action, log_prob - def log_prob(self, actions: th.Tensor, gaussian_actions: Optional[th.Tensor] = None) -> th.Tensor: # Inverse tanh @@ -237,6 +217,27 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution): log_prob -= th.sum(th.log(1 - actions ** 2 + self.epsilon), dim=1) return log_prob + def entropy(self) -> Optional[th.Tensor]: + # No analytical form, + # entropy needs to be estimated using -log_prob.mean() + return None + + def sample(self) -> th.Tensor: + # Reparametrization trick to pass gradients + self.gaussian_actions = super().sample() + return th.tanh(self.gaussian_actions) + + def mode(self) -> th.Tensor: + self.gaussian_actions = super().mode() + # Squash the output + return th.tanh(self.gaussian_actions) + + def log_prob_from_params(self, mean_actions: th.Tensor, + log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + action = self.actions_from_params(mean_actions, log_std) + log_prob = self.log_prob(action, self.gaussian_actions) + return action, log_prob + class CategoricalDistribution(Distribution): """ @@ -267,14 +268,17 @@ class CategoricalDistribution(Distribution): self.distribution = Categorical(logits=action_logits) return self - def mode(self) -> th.Tensor: - return th.argmax(self.distribution.probs, dim=1) + def log_prob(self, actions: th.Tensor) -> th.Tensor: + return self.distribution.log_prob(actions) + + def entropy(self) -> th.Tensor: + return self.distribution.entropy() def sample(self) -> th.Tensor: return self.distribution.sample() - def entropy(self) -> th.Tensor: - return self.distribution.entropy() + def mode(self) -> th.Tensor: + return th.argmax(self.distribution.probs, dim=1) def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor: @@ -287,9 +291,6 @@ class CategoricalDistribution(Distribution): log_prob = self.log_prob(actions) return actions, log_prob - def log_prob(self, actions: th.Tensor) -> th.Tensor: - return self.distribution.log_prob(actions) - class MultiCategoricalDistribution(Distribution): """ @@ -321,14 +322,19 @@ class MultiCategoricalDistribution(Distribution): self.distributions = [Categorical(logits=split) for split in th.split(action_logits, tuple(self.action_dims), dim=1)] return self - def mode(self) -> th.Tensor: - return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distributions], dim=1) + def log_prob(self, actions: th.Tensor) -> th.Tensor: + # Extract each discrete action and compute log prob for their respective distributions + return th.stack([dist.log_prob(action) for dist, action in zip(self.distributions, + th.unbind(actions, dim=1))], dim=1).sum(dim=1) + + def entropy(self) -> th.Tensor: + return th.stack([dist.entropy() for dist in self.distributions], dim=1).sum(dim=1) def sample(self) -> th.Tensor: return th.stack([dist.sample() for dist in self.distributions], dim=1) - def entropy(self) -> th.Tensor: - return th.stack([dist.entropy() for dist in self.distributions], dim=1).sum(dim=1) + def mode(self) -> th.Tensor: + return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distributions], dim=1) def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor: @@ -341,11 +347,6 @@ class MultiCategoricalDistribution(Distribution): log_prob = self.log_prob(actions) return actions, log_prob - def log_prob(self, actions: th.Tensor) -> th.Tensor: - # Extract each discrete action and compute log prob for their respective distributions - return th.stack([dist.log_prob(action) for dist, action in zip(self.distributions, - th.unbind(actions, dim=1))], dim=1).sum(dim=1) - class BernoulliDistribution(Distribution): """ @@ -375,14 +376,17 @@ class BernoulliDistribution(Distribution): self.distribution = Bernoulli(logits=action_logits) return self - def mode(self) -> th.Tensor: - return th.round(self.distribution.probs) + def log_prob(self, actions: th.Tensor) -> th.Tensor: + return self.distribution.log_prob(actions).sum(dim=1) + + def entropy(self) -> th.Tensor: + return self.distribution.entropy().sum(dim=1) def sample(self) -> th.Tensor: return self.distribution.sample() - def entropy(self) -> th.Tensor: - return self.distribution.entropy().sum(dim=1) + def mode(self) -> th.Tensor: + return th.round(self.distribution.probs) def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor: @@ -395,9 +399,6 @@ class BernoulliDistribution(Distribution): log_prob = self.log_prob(actions) return actions, log_prob - def log_prob(self, actions: th.Tensor) -> th.Tensor: - return self.distribution.log_prob(actions).sum(dim=1) - class StateDependentNoiseDistribution(Distribution): """ @@ -414,7 +415,7 @@ class StateDependentNoiseDistribution(Distribution): a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. :param squash_output: (bool) Whether to squash the output using a tanh function, - this allows to ensure boundaries. + this ensures bounds are satisfied. :param learn_features: (bool) Whether to learn features for gSDE or not. This will enable gradients to be backpropagated through the features ``latent_sde`` in the code. @@ -529,6 +530,35 @@ class StateDependentNoiseDistribution(Distribution): self.distribution = Normal(mean_actions, th.sqrt(variance + self.epsilon)) return self + def log_prob(self, actions: th.Tensor) -> th.Tensor: + if self.bijector is not None: + gaussian_actions = self.bijector.inverse(actions) + else: + gaussian_actions = actions + # log likelihood for a gaussian + log_prob = self.distribution.log_prob(gaussian_actions) + # Sum along action dim + log_prob = sum_independent_dims(log_prob) + + if self.bijector is not None: + # Squash correction (from original SAC implementation) + log_prob -= th.sum(self.bijector.log_prob_correction(gaussian_actions), dim=1) + return log_prob + + def entropy(self) -> Optional[th.Tensor]: + if self.bijector is not None: + # No analytical form, + # entropy needs to be estimated using -log_prob.mean() + return None + return sum_independent_dims(self.distribution.entropy()) + + def sample(self) -> th.Tensor: + noise = self.get_noise(self._latent_sde) + actions = self.distribution.mean + noise + if self.bijector is not None: + return self.bijector.forward(actions) + return actions + def mode(self) -> th.Tensor: actions = self.distribution.mean if self.bijector is not None: @@ -547,20 +577,6 @@ class StateDependentNoiseDistribution(Distribution): noise = th.bmm(latent_sde, self.exploration_matrices) return noise.squeeze(1) - def sample(self) -> th.Tensor: - noise = self.get_noise(self._latent_sde) - actions = self.distribution.mean + noise - if self.bijector is not None: - return self.bijector.forward(actions) - return actions - - def entropy(self) -> Optional[th.Tensor]: - # No analytical form, - # entropy needs to be estimated using -log_prob.mean() - if self.bijector is not None: - return None - return sum_independent_dims(self.distribution.entropy()) - def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor, @@ -576,21 +592,6 @@ class StateDependentNoiseDistribution(Distribution): log_prob = self.log_prob(actions) return actions, log_prob - def log_prob(self, actions: th.Tensor) -> th.Tensor: - if self.bijector is not None: - gaussian_actions = self.bijector.inverse(actions) - else: - gaussian_actions = actions - # log likelihood for a gaussian - log_prob = self.distribution.log_prob(gaussian_actions) - # Sum along action dim - log_prob = sum_independent_dims(log_prob) - - if self.bijector is not None: - # Squash correction (from original SAC implementation) - log_prob -= th.sum(self.bijector.log_prob_correction(gaussian_actions), dim=1) - return log_prob - class TanhBijector(object): """ @@ -653,9 +654,8 @@ def make_proba_distribution(action_space: gym.spaces.Space, if isinstance(action_space, spaces.Box): assert len(action_space.shape) == 1, "Error: the action space must be a vector" - if use_sde: - return StateDependentNoiseDistribution(get_action_dim(action_space), **dist_kwargs) - return DiagGaussianDistribution(get_action_dim(action_space), **dist_kwargs) + cls = StateDependentNoiseDistribution if use_sde else DiagGaussianDistribution + return cls(get_action_dim(action_space), **dist_kwargs) elif isinstance(action_space, spaces.Discrete): return CategoricalDistribution(action_space.n, **dist_kwargs) elif isinstance(action_space, spaces.MultiDiscrete): diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index f459de9..b7d92ad 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -135,7 +135,7 @@ class OffPolicyAlgorithm(BaseAlgorithm): self.action_space, self.device, optimize_memory_usage=self.optimize_memory_usage) self.policy = self.policy_class(self.observation_space, self.action_space, - self.lr_schedule, **self.policy_kwargs) + self.lr_schedule, **self.policy_kwargs) # pytype:disable=not-instantiable self.policy = self.policy.to(self.device) def save_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None: diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 9090d78..2937b77 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -96,7 +96,7 @@ class OnPolicyAlgorithm(BaseAlgorithm): n_envs=self.n_envs) self.policy = self.policy_class(self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, device=self.device, - **self.policy_kwargs) + **self.policy_kwargs) # pytype:disable=not-instantiable self.policy = self.policy.to(self.device) def collect_rollouts(self, diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index ba4797f..08b0ab2 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -1,3 +1,7 @@ +"""Policies: abstract base class and concrete implementations.""" + +from abc import ABC, abstractmethod +import collections from typing import Union, Type, Dict, List, Tuple, Optional, Any, Callable from functools import partial @@ -17,7 +21,7 @@ from stable_baselines3.common.distributions import (make_proba_distribution, Dis StateDependentNoiseDistribution) -class BasePolicy(nn.Module): +class BasePolicy(nn.Module, ABC): """ The base policy object @@ -98,17 +102,22 @@ class BasePolicy(nn.Module): module.bias.data.fill_(0.0) @staticmethod - def _dummy_schedule(_progress_remaining: float) -> float: + def _dummy_schedule(progress_remaining: float) -> float: """ (float) Useful for pickling policy.""" + del progress_remaining return 0.0 - def forward(self, *_args, **kwargs): - raise NotImplementedError() + @abstractmethod + def forward(self, *args, **kwargs): + del args, kwargs def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: """ Get the action according to the policy for a given observation. + By default provides a dummy implementation -- not all BasePolicy classes + implement this, e.g. if they are a Critic in an Actor-Critic method. + :param observation: (th.Tensor) :param deterministic: (bool) Whether to use stochastic or deterministic actions :return: (th.Tensor) Taken action according to the policy @@ -140,10 +149,8 @@ class BasePolicy(nn.Module): # Handle the different cases for images # as PyTorch use channel first format if is_image_space(self.observation_space): - if (observation.shape == self.observation_space.shape + if not (observation.shape == self.observation_space.shape or observation.shape[1:] == self.observation_space.shape): - pass - else: # Try to re-order the channels transpose_obs = VecTransposeImage.transpose_image(observation) if (transpose_obs.shape == self.observation_space.shape @@ -160,21 +167,21 @@ class BasePolicy(nn.Module): # Convert to numpy actions = actions.cpu().numpy() - # Rescale to proper domain when using squashing - if isinstance(self.action_space, gym.spaces.Box) and self.squash_output: - actions = self.unscale_action(actions) - - clipped_actions = actions - # Clip the actions to avoid out of bound error when using gaussian distribution - if isinstance(self.action_space, gym.spaces.Box) and not self.squash_output: - clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high) + if isinstance(self.action_space, gym.spaces.Box): + if self.squash_output: + # Rescale to proper domain when using squashing + actions = self.unscale_action(actions) + else: + # Actions could be on arbitrary scale, so clip the actions to avoid + # out of bound error (e.g. if sampling from a Gaussian distribution) + actions = np.clip(actions, self.action_space.low, self.action_space.high) if not vectorized_env: if state is not None: raise ValueError("Error: The environment must be vectorized when using recurrent policies.") - clipped_actions = clipped_actions[0] + actions = actions[0] - return clipped_actions, state + return actions, state def scale_action(self, action: np.ndarray) -> np.ndarray: """ @@ -227,7 +234,7 @@ class BasePolicy(nn.Module): Load policy from path. :param path: (str) - :param device: ( Union[th.device, str]) Device on which the policy should be loaded. + :param device: (Union[th.device, str]) Device on which the policy should be loaded. :return: (BasePolicy) """ device = get_device(device) @@ -294,7 +301,7 @@ class ActorCriticPolicy(BasePolicy): def __init__(self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, - lr_schedule: Callable, + lr_schedule: Callable[[float], float], net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, device: Union[th.device, str] = 'auto', activation_fn: Type[nn.Module] = nn.Tanh, @@ -313,7 +320,7 @@ class ActorCriticPolicy(BasePolicy): if optimizer_kwargs is None: optimizer_kwargs = {} - # Small values to avoid NaN in ADAM optimizer + # Small values to avoid NaN in Adam optimizer if optimizer_class == th.optim.Adam: optimizer_kwargs['eps'] = 1e-5 @@ -366,15 +373,17 @@ class ActorCriticPolicy(BasePolicy): def _get_data(self) -> Dict[str, Any]: data = super()._get_data() + default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None) + data.update(dict( net_arch=self.net_arch, activation_fn=self.activation_fn, use_sde=self.use_sde, log_std_init=self.log_std_init, - squash_output=self.dist_kwargs['squash_output'] if self.dist_kwargs else None, - full_std=self.dist_kwargs['full_std'] if self.dist_kwargs else None, - sde_net_arch=self.dist_kwargs['sde_net_arch'] if self.dist_kwargs else None, - use_expln=self.dist_kwargs['use_expln'] if self.dist_kwargs else None, + squash_output=default_none_kwargs['squash_output'], + full_std=default_none_kwargs['full_std'], + sde_net_arch=default_none_kwargs['sde_net_arch'], + use_expln=default_none_kwargs['use_expln'], lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone ortho_init=self.ortho_init, optimizer_class=self.optimizer_class, @@ -394,7 +403,7 @@ class ActorCriticPolicy(BasePolicy): StateDependentNoiseDistribution), 'reset_noise() is only available when using gSDE' self.action_dist.sample_weights(self.log_std, batch_size=n_envs) - def _build(self, lr_schedule: Callable) -> None: + def _build(self, lr_schedule: Callable[[float], float]) -> None: """ Create the networks and the optimizer. @@ -719,10 +728,10 @@ def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[ :return: (Type[BasePolicy]) the policy """ if base_policy_type not in _policy_registry: - raise ValueError(f"Error: the policy type {base_policy_type} is not registered!") + raise KeyError(f"Error: the policy type {base_policy_type} is not registered!") if name not in _policy_registry[base_policy_type]: - raise ValueError(f"Error: unknown policy type {name}," - f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!") + raise KeyError(f"Error: unknown policy type {name}," + f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!") return _policy_registry[base_policy_type][name] diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index 139cd60..5ef94f9 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -1,6 +1,5 @@ -""" -Common aliases for type hint -""" +"""Common aliases for type hints""" + from typing import Union, Dict, Any, NamedTuple, List, Callable, Tuple import numpy as np