From 2affbd6856360e613e4dbbb5673f87918b5a650f Mon Sep 17 00:00:00 2001 From: Adam Gleave Date: Thu, 2 Jul 2020 18:14:54 -0700 Subject: [PATCH 01/13] Fix linting and make it play nicely with venv --- .gitignore | 1 + Makefile | 5 +++-- docs/conf.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 43b627c..9f54889 100644 --- a/.gitignore +++ b/.gitignore @@ -22,6 +22,7 @@ keys/ # Virtualenv /env +/venv *.sublime-project diff --git a/Makefile b/Makefile index e874ee7..05ce4d2 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,5 @@ SHELL=/bin/bash +LINT_PATHS=stable_baselines3/ tests/ docs/conf.py setup.py pytest: ./scripts/run_tests.sh @@ -9,9 +10,9 @@ type: lint: # stop the build if there are Python syntax errors or undefined names # see https://lintlyci.github.io/Flake8Rules/ - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + flake8 ${LINT_PATHS} --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. - flake8 . --count --exit-zero --statistics + flake8 ${LINT_PATHS} --count --exit-zero --statistics doc: cd docs && make html diff --git a/docs/conf.py b/docs/conf.py index 5d74970..78138f6 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -19,7 +19,7 @@ from unittest.mock import MagicMock # We CANNOT enable 'sphinxcontrib.spelling' because ReadTheDocs.org does not support # PyEnchant. try: - import sphinxcontrib.spelling + import sphinxcontrib.spelling # noqa: F401 enable_spell_check = True except ImportError: enable_spell_check = False @@ -129,6 +129,7 @@ html_logo = '_static/img/logo.png' def setup(app): app.add_stylesheet("css/baselines_theme.css") + # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. From 1f0443f332821323403548f5994b2700eb943f23 Mon Sep 17 00:00:00 2001 From: Adam Gleave Date: Thu, 2 Jul 2020 18:49:59 -0700 Subject: [PATCH 02/13] Review base_class --- stable_baselines3/common/base_class.py | 100 +++++++++++++------------ 1 file changed, 53 insertions(+), 47 deletions(-) diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 4453231..f66220e 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -1,5 +1,7 @@ +"""Abstract base classes for RL algorithms.""" + import time -from typing import Union, Type, Optional, Dict, Any, List, Tuple, Callable +from typing import Union, Type, Optional, Dict, Any, Iterable, List, Tuple, Callable from abc import ABC, abstractmethod from collections import deque import pathlib @@ -23,12 +25,30 @@ from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.noise import ActionNoise +def maybe_make_env(env: Union[GymEnv, str, None], monitor_wrapper: bool, verbose: int) -> Optional[GymEnv]: + """If env is a string, make the environment; otherwise, return env. + + :param env: (Union[GymEnv, str, None]) The environment to learn from. + :param monitor_wrapper: (bool) Whether to wrap env in a Monitor when creating env. + :param verbose: (int) logging verbosity + :return A Gym (vector) environment. + """ + if isinstance(env, str): + if verbose >= 1: + print(f"Creating environment from the given name '{env}'") + env = gym.make(env) + if monitor_wrapper: + env = Monitor(env, filename=None) + + return env + + class BaseAlgorithm(ABC): """ The base of RL algorithms :param policy: (Type[BasePolicy]) Policy object - :param env: (Union[GymEnv, str]) The environment to learn from + :param env: (Union[GymEnv, str, None]) The environment to learn from (if registered in Gym, can be str. Can be None for loading trained models) :param policy_base: (Type[BasePolicy]) The base policy used by this method :param learning_rate: (float or callable) learning rate for the optimizer, @@ -54,7 +74,7 @@ class BaseAlgorithm(ABC): def __init__(self, policy: Type[BasePolicy], - env: Union[GymEnv, str], + env: Union[GymEnv, str, None], policy_base: Type[BasePolicy], learning_rate: Union[float, Callable], policy_kwargs: Dict[str, Any] = None, @@ -116,18 +136,9 @@ class BaseAlgorithm(ABC): if env is not None: if isinstance(env, str): if create_eval_env: - eval_env = gym.make(env) - if monitor_wrapper: - eval_env = Monitor(eval_env, filename=None) - self.eval_env = DummyVecEnv([lambda: eval_env]) - if self.verbose >= 1: - print("Creating environment from the given name, wrapped in a DummyVecEnv.") - - env = gym.make(env) - if monitor_wrapper: - env = Monitor(env, filename=None) - env = DummyVecEnv([lambda: env]) + self.eval_env = maybe_make_env(env, monitor_wrapper, self.verbose) + env = maybe_make_env(env, monitor_wrapper, self.verbose) env = self._wrap_env(env) self.observation_space = env.observation_space @@ -136,8 +147,8 @@ class BaseAlgorithm(ABC): self.env = env if not support_multi_env and self.n_envs > 1: - raise ValueError("Error: the model does not support multiple envs requires a single vectorized" - " environment.") + raise ValueError("Error: the model does not support multiple envs; it requires " + "a single vectorized environment.") def _wrap_env(self, env: GymEnv) -> VecEnv: if not isinstance(env, VecEnv): @@ -153,10 +164,7 @@ class BaseAlgorithm(ABC): @abstractmethod def _setup_model(self) -> None: - """ - Create networks, buffer and optimizers - """ - raise NotImplementedError() + """Create networks, buffer and optimizers.""" def _get_eval_env(self, eval_env: Optional[GymEnv]) -> Optional[GymEnv]: """ @@ -238,7 +246,7 @@ class BaseAlgorithm(ABC): def get_torch_variables(self) -> Tuple[List[str], List[str]]: """ - Get the name of the torch variable that will be saved. + Get the name of the torch variables that will be saved. ``th.save`` and ``th.load`` will be used with the right device instead of the default pickling strategy. @@ -263,10 +271,9 @@ class BaseAlgorithm(ABC): Return a trained model. :param total_timesteps: (int) The total number of samples (env steps) to train on - :param callback: (function (dict, dict)) -> boolean function called at every steps with state of the algorithm. - It takes the local and global variables. If it returns False, training is aborted. + :param callback: (MaybeCallback) callback(s) called at every step with state of the algorithm. :param log_interval: (int) The number of timesteps before logging. - :param tb_log_name: (str) the name of the run for tensorboard log + :param tb_log_name: (str) the name of the run for TensorBoard logging :param eval_env: (gym.Env) Environment that will be used to evaluate the agent :param eval_freq: (int) Evaluate the agent every ``eval_freq`` timesteps (this may vary a little) :param n_eval_episodes: (int) Number of episode to evaluate the agent @@ -274,7 +281,6 @@ class BaseAlgorithm(ABC): :param reset_num_timesteps: (bool) whether or not to reset the current timestep number (used in logging) :return: (BaseAlgorithm) the trained model """ - raise NotImplementedError() def predict(self, observation: np.ndarray, state: Optional[np.ndarray] = None, @@ -329,8 +335,6 @@ class BaseAlgorithm(ABC): # load parameters model.__dict__.update(data) model.__dict__.update(kwargs) - if not hasattr(model, "_setup_model") and len(params) > 0: - raise NotImplementedError(f"{cls} has no ``_setup_model()`` method") model._setup_model() # put state_dicts back in place @@ -366,14 +370,18 @@ class BaseAlgorithm(ABC): self.eval_env.seed(seed) def _init_callback(self, - callback: Union[None, Callable, List[BaseCallback], BaseCallback], + callback: MaybeCallback, eval_env: Optional[VecEnv] = None, eval_freq: int = 10000, n_eval_episodes: int = 5, log_path: Optional[str] = None) -> BaseCallback: """ - :param callback: (Union[callable, [BaseCallback], BaseCallback, None]) - :return: (BaseCallback) + :param callback: (MaybeCallback) Callback(s) called at every step with state of the algorithm. + :param eval_freq: (Optional[int]) How many steps between evaluations; if None, do not evaluate. + :param n_eval_episodes: (int) How many episodes to play per evaluation + :param n_eval_episodes: (int) Number of episodes to rollout during evaluation. + :param log_path: (Optional[str]) Path to a folder where the evaluations will be saved + :return: (BaseCallback) A hybrid callback calling `callback` and performing evaluation. """ # Convert a list of callbacks into a callback if isinstance(callback, list): @@ -396,7 +404,7 @@ class BaseAlgorithm(ABC): def _setup_learn(self, total_timesteps: int, eval_env: Optional[GymEnv], - callback: Union[None, Callable, List[BaseCallback], BaseCallback] = None, + callback: MaybeCallback = None, eval_freq: int = 10000, n_eval_episodes: int = 5, log_path: Optional[str] = None, @@ -407,11 +415,11 @@ class BaseAlgorithm(ABC): Initialize different variables needed for training. :param total_timesteps: (int) The total number of samples (env steps) to train on - :param eval_env: (Optional[GymEnv]) - :param callback: (Union[None, BaseCallback, List[BaseCallback, Callable]]) + :param eval_env: (Optional[VecEnv]) Environment to use for evaluation. + :param callback: (MaybeCallback) Callback(s) called at every step with state of the algorithm. :param eval_freq: (int) How many steps between evaluations :param n_eval_episodes: (int) How many episodes to play per evaluation - :param log_path (Optional[str]): Path to a log folder + :param log_path: (Optional[str]) Path to a folder where the evaluations will be saved :param reset_num_timesteps: (bool) Whether to reset or not the ``num_timesteps`` attribute :param tb_log_name: (str) the name of the run for tensorboard log :return: (Tuple[int, BaseCallback]) @@ -480,8 +488,8 @@ class BaseAlgorithm(ABC): def save( self, path: Union[str, pathlib.Path, io.BufferedIOBase], - exclude: Optional[List[str]] = None, - include: Optional[List[str]] = None, + exclude: Optional[Iterable[str]] = None, + include: Optional[Iterable[str]] = None, ) -> None: """ Save all the attributes of the object and the model parameters in a zip-file. @@ -492,16 +500,15 @@ class BaseAlgorithm(ABC): """ # copy parameter list so we don't mutate the original dict data = self.__dict__.copy() - # use standard list of excluded parameters if none given - if exclude is None: - exclude = self.excluded_save_params() - else: - # append standard exclude params to the given params - exclude.extend([param for param in self.excluded_save_params() if param not in exclude]) - # do not exclude params if they are specifically included + # Exclude is union of specified parameters (if any) and standard exclusions + if exclude is None: + exclude = [] + exclude = set(exclude).union(self.excluded_save_params()) + + # Do not exclude params if they are specifically included if include is not None: - exclude = [param_name for param_name in exclude if param_name not in include] + exclude = exclude.difference(include) state_dicts_names, tensors_names = self.get_torch_variables() # any params that are in the save vars must not be saved by data @@ -509,12 +516,11 @@ class BaseAlgorithm(ABC): for torch_var in torch_variables: # we need to get only the name of the top most module as we'll remove that var_name = torch_var.split('.')[0] - exclude.append(var_name) + exclude.add(var_name) # Remove parameter entries of parameters which are to be excluded for param_name in exclude: - if param_name in data: - data.pop(param_name, None) + data.pop(param_name, None) # Build dict of tensor variables tensors = None From 56fd89da8d849f01358b522b9bbdfb6961bbcf53 Mon Sep 17 00:00:00 2001 From: Adam Gleave Date: Thu, 2 Jul 2020 18:51:10 -0700 Subject: [PATCH 03/13] Review type aliases --- stable_baselines3/common/type_aliases.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index 139cd60..5ef94f9 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -1,6 +1,5 @@ -""" -Common aliases for type hint -""" +"""Common aliases for type hints""" + from typing import Union, Dict, Any, NamedTuple, List, Callable, Tuple import numpy as np From 7ba48dce4817e15fe3bb6727fb052eca990e2248 Mon Sep 17 00:00:00 2001 From: Adam Gleave Date: Thu, 2 Jul 2020 19:18:51 -0700 Subject: [PATCH 04/13] Review distributions --- stable_baselines3/common/distributions.py | 218 +++++++++++----------- 1 file changed, 109 insertions(+), 109 deletions(-) diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index 951f163..068bba5 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -1,3 +1,6 @@ +"""Probability distributions.""" + +from abc import ABC, abstractmethod from typing import Optional, Tuple, Dict, Any, List import gym import torch as th @@ -8,36 +11,38 @@ from gym import spaces from stable_baselines3.common.preprocessing import get_action_dim -class Distribution(object): +class Distribution(ABC): + """Abstract base class for distributions.""" + def __init__(self): super(Distribution, self).__init__() + @abstractmethod def log_prob(self, x: th.Tensor) -> th.Tensor: """ - returns the log likelihood + Returns the log likelihood :param x: (th.Tensor) the taken action :return: (th.Tensor) The log likelihood of the distribution """ - raise NotImplementedError + @abstractmethod def entropy(self) -> Optional[th.Tensor]: """ Returns Shannon's entropy of the probability - :return: (Optional[th.Tensor]) the entropy, - return None if no analytical form is known + :return: (Optional[th.Tensor]) the entropy, or None if no analytical form is known """ - raise NotImplementedError + @abstractmethod def sample(self) -> th.Tensor: """ Returns a sample from the probability distribution :return: (th.Tensor) the stochastic action """ - raise NotImplementedError + @abstractmethod def mode(self) -> th.Tensor: """ Returns the most likely action (deterministic output) @@ -45,7 +50,6 @@ class Distribution(object): :return: (th.Tensor) the stochastic action """ - raise NotImplementedError def get_actions(self, deterministic: bool = False) -> th.Tensor: """ @@ -58,6 +62,7 @@ class Distribution(object): return self.mode() return self.sample() + @abstractmethod def actions_from_params(self, *args, **kwargs) -> th.Tensor: """ Returns samples from the probability distribution @@ -65,8 +70,8 @@ class Distribution(object): :return: (th.Tensor) actions """ - raise NotImplementedError + @abstractmethod def log_prob_from_params(self, *args, **kwargs) -> Tuple[th.Tensor, th.Tensor]: """ Returns samples and the associated log probabilities @@ -74,14 +79,12 @@ class Distribution(object): :return: (th.Tuple[th.Tensor, th.Tensor]) actions and log prob """ - raise NotImplementedError def sum_independent_dims(tensor: th.Tensor) -> th.Tensor: """ Continuous actions are usually considered to be independent, - so we can sum the components for the ``log_prob`` - or the entropy. + so we can sum components of the ``log_prob`` or the entropy. :param tensor: (th.Tensor) shape: (n_batch, n_actions) or (n_batch,) :return: (th.Tensor) shape: (n_batch,) @@ -95,8 +98,7 @@ def sum_independent_dims(tensor: th.Tensor) -> th.Tensor: class DiagGaussianDistribution(Distribution): """ - Gaussian distribution with diagonal covariance matrix, - for continuous actions. + Gaussian distribution with diagonal covariance matrix, for continuous actions. :param action_dim: (int) Dimension of the action space. """ @@ -115,7 +117,7 @@ class DiagGaussianDistribution(Distribution): one output will be the mean of the Gaussian, the other parameter will be the standard deviation (log std in fact to allow negative values) - :param latent_dim: (int) Dimension og the last layer of the policy (before the action layer) + :param latent_dim: (int) Dimension of the last layer of the policy (before the action layer) :param log_std_init: (float) Initial value for the log standard deviation :return: (nn.Linear, nn.Parameter) """ @@ -137,15 +139,26 @@ class DiagGaussianDistribution(Distribution): self.distribution = Normal(mean_actions, action_std) return self - def mode(self) -> th.Tensor: - return self.distribution.mean + def log_prob(self, actions: th.Tensor) -> th.Tensor: + """ + Get the log probabilities of actions according to the distribution. + Note that you must first call the ``proba_distribution()`` method. + + :param actions: (th.Tensor) + :return: (th.Tensor) + """ + log_prob = self.distribution.log_prob(actions) + return sum_independent_dims(log_prob) + + def entropy(self) -> th.Tensor: + return sum_independent_dims(self.distribution.entropy()) def sample(self) -> th.Tensor: # Reparametrization trick to pass gradients return self.distribution.rsample() - def entropy(self) -> th.Tensor: - return sum_independent_dims(self.distribution.entropy()) + def mode(self) -> th.Tensor: + return self.distribution.mean def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, @@ -168,22 +181,10 @@ class DiagGaussianDistribution(Distribution): log_prob = self.log_prob(actions) return actions, log_prob - def log_prob(self, actions: th.Tensor) -> th.Tensor: - """ - Get the log probabilities of actions according to the distribution. - Note that you must call ``proba_distribution()`` method before. - - :param actions: (th.Tensor) - :return: (th.Tensor) - """ - log_prob = self.distribution.log_prob(actions) - return sum_independent_dims(log_prob) - class SquashedDiagGaussianDistribution(DiagGaussianDistribution): """ - Gaussian distribution with diagonal covariance matrix, - followed by a squashing function (tanh) to ensure bounds. + Gaussian distribution with diagonal covariance matrix, followed by a squashing function (tanh) to ensure bounds. :param action_dim: (int) Dimension of the action space. :param epsilon: (float) small value to avoid NaN due to numerical imprecision. @@ -200,27 +201,6 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution): super(SquashedDiagGaussianDistribution, self).proba_distribution(mean_actions, log_std) return self - def mode(self) -> th.Tensor: - self.gaussian_actions = self.distribution.mean - # Squash the output - return th.tanh(self.gaussian_actions) - - def entropy(self) -> Optional[th.Tensor]: - # No analytical form, - # entropy needs to be estimated using -log_prob.mean() - return None - - def sample(self) -> th.Tensor: - # Reparametrization trick to pass gradients - self.gaussian_actions = self.distribution.rsample() - return th.tanh(self.gaussian_actions) - - def log_prob_from_params(self, mean_actions: th.Tensor, - log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: - action = self.actions_from_params(mean_actions, log_std) - log_prob = self.log_prob(action, self.gaussian_actions) - return action, log_prob - def log_prob(self, actions: th.Tensor, gaussian_actions: Optional[th.Tensor] = None) -> th.Tensor: # Inverse tanh @@ -237,6 +217,27 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution): log_prob -= th.sum(th.log(1 - actions ** 2 + self.epsilon), dim=1) return log_prob + def entropy(self) -> Optional[th.Tensor]: + # No analytical form, + # entropy needs to be estimated using -log_prob.mean() + return None + + def sample(self) -> th.Tensor: + # Reparametrization trick to pass gradients + self.gaussian_actions = super().sample() + return th.tanh(self.gaussian_actions) + + def mode(self) -> th.Tensor: + self.gaussian_actions = super().mode() + # Squash the output + return th.tanh(self.gaussian_actions) + + def log_prob_from_params(self, mean_actions: th.Tensor, + log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + action = self.actions_from_params(mean_actions, log_std) + log_prob = self.log_prob(action, self.gaussian_actions) + return action, log_prob + class CategoricalDistribution(Distribution): """ @@ -267,14 +268,17 @@ class CategoricalDistribution(Distribution): self.distribution = Categorical(logits=action_logits) return self - def mode(self) -> th.Tensor: - return th.argmax(self.distribution.probs, dim=1) + def log_prob(self, actions: th.Tensor) -> th.Tensor: + return self.distribution.log_prob(actions) + + def entropy(self) -> th.Tensor: + return self.distribution.entropy() def sample(self) -> th.Tensor: return self.distribution.sample() - def entropy(self) -> th.Tensor: - return self.distribution.entropy() + def mode(self) -> th.Tensor: + return th.argmax(self.distribution.probs, dim=1) def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor: @@ -287,9 +291,6 @@ class CategoricalDistribution(Distribution): log_prob = self.log_prob(actions) return actions, log_prob - def log_prob(self, actions: th.Tensor) -> th.Tensor: - return self.distribution.log_prob(actions) - class MultiCategoricalDistribution(Distribution): """ @@ -321,14 +322,19 @@ class MultiCategoricalDistribution(Distribution): self.distributions = [Categorical(logits=split) for split in th.split(action_logits, tuple(self.action_dims), dim=1)] return self - def mode(self) -> th.Tensor: - return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distributions], dim=1) + def log_prob(self, actions: th.Tensor) -> th.Tensor: + # Extract each discrete action and compute log prob for their respective distributions + return th.stack([dist.log_prob(action) for dist, action in zip(self.distributions, + th.unbind(actions, dim=1))], dim=1).sum(dim=1) + + def entropy(self) -> th.Tensor: + return th.stack([dist.entropy() for dist in self.distributions], dim=1).sum(dim=1) def sample(self) -> th.Tensor: return th.stack([dist.sample() for dist in self.distributions], dim=1) - def entropy(self) -> th.Tensor: - return th.stack([dist.entropy() for dist in self.distributions], dim=1).sum(dim=1) + def mode(self) -> th.Tensor: + return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distributions], dim=1) def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor: @@ -341,11 +347,6 @@ class MultiCategoricalDistribution(Distribution): log_prob = self.log_prob(actions) return actions, log_prob - def log_prob(self, actions: th.Tensor) -> th.Tensor: - # Extract each discrete action and compute log prob for their respective distributions - return th.stack([dist.log_prob(action) for dist, action in zip(self.distributions, - th.unbind(actions, dim=1))], dim=1).sum(dim=1) - class BernoulliDistribution(Distribution): """ @@ -375,14 +376,17 @@ class BernoulliDistribution(Distribution): self.distribution = Bernoulli(logits=action_logits) return self - def mode(self) -> th.Tensor: - return th.round(self.distribution.probs) + def log_prob(self, actions: th.Tensor) -> th.Tensor: + return self.distribution.log_prob(actions).sum(dim=1) + + def entropy(self) -> th.Tensor: + return self.distribution.entropy().sum(dim=1) def sample(self) -> th.Tensor: return self.distribution.sample() - def entropy(self) -> th.Tensor: - return self.distribution.entropy().sum(dim=1) + def mode(self) -> th.Tensor: + return th.round(self.distribution.probs) def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor: @@ -395,9 +399,6 @@ class BernoulliDistribution(Distribution): log_prob = self.log_prob(actions) return actions, log_prob - def log_prob(self, actions: th.Tensor) -> th.Tensor: - return self.distribution.log_prob(actions).sum(dim=1) - class StateDependentNoiseDistribution(Distribution): """ @@ -414,7 +415,7 @@ class StateDependentNoiseDistribution(Distribution): a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. :param squash_output: (bool) Whether to squash the output using a tanh function, - this allows to ensure boundaries. + this ensures bounds are satisfied. :param learn_features: (bool) Whether to learn features for gSDE or not. This will enable gradients to be backpropagated through the features ``latent_sde`` in the code. @@ -529,6 +530,35 @@ class StateDependentNoiseDistribution(Distribution): self.distribution = Normal(mean_actions, th.sqrt(variance + self.epsilon)) return self + def log_prob(self, actions: th.Tensor) -> th.Tensor: + if self.bijector is not None: + gaussian_actions = self.bijector.inverse(actions) + else: + gaussian_actions = actions + # log likelihood for a gaussian + log_prob = self.distribution.log_prob(gaussian_actions) + # Sum along action dim + log_prob = sum_independent_dims(log_prob) + + if self.bijector is not None: + # Squash correction (from original SAC implementation) + log_prob -= th.sum(self.bijector.log_prob_correction(gaussian_actions), dim=1) + return log_prob + + def entropy(self) -> Optional[th.Tensor]: + if self.bijector is not None: + # No analytical form, + # entropy needs to be estimated using -log_prob.mean() + return None + return sum_independent_dims(self.distribution.entropy()) + + def sample(self) -> th.Tensor: + noise = self.get_noise(self._latent_sde) + actions = self.distribution.mean + noise + if self.bijector is not None: + return self.bijector.forward(actions) + return actions + def mode(self) -> th.Tensor: actions = self.distribution.mean if self.bijector is not None: @@ -547,20 +577,6 @@ class StateDependentNoiseDistribution(Distribution): noise = th.bmm(latent_sde, self.exploration_matrices) return noise.squeeze(1) - def sample(self) -> th.Tensor: - noise = self.get_noise(self._latent_sde) - actions = self.distribution.mean + noise - if self.bijector is not None: - return self.bijector.forward(actions) - return actions - - def entropy(self) -> Optional[th.Tensor]: - # No analytical form, - # entropy needs to be estimated using -log_prob.mean() - if self.bijector is not None: - return None - return sum_independent_dims(self.distribution.entropy()) - def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor, @@ -576,21 +592,6 @@ class StateDependentNoiseDistribution(Distribution): log_prob = self.log_prob(actions) return actions, log_prob - def log_prob(self, actions: th.Tensor) -> th.Tensor: - if self.bijector is not None: - gaussian_actions = self.bijector.inverse(actions) - else: - gaussian_actions = actions - # log likelihood for a gaussian - log_prob = self.distribution.log_prob(gaussian_actions) - # Sum along action dim - log_prob = sum_independent_dims(log_prob) - - if self.bijector is not None: - # Squash correction (from original SAC implementation) - log_prob -= th.sum(self.bijector.log_prob_correction(gaussian_actions), dim=1) - return log_prob - class TanhBijector(object): """ @@ -653,9 +654,8 @@ def make_proba_distribution(action_space: gym.spaces.Space, if isinstance(action_space, spaces.Box): assert len(action_space.shape) == 1, "Error: the action space must be a vector" - if use_sde: - return StateDependentNoiseDistribution(get_action_dim(action_space), **dist_kwargs) - return DiagGaussianDistribution(get_action_dim(action_space), **dist_kwargs) + cls = StateDependentNoiseDistribution if use_sde else DiagGaussianDistribution + cls(get_action_dim(action_space), **dist_kwargs) elif isinstance(action_space, spaces.Discrete): return CategoricalDistribution(action_space.n, **dist_kwargs) elif isinstance(action_space, spaces.MultiDiscrete): From cc7a58bc5f3ad3e4e610a54711afa9246324adee Mon Sep 17 00:00:00 2001 From: Adam Gleave Date: Thu, 2 Jul 2020 20:28:48 -0700 Subject: [PATCH 05/13] Bugfix --- stable_baselines3/common/distributions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index 068bba5..6fa70f8 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -655,7 +655,7 @@ def make_proba_distribution(action_space: gym.spaces.Space, if isinstance(action_space, spaces.Box): assert len(action_space.shape) == 1, "Error: the action space must be a vector" cls = StateDependentNoiseDistribution if use_sde else DiagGaussianDistribution - cls(get_action_dim(action_space), **dist_kwargs) + return cls(get_action_dim(action_space), **dist_kwargs) elif isinstance(action_space, spaces.Discrete): return CategoricalDistribution(action_space.n, **dist_kwargs) elif isinstance(action_space, spaces.MultiDiscrete): From e9d8e05cc8e2ac09c2f011ef37589bd405f817d9 Mon Sep 17 00:00:00 2001 From: Adam Gleave Date: Thu, 2 Jul 2020 21:04:36 -0700 Subject: [PATCH 06/13] Review policies --- stable_baselines3/common/policies.py | 64 +++++++++++++++------------- 1 file changed, 35 insertions(+), 29 deletions(-) diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 6d4c87f..48e066c 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -1,3 +1,7 @@ +"""Policies: abstract base class and concrete implementations.""" + +from abc import ABC, abstractmethod +import collections from typing import Union, Type, Dict, List, Tuple, Optional, Any, Callable from functools import partial @@ -17,7 +21,7 @@ from stable_baselines3.common.distributions import (make_proba_distribution, Dis StateDependentNoiseDistribution) -class BasePolicy(nn.Module): +class BasePolicy(nn.Module, ABC): """ The base policy object @@ -98,13 +102,16 @@ class BasePolicy(nn.Module): module.bias.data.fill_(0.0) @staticmethod - def _dummy_schedule(_progress_remaining: float) -> float: + def _dummy_schedule(progress_remaining: float) -> float: """ (float) Useful for pickling policy.""" + del progress_remaining return 0.0 - def forward(self, *_args, **kwargs): - raise NotImplementedError() + @abstractmethod + def forward(self, *args, **kwargs): + del args, kwargs + @abstractmethod def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: """ Get the action according to the policy for a given observation. @@ -113,7 +120,6 @@ class BasePolicy(nn.Module): :param deterministic: (bool) Whether to use stochastic or deterministic actions :return: (th.Tensor) Taken action according to the policy """ - raise NotImplementedError() def predict(self, observation: np.ndarray, @@ -140,10 +146,8 @@ class BasePolicy(nn.Module): # Handle the different cases for images # as PyTorch use channel first format if is_image_space(self.observation_space): - if (observation.shape == self.observation_space.shape + if not (observation.shape == self.observation_space.shape or observation.shape[1:] == self.observation_space.shape): - pass - else: # Try to re-order the channels transpose_obs = VecTransposeImage.transpose_image(observation) if (transpose_obs.shape == self.observation_space.shape @@ -160,21 +164,21 @@ class BasePolicy(nn.Module): # Convert to numpy actions = actions.cpu().numpy() - # Rescale to proper domain when using squashing - if isinstance(self.action_space, gym.spaces.Box) and self.squash_output: - actions = self.unscale_action(actions) - - clipped_actions = actions - # Clip the actions to avoid out of bound error when using gaussian distribution - if isinstance(self.action_space, gym.spaces.Box) and not self.squash_output: - clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high) + if isinstance(self.action_space, gym.spaces.Box): + if self.squash_output: + # Rescale to proper domain when using squashing + actions = self.unscale_action(actions) + else: + # Actions could be on arbitrary scale, so clip the actions to avoid + # out of bound error (e.g. if sampling from a Gaussian distribution) + actions = np.clip(actions, self.action_space.low, self.action_space.high) if not vectorized_env: if state is not None: raise ValueError("Error: The environment must be vectorized when using recurrent policies.") - clipped_actions = clipped_actions[0] + actions = actions[0] - return clipped_actions, state + return actions, state def scale_action(self, action: np.ndarray) -> np.ndarray: """ @@ -227,7 +231,7 @@ class BasePolicy(nn.Module): Load policy from path. :param path: (str) - :param device: ( Union[th.device, str]) Device on which the policy should be loaded. + :param device: (Union[th.device, str]) Device on which the policy should be loaded. :return: (BasePolicy) """ device = get_device(device) @@ -294,7 +298,7 @@ class ActorCriticPolicy(BasePolicy): def __init__(self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, - lr_schedule: Callable, + lr_schedule: Callable[[float], float], net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, device: Union[th.device, str] = 'auto', activation_fn: Type[nn.Module] = nn.Tanh, @@ -313,7 +317,7 @@ class ActorCriticPolicy(BasePolicy): if optimizer_kwargs is None: optimizer_kwargs = {} - # Small values to avoid NaN in ADAM optimizer + # Small values to avoid NaN in Adam optimizer if optimizer_class == th.optim.Adam: optimizer_kwargs['eps'] = 1e-5 @@ -366,15 +370,17 @@ class ActorCriticPolicy(BasePolicy): def _get_data(self) -> Dict[str, Any]: data = super()._get_data() + default_none_kwargs = self.dist_kwargs or collections.defaultdict() + data.update(dict( net_arch=self.net_arch, activation_fn=self.activation_fn, use_sde=self.use_sde, log_std_init=self.log_std_init, - squash_output=self.dist_kwargs['squash_output'] if self.dist_kwargs else None, - full_std=self.dist_kwargs['full_std'] if self.dist_kwargs else None, - sde_net_arch=self.dist_kwargs['sde_net_arch'] if self.dist_kwargs else None, - use_expln=self.dist_kwargs['use_expln'] if self.dist_kwargs else None, + squash_output=default_none_kwargs['squash_output'], + full_std=default_none_kwargs['full_std'], + sde_net_arch=default_none_kwargs['sde_net_arch'], + use_expln=default_none_kwargs['use_expln'], lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone ortho_init=self.ortho_init, optimizer_class=self.optimizer_class, @@ -394,7 +400,7 @@ class ActorCriticPolicy(BasePolicy): StateDependentNoiseDistribution), 'reset_noise() is only available when using gSDE' self.action_dist.sample_weights(self.log_std, batch_size=n_envs) - def _build(self, lr_schedule: Callable) -> None: + def _build(self, lr_schedule: Callable[[float], float]) -> None: """ Create the networks and the optimizer. @@ -651,10 +657,10 @@ def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[ :return: (Type[BasePolicy]) the policy """ if base_policy_type not in _policy_registry: - raise ValueError(f"Error: the policy type {base_policy_type} is not registered!") + raise KeyError(f"Error: the policy type {base_policy_type} is not registered!") if name not in _policy_registry[base_policy_type]: - raise ValueError(f"Error: unknown policy type {name}," - f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!") + raise KeyError(f"Error: unknown policy type {name}," + f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!") return _policy_registry[base_policy_type][name] From e61d34a6f0db4eb92d6121d28fd3037c2da64dd1 Mon Sep 17 00:00:00 2001 From: Adam Gleave Date: Thu, 2 Jul 2020 21:35:06 -0700 Subject: [PATCH 07/13] Fix typing, key error --- stable_baselines3/common/off_policy_algorithm.py | 2 +- stable_baselines3/common/on_policy_algorithm.py | 2 +- stable_baselines3/common/policies.py | 7 +++++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index f459de9..b7d92ad 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -135,7 +135,7 @@ class OffPolicyAlgorithm(BaseAlgorithm): self.action_space, self.device, optimize_memory_usage=self.optimize_memory_usage) self.policy = self.policy_class(self.observation_space, self.action_space, - self.lr_schedule, **self.policy_kwargs) + self.lr_schedule, **self.policy_kwargs) # pytype:disable=not-instantiable self.policy = self.policy.to(self.device) def save_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None: diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 9090d78..2937b77 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -96,7 +96,7 @@ class OnPolicyAlgorithm(BaseAlgorithm): n_envs=self.n_envs) self.policy = self.policy_class(self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, device=self.device, - **self.policy_kwargs) + **self.policy_kwargs) # pytype:disable=not-instantiable self.policy = self.policy.to(self.device) def collect_rollouts(self, diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 48e066c..2aeffd9 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -111,15 +111,18 @@ class BasePolicy(nn.Module, ABC): def forward(self, *args, **kwargs): del args, kwargs - @abstractmethod def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: """ Get the action according to the policy for a given observation. + By default provides a dummy implementation -- not all BasePolicy classes + implement this, e.g. if they are a Critic in an Actor-Critic method. + :param observation: (th.Tensor) :param deterministic: (bool) Whether to use stochastic or deterministic actions :return: (th.Tensor) Taken action according to the policy """ + raise NotImplementedError() def predict(self, observation: np.ndarray, @@ -370,7 +373,7 @@ class ActorCriticPolicy(BasePolicy): def _get_data(self) -> Dict[str, Any]: data = super()._get_data() - default_none_kwargs = self.dist_kwargs or collections.defaultdict() + default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None) data.update(dict( net_arch=self.net_arch, From 3756d05f7265b8023b1987d9d94a25ec15dbcc1f Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 7 Jul 2020 00:02:51 +0200 Subject: [PATCH 08/13] Refactored ContinuousCritic for SAC/TD3 (#78) * Refactored ContinuousCritic for SAC/TD3 * Address comments * Add pybullet notebook --- docs/guide/examples.rst | 8 ++- docs/misc/changelog.rst | 7 ++- stable_baselines3/common/policies.py | 70 +++++++++++++++++++++++++- stable_baselines3/sac/policies.py | 69 ++++++-------------------- stable_baselines3/sac/sac.py | 2 +- stable_baselines3/td3/policies.py | 73 ++++++---------------------- stable_baselines3/td3/td3.py | 1 + stable_baselines3/version.txt | 2 +- 8 files changed, 114 insertions(+), 118 deletions(-) diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 0e3242b..a6b8040 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -17,6 +17,7 @@ notebooks: - `Monitor Training and Plotting`_ - `Atari Games`_ - `RL Baselines zoo`_ +- `PyBullet`_ .. - `Hindsight Experience Replay`_ @@ -27,6 +28,7 @@ notebooks: .. _Atari Games: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/atari_games.ipynb .. _Hindsight Experience Replay: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/stable_baselines_her.ipynb .. _RL Baselines zoo: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/rl-baselines-zoo.ipynb +.. _PyBullet: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pybullet.ipynb .. |colab| image:: ../_static/img/colab.svg @@ -291,7 +293,7 @@ PyBullet: Normalizing input features Normalizing input features may be essential to successful training of an RL agent (by default, images are scaled but not other types of input), -for instance when training on `PyBullet `_ environments. For that, a wrapper exists and +for instance when training on `PyBullet `__ environments. For that, a wrapper exists and will compute a running average and standard deviation of input features (it can do the same for rewards). @@ -300,6 +302,10 @@ will compute a running average and standard deviation of input features (it can you need to install pybullet with ``pip install pybullet`` +.. image:: ../_static/img/colab-badge.svg + :target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/pybullet.ipynb + + .. code-block:: python import gym diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 45d78d3..0ca6fa9 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,13 +3,15 @@ Changelog ========== -Pre-Release 0.8.0a2 (WIP) +Pre-Release 0.8.0a3 (WIP) ------------------------------ Breaking Changes: ^^^^^^^^^^^^^^^^^ - ``AtariWrapper`` and other Atari wrappers were updated to match SB2 ones - ``save_replay_buffer`` now receives as argument the file path instead of the folder path (@tirafesi) +- Refactored ``Critic`` class for ``TD3`` and ``SAC``, it is now called ``ContinuousCritic`` + and has an additional parameter ``n_critics`` New Features: ^^^^^^^^^^^^^ @@ -40,6 +42,7 @@ Documentation: - Updated notebook links - Fixed a typo in the section of Enjoy a Trained Agent, in RL Baselines3 Zoo README. (@blurLake) - Added Unity reacher to the projects page (@koulakis) +- Added PyBullet colab notebook @@ -342,4 +345,4 @@ And all the contributors: @Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket @MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching @flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3 -@tirafesi @blurLake @koulakis +@tirafesi @blurLake @koulakis diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 6d4c87f..ba4797f 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -6,7 +6,7 @@ import torch as th import torch.nn as nn import numpy as np -from stable_baselines3.common.preprocessing import preprocess_obs, is_image_space +from stable_baselines3.common.preprocessing import preprocess_obs, is_image_space, get_action_dim from stable_baselines3.common.torch_layers import (FlattenExtractor, BaseFeaturesExtractor, create_mlp, NatureCNN, MlpExtractor) from stable_baselines3.common.utils import get_device, is_vectorized_observation @@ -617,6 +617,74 @@ class ActorCriticCnnPolicy(ActorCriticPolicy): optimizer_kwargs) +class ContinuousCritic(BasePolicy): + """ + Critic network(s) for DDPG/SAC/TD3. + It represents the action-state value function (Q-value function). + Compared to A2C/PPO critics, this one represents the Q-value + and takes the continuous action as input. It is concatenated with the state + and then fed to the network which outputs a single value: Q(s, a). + For more recent algorithms like SAC/TD3, multiple networks + are created to give different estimates. + + By default, it creates two critic networks used to reduce overestimation + thanks to clipped Q-learning (cf TD3 paper). + + :param observation_space: (gym.spaces.Space) Obervation space + :param action_space: (gym.spaces.Space) Action space + :param net_arch: ([int]) Network architecture + :param features_extractor: (nn.Module) Network to extract features + (a CNN when using images, a nn.Flatten() layer otherwise) + :param features_dim: (int) Number of features + :param activation_fn: (Type[nn.Module]) Activation function + :param normalize_images: (bool) Whether to normalize images or not, + dividing by 255.0 (True by default) + :param device: (Union[th.device, str]) Device on which the code should run. + :param n_critics: (int) Number of critic networks to create. + """ + + def __init__(self, observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + net_arch: List[int], + features_extractor: nn.Module, + features_dim: int, + activation_fn: Type[nn.Module] = nn.ReLU, + normalize_images: bool = True, + device: Union[th.device, str] = 'auto', + n_critics: int = 2): + super().__init__(observation_space, action_space, + features_extractor=features_extractor, + normalize_images=normalize_images, + device=device) + + action_dim = get_action_dim(self.action_space) + + self.n_critics = n_critics + self.q_networks = [] + for idx in range(n_critics): + q_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn) + q_net = nn.Sequential(*q_net) + self.add_module(f'qf{idx}', q_net) + self.q_networks.append(q_net) + + def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, ...]: + # Learn the features extractor using the policy loss only + with th.no_grad(): + features = self.extract_features(obs) + qvalue_input = th.cat([features, actions], dim=1) + return tuple(q_net(qvalue_input) for q_net in self.q_networks) + + def q1_forward(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor: + """ + Only predict the Q-value using the first network. + This allows to reduce computation when all the estimates are not needed + (e.g. when updating the policy in TD3). + """ + with th.no_grad(): + features = self.extract_features(obs) + return self.q_networks[0](th.cat([features, actions], dim=1)) + + def create_sde_features_extractor(features_dim: int, sde_net_arch: List[int], activation_fn: Type[nn.Module]) -> Tuple[nn.Sequential, int]: diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py index 335bbd4..e409293 100644 --- a/stable_baselines3/sac/policies.py +++ b/stable_baselines3/sac/policies.py @@ -5,7 +5,7 @@ import torch as th import torch.nn as nn from stable_baselines3.common.preprocessing import get_action_dim -from stable_baselines3.common.policies import BasePolicy, register_policy, create_sde_features_extractor +from stable_baselines3.common.policies import BasePolicy, register_policy, create_sde_features_extractor, ContinuousCritic from stable_baselines3.common.torch_layers import create_mlp, NatureCNN, BaseFeaturesExtractor, FlattenExtractor from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution @@ -179,54 +179,6 @@ class Actor(BasePolicy): return self.forward(observation, deterministic) -class Critic(BasePolicy): - """ - Critic network (q-value function) for SAC. - - :param observation_space: (gym.spaces.Space) Obervation space - :param action_space: (gym.spaces.Space) Action space - :param net_arch: ([int]) Network architecture - :param features_extractor: (nn.Module) Network to extract features - (a CNN when using images, a nn.Flatten() layer otherwise) - :param features_dim: (int) Number of features - :param activation_fn: (Type[nn.Module]) Activation function - :param normalize_images: (bool) Whether to normalize images or not, - dividing by 255.0 (True by default) - :param device: (Union[th.device, str]) Device on which the code should run. - """ - - def __init__(self, observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, - net_arch: List[int], - features_extractor: nn.Module, - features_dim: int, - activation_fn: Type[nn.Module] = nn.ReLU, - normalize_images: bool = True, - device: Union[th.device, str] = 'auto'): - super(Critic, self).__init__(observation_space, action_space, - features_extractor=features_extractor, - normalize_images=normalize_images, - device=device) - - action_dim = get_action_dim(self.action_space) - - q1_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn) - self.q1_net = nn.Sequential(*q1_net) - - q2_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn) - self.q2_net = nn.Sequential(*q2_net) - - self.q_networks = [self.q1_net, self.q2_net] - - def forward(self, obs: th.Tensor, action: th.Tensor) -> List[th.Tensor]: - # Learn the features extractor using the policy loss only - # this is much faster - with th.no_grad(): - features = self.extract_features(obs) - qvalue_input = th.cat([features, action], dim=1) - return [q_net(qvalue_input) for q_net in self.q_networks] - - class SACPolicy(BasePolicy): """ Policy class (with both actor and critic) for SAC. @@ -255,6 +207,7 @@ class SACPolicy(BasePolicy): ``th.optim.Adam`` by default :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, excluding the learning rate, to pass to the optimizer + :param n_critics: (int) Number of critic networks to create. """ def __init__(self, observation_space: gym.spaces.Space, @@ -272,7 +225,8 @@ class SACPolicy(BasePolicy): features_extractor_kwargs: Optional[Dict[str, Any]] = None, normalize_images: bool = True, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None): + optimizer_kwargs: Optional[Dict[str, Any]] = None, + n_critics: int = 2): super(SACPolicy, self).__init__(observation_space, action_space, device, features_extractor_class, @@ -313,6 +267,9 @@ class SACPolicy(BasePolicy): 'clip_mean': clip_mean } self.actor_kwargs.update(sde_kwargs) + self.critic_kwargs = self.net_args.copy() + self.critic_kwargs.update({'n_critics': n_critics}) + self.actor, self.actor_target = None, None self.critic, self.critic_target = None, None @@ -345,6 +302,7 @@ class SACPolicy(BasePolicy): sde_net_arch=self.actor_kwargs['sde_net_arch'], use_expln=self.actor_kwargs['use_expln'], clip_mean=self.actor_kwargs['clip_mean'], + n_critics=self.critic_kwargs['n_critics'], lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone optimizer_class=self.optimizer_class, optimizer_kwargs=self.optimizer_kwargs, @@ -364,8 +322,8 @@ class SACPolicy(BasePolicy): def make_actor(self) -> Actor: return Actor(**self.actor_kwargs).to(self.device) - def make_critic(self) -> Critic: - return Critic(**self.net_args).to(self.device) + def make_critic(self) -> ContinuousCritic: + return ContinuousCritic(**self.critic_kwargs).to(self.device) def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor: return self._predict(obs, deterministic=deterministic) @@ -403,6 +361,7 @@ class CnnPolicy(SACPolicy): ``th.optim.Adam`` by default :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, excluding the learning rate, to pass to the optimizer + :param n_critics: (int) Number of critic networks to create. """ def __init__(self, observation_space: gym.spaces.Space, @@ -420,7 +379,8 @@ class CnnPolicy(SACPolicy): features_extractor_kwargs: Optional[Dict[str, Any]] = None, normalize_images: bool = True, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None): + optimizer_kwargs: Optional[Dict[str, Any]] = None, + n_critics: int = 2): super(CnnPolicy, self).__init__(observation_space, action_space, lr_schedule, @@ -436,7 +396,8 @@ class CnnPolicy(SACPolicy): features_extractor_kwargs, normalize_images, optimizer_class, - optimizer_kwargs) + optimizer_kwargs, + n_critics) register_policy("MlpPolicy", MlpPolicy) diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index 203abc4..04e20fa 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -118,7 +118,7 @@ class SAC(OffPolicyAlgorithm): def _setup_model(self) -> None: super(SAC, self)._setup_model() self._create_aliases() - + assert self.critic.n_critics == 2, "SAC only supports `n_critics=2` for now" # Target entropy is used when learning the entropy coefficient if self.target_entropy == 'auto': # automatically set target entropy if needed diff --git a/stable_baselines3/td3/policies.py b/stable_baselines3/td3/policies.py index 7b863ad..325640f 100644 --- a/stable_baselines3/td3/policies.py +++ b/stable_baselines3/td3/policies.py @@ -1,11 +1,11 @@ -from typing import Optional, List, Tuple, Callable, Union, Type, Any, Dict +from typing import Optional, List, Callable, Union, Type, Any, Dict import gym import torch as th import torch.nn as nn from stable_baselines3.common.preprocessing import get_action_dim -from stable_baselines3.common.policies import BasePolicy, register_policy +from stable_baselines3.common.policies import BasePolicy, register_policy, ContinuousCritic from stable_baselines3.common.torch_layers import create_mlp, NatureCNN, BaseFeaturesExtractor, FlattenExtractor @@ -71,57 +71,6 @@ class Actor(BasePolicy): return self.forward(observation, deterministic=deterministic) -class Critic(BasePolicy): - """ - Critic network for TD3, - in fact it represents the action-state value function (Q-value function) - - :param observation_space: (gym.spaces.Space) Obervation space - :param action_space: (gym.spaces.Space) Action space - :param net_arch: ([int]) Network architecture - :param features_extractor: (nn.Module) Network to extract features - (a CNN when using images, a nn.Flatten() layer otherwise) - :param features_dim: (int) Number of features - :param activation_fn: (Type[nn.Module]) Activation function - :param normalize_images: (bool) Whether to normalize images or not, - dividing by 255.0 (True by default) - :param device: (Union[th.device, str]) Device on which the code should run. - """ - - def __init__(self, observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, - net_arch: List[int], - features_extractor: nn.Module, - features_dim: int, - activation_fn: Type[nn.Module] = nn.ReLU, - normalize_images: bool = True, - device: Union[th.device, str] = 'auto'): - super(Critic, self).__init__(observation_space, action_space, - features_extractor=features_extractor, - normalize_images=normalize_images, - device=device) - - action_dim = get_action_dim(self.action_space) - - q1_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn) - self.q1_net = nn.Sequential(*q1_net) - - q2_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn) - self.q2_net = nn.Sequential(*q2_net) - - def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: - # Learn the features extractor using the policy loss only - with th.no_grad(): - features = self.extract_features(obs) - qvalue_input = th.cat([features, actions], dim=1) - return self.q1_net(qvalue_input), self.q2_net(qvalue_input) - - def q1_forward(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor: - with th.no_grad(): - features = self.extract_features(obs) - return self.q1_net(th.cat([features, actions], dim=1)) - - class TD3Policy(BasePolicy): """ Policy class (with both actor and critic) for TD3. @@ -141,6 +90,7 @@ class TD3Policy(BasePolicy): ``th.optim.Adam`` by default :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, excluding the learning rate, to pass to the optimizer + :param n_critics: (int) Number of critic networks to create. """ def __init__(self, observation_space: gym.spaces.Space, @@ -153,7 +103,8 @@ class TD3Policy(BasePolicy): features_extractor_kwargs: Optional[Dict[str, Any]] = None, normalize_images: bool = True, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None): + optimizer_kwargs: Optional[Dict[str, Any]] = None, + n_critics: int = 2): super(TD3Policy, self).__init__(observation_space, action_space, device, features_extractor_class, @@ -185,6 +136,8 @@ class TD3Policy(BasePolicy): 'normalize_images': normalize_images, 'device': device } + self.critic_kwargs = self.net_args.copy() + self.critic_kwargs.update({'n_critics': n_critics}) self.actor, self.actor_target = None, None self.critic, self.critic_target = None, None @@ -208,6 +161,7 @@ class TD3Policy(BasePolicy): data.update(dict( net_arch=self.net_args['net_arch'], activation_fn=self.net_args['activation_fn'], + n_critics=self.critic_kwargs['n_critics'], lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone optimizer_class=self.optimizer_class, optimizer_kwargs=self.optimizer_kwargs, @@ -219,8 +173,8 @@ class TD3Policy(BasePolicy): def make_actor(self) -> Actor: return Actor(**self.net_args).to(self.device) - def make_critic(self) -> Critic: - return Critic(**self.net_args).to(self.device) + def make_critic(self) -> ContinuousCritic: + return ContinuousCritic(**self.critic_kwargs).to(self.device) def forward(self, observation: th.Tensor, deterministic: bool = False): return self._predict(observation, deterministic=deterministic) @@ -251,6 +205,7 @@ class CnnPolicy(TD3Policy): ``th.optim.Adam`` by default :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, excluding the learning rate, to pass to the optimizer + :param n_critics: (int) Number of critic networks to create. """ def __init__(self, observation_space: gym.spaces.Space, @@ -263,7 +218,8 @@ class CnnPolicy(TD3Policy): features_extractor_kwargs: Optional[Dict[str, Any]] = None, normalize_images: bool = True, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None): + optimizer_kwargs: Optional[Dict[str, Any]] = None, + n_critics: int = 2): super(CnnPolicy, self).__init__(observation_space, action_space, lr_schedule, @@ -274,7 +230,8 @@ class CnnPolicy(TD3Policy): features_extractor_kwargs, normalize_images, optimizer_class, - optimizer_kwargs) + optimizer_kwargs, + n_critics) register_policy("MlpPolicy", MlpPolicy) diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index 459e240..13bcc98 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -96,6 +96,7 @@ class TD3(OffPolicyAlgorithm): def _setup_model(self) -> None: super(TD3, self)._setup_model() self._create_aliases() + assert self.critic.n_critics == 2, "TD3 only supports `n_critics=2` for now" def _create_aliases(self) -> None: self.actor = self.policy.actor diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 8db4718..8369211 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -0.8.0a2 +0.8.0a3 From 91bbc28c0f01b8a50cea61c113186c5ab8ac3068 Mon Sep 17 00:00:00 2001 From: Adam Gleave Date: Tue, 7 Jul 2020 18:39:55 -0700 Subject: [PATCH 09/13] Address minor issues after clarification by @araffin --- stable_baselines3/common/base_class.py | 10 ++++++---- stable_baselines3/common/distributions.py | 14 ++++++++++++++ stable_baselines3/common/policies.py | 3 +++ 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index f66220e..4e51367 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -320,8 +320,8 @@ class BaseAlgorithm(ABC): f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}") # check if observation space and action space are part of the saved parameters - if ("observation_space" not in data or "action_space" not in data) and "env" not in data: - raise ValueError("The observation_space and action_space was not given, can't verify new environments") + if "observation_space" not in data or "action_space" not in data: + raise KeyError("The observation_space and action_space were not given, can't verify new environments") # check if given env is valid if env is not None: check_for_correct_spaces(env, data["observation_space"], data["action_space"]) @@ -425,8 +425,10 @@ class BaseAlgorithm(ABC): :return: (Tuple[int, BaseCallback]) """ self.start_time = time.time() - self.ep_info_buffer = deque(maxlen=100) - self.ep_success_buffer = deque(maxlen=100) + if self.ep_info_buffer is None or reset_num_timesteps: + # Initialize buffers if they don't exist, or reinitialize if resetting counters + self.ep_info_buffer = deque(maxlen=100) + self.ep_success_buffer = deque(maxlen=100) if self.action_noise is not None: self.action_noise.reset() diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index 6fa70f8..1a20928 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -17,6 +17,20 @@ class Distribution(ABC): def __init__(self): super(Distribution, self).__init__() + @abstractmethod + def proba_distribution_net(self, *args, **kwargs): + """Create the layers and parameters that represent the distribution. + + Subclasses must define this, but the arguments and return type vary between + concrete classes.""" + + @abstractmethod + def proba_distribution(self, *args, **kwargs) -> 'Distribution': + """Set parameters of the distribution. + + :return: (Distribution) self + """ + @abstractmethod def log_prob(self, x: th.Tensor) -> th.Tensor: """ diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 2aeffd9..7342c46 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -140,6 +140,7 @@ class BasePolicy(nn.Module, ABC): :return: (Tuple[np.ndarray, Optional[np.ndarray]]) the model's action and the next state (used in recurrent policies) """ + # TODO (GH/1): add support for RNN policies # if state is None: # state = self.initial_state # if mask is None: @@ -438,6 +439,8 @@ class ActorCriticPolicy(BasePolicy): self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi) elif isinstance(self.action_dist, BernoulliDistribution): self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi) + else: + raise NotImplementedError(f"Unsupported distribution '{self.action_dist}'.") self.value_net = nn.Linear(self.mlp_extractor.latent_dim_vf, 1) # Init weights: use orthogonal initialization From 0345591deaa53e4240f4534fffe05f88e109cbb0 Mon Sep 17 00:00:00 2001 From: Adam Gleave Date: Tue, 7 Jul 2020 18:51:44 -0700 Subject: [PATCH 10/13] Refactor BasePolicy by introducing new BaseModel ABC for Critic's to inherit from. --- stable_baselines3/common/policies.py | 169 +++++++++++++++------------ stable_baselines3/sac/policies.py | 4 +- stable_baselines3/td3/policies.py | 4 +- 3 files changed, 96 insertions(+), 81 deletions(-) diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 7342c46..16ba652 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -21,9 +21,12 @@ from stable_baselines3.common.distributions import (make_proba_distribution, Dis StateDependentNoiseDistribution) -class BasePolicy(nn.Module, ABC): +class BaseModel(nn.Module, ABC): """ - The base policy object + The base model object: makes predictions in response to observations. + + In the case of policies, the prediction is an action. In the case of critics, it is the + estimated value of the observation. :param observation_space: (gym.spaces.Space) The observation space of the environment :param action_space: (gym.spaces.Space) The action space of the environment @@ -39,8 +42,6 @@ class BasePolicy(nn.Module, ABC): ``th.optim.Adam`` by default :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, excluding the learning rate, to pass to the optimizer - :param squash_output: (bool) For continuous actions, whether the output is squashed - or not using a ``tanh()`` function. """ def __init__(self, @@ -52,9 +53,8 @@ class BasePolicy(nn.Module, ABC): features_extractor: Optional[nn.Module] = None, normalize_images: bool = True, optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, - squash_output: bool = False): - super(BasePolicy, self).__init__() + optimizer_kwargs: Optional[Dict[str, Any]] = None): + super(BaseModel, self).__init__() if optimizer_kwargs is None: optimizer_kwargs = {} @@ -67,7 +67,6 @@ class BasePolicy(nn.Module, ABC): self.device = get_device(device) self.features_extractor = features_extractor self.normalize_images = normalize_images - self._squash_output = squash_output self.optimizer_class = optimizer_class self.optimizer_kwargs = optimizer_kwargs @@ -76,6 +75,10 @@ class BasePolicy(nn.Module, ABC): self.features_extractor_class = features_extractor_class self.features_extractor_kwargs = features_extractor_kwargs + @abstractmethod + def forward(self, *args, **kwargs): + del args, kwargs + def extract_features(self, obs: th.Tensor) -> th.Tensor: """ Preprocess the observation if needed and extract features. @@ -87,9 +90,89 @@ class BasePolicy(nn.Module, ABC): preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images) return self.features_extractor(preprocessed_obs) + def _get_data(self) -> Dict[str, Any]: + """ + Get data that need to be saved in order to re-create the model. + This corresponds to the arguments of the constructor. + + :return: (Dict[str, Any]) + """ + return dict( + observation_space=self.observation_space, + action_space=self.action_space, + # Passed to the constructor by child class + # squash_output=self.squash_output, + # features_extractor=self.features_extractor + normalize_images=self.normalize_images, + ) + + def save(self, path: str) -> None: + """ + Save model to a given location. + + :param path: (str) + """ + th.save({'state_dict': self.state_dict(), 'data': self._get_data()}, path) + + @classmethod + def load(cls, path: str, device: Union[th.device, str] = 'auto') -> 'BaseModel': + """ + Load model from path. + + :param path: (str) + :param device: (Union[th.device, str]) Device on which the policy should be loaded. + :return: (BasePolicy) + """ + device = get_device(device) + saved_variables = th.load(path, map_location=device) + # Create policy object + model = cls(**saved_variables['data']) + # Load weights + model.load_state_dict(saved_variables['state_dict']) + model.to(device) + return model + + def load_from_vector(self, vector: np.ndarray): + """ + Load parameters from a 1D vector. + + :param vector: (np.ndarray) + """ + th.nn.utils.vector_to_parameters(th.FloatTensor(vector).to(self.device), self.parameters()) + + def parameters_to_vector(self) -> np.ndarray: + """ + Convert the parameters to a 1D vector. + + :return: (np.ndarray) + """ + return th.nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy() + + +class BasePolicy(BaseModel): + """The base policy object. + + Parameters are mostly the same as `BaseModel`; additions are documented below. + + :param args: positional arguments passed through to `BaseModel`. + :param kwargs: keyword arguments passed through to `BaseModel`. + :param squash_output: (bool) For continuous actions, whether the output is squashed + or not using a ``tanh()`` function. + """ + def __init__(self, *args, squash_output: bool = False, **kwargs): + super(BasePolicy, self).__init__(*args, **kwargs) + self._squash_output = squash_output + + + @staticmethod + def _dummy_schedule(progress_remaining: float) -> float: + """ (float) Useful for pickling policy.""" + del progress_remaining + return 0.0 + @property def squash_output(self) -> bool: - """ (bool) Getter for squash_output.""" + """(bool) Getter for squash_output.""" return self._squash_output @staticmethod @@ -101,16 +184,7 @@ class BasePolicy(nn.Module, ABC): nn.init.orthogonal_(module.weight, gain=gain) module.bias.data.fill_(0.0) - @staticmethod - def _dummy_schedule(progress_remaining: float) -> float: - """ (float) Useful for pickling policy.""" - del progress_remaining - return 0.0 - @abstractmethod - def forward(self, *args, **kwargs): - del args, kwargs - def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: """ Get the action according to the policy for a given observation. @@ -122,7 +196,6 @@ class BasePolicy(nn.Module, ABC): :param deterministic: (bool) Whether to use stochastic or deterministic actions :return: (th.Tensor) Taken action according to the policy """ - raise NotImplementedError() def predict(self, observation: np.ndarray, @@ -205,64 +278,6 @@ class BasePolicy(nn.Module, ABC): low, high = self.action_space.low, self.action_space.high return low + (0.5 * (scaled_action + 1.0) * (high - low)) - def _get_data(self) -> Dict[str, Any]: - """ - Get data that need to be saved in order to re-create the policy. - This corresponds to the arguments of the constructor. - - :return: (Dict[str, Any]) - """ - return dict( - observation_space=self.observation_space, - action_space=self.action_space, - # Passed to the constructor by child class - # squash_output=self.squash_output, - # features_extractor=self.features_extractor - normalize_images=self.normalize_images, - ) - - def save(self, path: str) -> None: - """ - Save policy to a given location. - - :param path: (str) - """ - th.save({'state_dict': self.state_dict(), 'data': self._get_data()}, path) - - @classmethod - def load(cls, path: str, device: Union[th.device, str] = 'auto') -> 'BasePolicy': - """ - Load policy from path. - - :param path: (str) - :param device: (Union[th.device, str]) Device on which the policy should be loaded. - :return: (BasePolicy) - """ - device = get_device(device) - saved_variables = th.load(path, map_location=device) - # Create policy object - model = cls(**saved_variables['data']) - # Load weights - model.load_state_dict(saved_variables['state_dict']) - model.to(device) - return model - - def load_from_vector(self, vector: np.ndarray): - """ - Load parameters from a 1D vector. - - :param vector: (np.ndarray) - """ - th.nn.utils.vector_to_parameters(th.FloatTensor(vector).to(self.device), self.parameters()) - - def parameters_to_vector(self) -> np.ndarray: - """ - Convert the parameters to a 1D vector. - - :return: (np.ndarray) - """ - return th.nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy() - class ActorCriticPolicy(BasePolicy): """ diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py index 335bbd4..e408622 100644 --- a/stable_baselines3/sac/policies.py +++ b/stable_baselines3/sac/policies.py @@ -5,7 +5,7 @@ import torch as th import torch.nn as nn from stable_baselines3.common.preprocessing import get_action_dim -from stable_baselines3.common.policies import BasePolicy, register_policy, create_sde_features_extractor +from stable_baselines3.common.policies import BaseModel, BasePolicy, register_policy, create_sde_features_extractor from stable_baselines3.common.torch_layers import create_mlp, NatureCNN, BaseFeaturesExtractor, FlattenExtractor from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution @@ -179,7 +179,7 @@ class Actor(BasePolicy): return self.forward(observation, deterministic) -class Critic(BasePolicy): +class Critic(BaseModel): """ Critic network (q-value function) for SAC. diff --git a/stable_baselines3/td3/policies.py b/stable_baselines3/td3/policies.py index 7b863ad..addadf5 100644 --- a/stable_baselines3/td3/policies.py +++ b/stable_baselines3/td3/policies.py @@ -5,7 +5,7 @@ import torch as th import torch.nn as nn from stable_baselines3.common.preprocessing import get_action_dim -from stable_baselines3.common.policies import BasePolicy, register_policy +from stable_baselines3.common.policies import BaseModel, BasePolicy, register_policy from stable_baselines3.common.torch_layers import create_mlp, NatureCNN, BaseFeaturesExtractor, FlattenExtractor @@ -71,7 +71,7 @@ class Actor(BasePolicy): return self.forward(observation, deterministic=deterministic) -class Critic(BasePolicy): +class Critic(BaseModel): """ Critic network for TD3, in fact it represents the action-state value function (Q-value function) From e7130344de6548ce843193f27c6e69fd99f21926 Mon Sep 17 00:00:00 2001 From: Adam Gleave Date: Tue, 7 Jul 2020 19:03:46 -0700 Subject: [PATCH 11/13] Add changelog entry --- docs/misc/changelog.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 0ca6fa9..a65ec1d 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -21,6 +21,7 @@ New Features: when ``psutil`` is available - Saving models now automatically creates the necessary folders and raises appropriate warnings (@PartiallyTyped) - Refactored opening paths for saving and loading to use strings, pathlib or io.BufferedIOBase (@PartiallyTyped) +- Introduced ``BaseModel`` abstract parent for ``BasePolicy``, which critics inherit from. Bug Fixes: ^^^^^^^^^^ From fc0b0b8824bd10c46950b5d5c53b499fe80ebb20 Mon Sep 17 00:00:00 2001 From: Adam Gleave Date: Tue, 7 Jul 2020 19:05:10 -0700 Subject: [PATCH 12/13] Lint --- stable_baselines3/common/policies.py | 1 - 1 file changed, 1 deletion(-) diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 164a247..bc387c2 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -163,7 +163,6 @@ class BasePolicy(BaseModel): super(BasePolicy, self).__init__(*args, **kwargs) self._squash_output = squash_output - @staticmethod def _dummy_schedule(progress_remaining: float) -> float: """ (float) Useful for pickling policy.""" From 3cf6e9714b816ab7f1352d6aa439059becff707b Mon Sep 17 00:00:00 2001 From: Joel Joseph <34275997+joeljosephjin@users.noreply.github.com> Date: Fri, 10 Jul 2020 14:08:35 +0530 Subject: [PATCH 13/13] Update ppo.rst (#94) * Update ppo.rst minor correction from A2C to PPO * Update changelog.rst * Update changelog.rst Co-authored-by: Antonin RAFFIN --- docs/misc/changelog.rst | 3 ++- docs/modules/ppo.rst | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index a65ec1d..b261a0e 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -44,6 +44,7 @@ Documentation: - Fixed a typo in the section of Enjoy a Trained Agent, in RL Baselines3 Zoo README. (@blurLake) - Added Unity reacher to the projects page (@koulakis) - Added PyBullet colab notebook +- Fixed typo in PPO example code (@joeljosephjin) @@ -346,4 +347,4 @@ And all the contributors: @Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket @MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching @flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3 -@tirafesi @blurLake @koulakis +@tirafesi @blurLake @koulakis @joeljosephjin diff --git a/docs/modules/ppo.rst b/docs/modules/ppo.rst index f03e92e..038149d 100644 --- a/docs/modules/ppo.rst +++ b/docs/modules/ppo.rst @@ -53,7 +53,7 @@ Train a PPO agent on ``Pendulum-v0`` using 4 environments. import gym - from stable_baselines3 import A2C + from stable_baselines3 import PPO from stable_baselines3.ppo import MlpPolicy from stable_baselines3.common.cmd_util import make_vec_env