mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-29 03:31:08 +00:00
Merge branch 'master' into sde
This commit is contained in:
commit
fc4bf016fd
34 changed files with 1813 additions and 1408 deletions
|
|
@ -155,9 +155,9 @@ All the following examples can be executed online using Google colab notebooks:
|
|||
- [All Notebooks](https://github.com/Stable-Baselines-Team/rl-colab-notebooks/tree/sb3)
|
||||
- [Getting Started](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/stable_baselines_getting_started.ipynb)
|
||||
- [Training, Saving, Loading](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/saving_loading_dqn.ipynb)
|
||||
<!-- - [Multiprocessing](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb)
|
||||
- [Multiprocessing](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb)
|
||||
- [Monitor Training and Plotting](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/monitor_training.ipynb)
|
||||
- [Atari Games](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/atari_games.ipynb) -->
|
||||
- [Atari Games](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/atari_games.ipynb)
|
||||
- [RL Baselines Zoo](https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/rl-baselines-zoo.ipynb)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ This will give you access to events (``_on_training_start``, ``_on_step``) and u
|
|||
# Those variables will be accessible in the callback
|
||||
# (they are defined in the base class)
|
||||
# The RL model
|
||||
# self.model = None # type: BaseRLModel
|
||||
# self.model = None # type: BaseAlgorithm
|
||||
# An alias for self.model.get_env(), the environment used for training
|
||||
# self.training_env = None # type: Union[gym.Env, VecEnv, None]
|
||||
# Number of time the callback was called
|
||||
|
|
|
|||
|
|
@ -19,7 +19,9 @@ The library is not meant to be modular, although inheritance is used to reduce c
|
|||
Algorithms Structure
|
||||
====================
|
||||
|
||||
|
||||
Each algorithm (on-policy and off-policy ones) follows a common structure.
|
||||
Policy contains code for acting in the environment, and algorithm updates this policy.
|
||||
There is one folder per algorithm, and in that folder there is the algorithm and the policy definition (``policies.py``).
|
||||
|
||||
Each algorithm has two main methods:
|
||||
|
|
@ -34,13 +36,14 @@ Where to start?
|
|||
|
||||
The first thing you need to read and understand are the base classes in the ``common/`` folder:
|
||||
|
||||
- ``BaseRLModel`` in ``base_class.py`` which defines how an RL class should look like.
|
||||
- ``BaseAlgorithm`` in ``base_class.py`` which defines how an RL class should look like.
|
||||
It contains also all the "glue code" for saving/loading and the common operations (wrapping environments)
|
||||
|
||||
- ``BasePolicy`` in ``policies.py`` which defines how a policy class should look like.
|
||||
It contains also all the magic for the ``.predict()`` method, to handle as many cases as possible
|
||||
It contains also all the magic for the ``.predict()`` method, to handle as many spaces/cases as possible
|
||||
|
||||
- ``OffPolicyRLModel`` in ``base_class.py`` that contains the implementation of ``collect_rollouts()`` for the off-policy algorithms
|
||||
- ``OffPolicyAlgorithm`` in ``off_policy_algorithm.py`` that contains the implementation of ``collect_rollouts()`` for the off-policy algorithms,
|
||||
and similarly ``OnPolicyAlgorithm`` in ``on_policy_algorithm.py``.
|
||||
|
||||
|
||||
All the environments handled internally are assumed to be ``VecEnv`` (``gym.Env`` are automatically wrapped).
|
||||
|
|
@ -50,7 +53,7 @@ Pre-Processing
|
|||
==============
|
||||
|
||||
To handle different observation spaces, some pre-processing needs to be done (e.g. one-hot encoding for discrete observation).
|
||||
Most of the code for pre-processing is in ``common/preprocessing.py``.
|
||||
Most of the code for pre-processing is in ``common/preprocessing.py`` and ``common/policies.py``.
|
||||
|
||||
For images, we make use of an additional wrapper ``VecTransposeImage`` because PyTorch uses the "channel-first" convention.
|
||||
|
||||
|
|
@ -61,9 +64,12 @@ Policy Structure
|
|||
When we refer to "policy" in Stable-Baselines3, this is usually an abuse of language compared to RL terminology.
|
||||
In SB3, "policy" refers to the class that handles all the networks useful for training,
|
||||
so not only the network used to predict actions (the "learned controller").
|
||||
|
||||
For instance, the ``TD3`` policy contains the actor, the critic and the target networks.
|
||||
|
||||
To avoid the hassle of importing specific policy classes for specific algorithm (e.g. both A2C and PPO use ``ActorCriticPolicy``),
|
||||
SB3 uses names like "MlpPolicy" and "CnnPolicy" to refer policies using small feed-forward networks or convolutional networks,
|
||||
respectively. Importing ``[algorithm]/policies.py`` registers an appropriate policy for that algorithm under those names.
|
||||
|
||||
Probability distributions
|
||||
=========================
|
||||
|
||||
|
|
|
|||
|
|
@ -13,13 +13,11 @@ notebooks:
|
|||
- `All Notebooks <https://github.com/Stable-Baselines-Team/rl-colab-notebooks/tree/sb3>`_
|
||||
- `Getting Started`_
|
||||
- `Training, Saving, Loading`_
|
||||
- `Multiprocessing`_
|
||||
- `Monitor Training and Plotting`_
|
||||
- `Atari Games`_
|
||||
- `RL Baselines zoo`_
|
||||
|
||||
|
||||
.. - `Multiprocessing`_
|
||||
.. - `Monitor Training and Plotting`_
|
||||
.. - `Atari Games`_
|
||||
.. - `Breakout`_ (trained agent included)
|
||||
.. - `Hindsight Experience Replay`_
|
||||
|
||||
.. _Getting Started: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/stable_baselines_getting_started.ipynb
|
||||
|
|
@ -27,7 +25,6 @@ notebooks:
|
|||
.. _Multiprocessing: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb
|
||||
.. _Monitor Training and Plotting: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/monitor_training.ipynb
|
||||
.. _Atari Games: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/atari_games.ipynb
|
||||
.. _Breakout: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/breakout.ipynb
|
||||
.. _Hindsight Experience Replay: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/stable_baselines_her.ipynb
|
||||
.. _RL Baselines zoo: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/rl-baselines-zoo.ipynb
|
||||
|
||||
|
|
@ -91,9 +88,9 @@ In the following example, we will train, save and load a A2C model on the Lunar
|
|||
|
||||
Multiprocessing: Unleashing the Power of Vectorized Environments
|
||||
----------------------------------------------------------------
|
||||
..
|
||||
.. .. image:: ../_static/img/colab-badge.svg
|
||||
.. :target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb
|
||||
|
||||
.. image:: ../_static/img/colab-badge.svg
|
||||
:target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb
|
||||
|
||||
.. figure:: https://cdn-images-1.medium.com/max/960/1*h4WTQNVIsvMXJTCpXm_TAw.gif
|
||||
|
||||
|
|
@ -106,7 +103,6 @@ Multiprocessing: Unleashing the Power of Vectorized Environments
|
|||
import numpy as np
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.ppo import MlpPolicy
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||
from stable_baselines3.common.cmd_util import make_vec_env
|
||||
from stable_baselines3.common.utils import set_random_seed
|
||||
|
|
@ -160,12 +156,9 @@ This could be useful when you want to monitor training, for instance display liv
|
|||
learning curves in Tensorboard (or in Visdom) or save the best agent.
|
||||
If your callback returns False, training is aborted early.
|
||||
|
||||
.. .. image:: ../_static/img/colab-badge.svg
|
||||
.. :target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/monitor_training.ipynb
|
||||
..
|
||||
.. .. figure:: ../_static/img/learning_curve.png
|
||||
..
|
||||
.. Learning curve of TD3 on LunarLanderContinuous environment
|
||||
.. image:: ../_static/img/colab-badge.svg
|
||||
:target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/monitor_training.ipynb
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
|
@ -176,7 +169,6 @@ If your callback returns False, training is aborted early.
|
|||
import matplotlib.pyplot as plt
|
||||
|
||||
from stable_baselines3 import TD3
|
||||
from stable_baselines3.td3 import MlpPolicy
|
||||
from stable_baselines3.common import results_plotter
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.results_plotter import load_results, ts2xy, plot_results
|
||||
|
|
@ -240,7 +232,7 @@ If your callback returns False, training is aborted early.
|
|||
n_actions = env.action_space.shape[-1]
|
||||
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
|
||||
# Because we use parameter noise, we should use a MlpPolicy with layer normalization
|
||||
model = TD3(MlpPolicy, env, action_noise=action_noise, verbose=0)
|
||||
model = TD3('MlpPolicy', env, action_noise=action_noise, verbose=0)
|
||||
# Create the callback: check every 1000 steps
|
||||
callback = SaveOnBestTrainingRewardCallback(check_freq=1000, log_dir=log_dir)
|
||||
# Train the agent
|
||||
|
|
@ -267,8 +259,8 @@ Training a RL agent on Atari games is straightforward thanks to ``make_atari_env
|
|||
It will do `all the preprocessing <https://danieltakeshi.github.io/2016/11/25/frame-skipping-and-preprocessing-for-deep-q-networks-on-atari-2600-games/>`_
|
||||
and multiprocessing for you.
|
||||
|
||||
.. .. image:: ../_static/img/colab-badge.svg
|
||||
.. :target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/atari_games.ipynb
|
||||
.. image:: ../_static/img/colab-badge.svg
|
||||
:target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/atari_games.ipynb
|
||||
..
|
||||
|
||||
.. code-block:: python
|
||||
|
|
|
|||
|
|
@ -3,13 +3,55 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
|
||||
Pre-Release 0.7.0a1 (WIP)
|
||||
Pre-Release 0.8.0a0 (WIP)
|
||||
------------------------------
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
- Updated notebook links
|
||||
|
||||
|
||||
Pre-Release 0.7.0 (2020-06-10)
|
||||
------------------------------
|
||||
|
||||
**Hotfix for PPO/A2C + gSDE, internal refactoring and bug fixes**
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- ``render()`` method of ``VecEnvs`` now only accept one argument: ``mode``
|
||||
- Created new file common/torch_layers.py, similar to SB refactoring
|
||||
|
||||
- Contains all PyTorch network layer definitions and feature extractors: ``MlpExtractor``, ``create_mlp``, ``NatureCNN``
|
||||
|
||||
- Renamed ``BaseRLModel`` to ``BaseAlgorithm`` (along with offpolicy and onpolicy variants)
|
||||
- Moved on-policy and off-policy base algorithms to ``common/on_policy_algorithm.py`` and ``common/off_policy_algorithm.py``, respectively.
|
||||
- Moved ``PPOPolicy`` to ``ActorCriticPolicy`` in common/policies.py
|
||||
- Moved ``PPO`` (algorithm class) into ``OnPolicyAlgorithm`` (``common/on_policy_algorithm.py``), to be shared with A2C
|
||||
- Moved following functions from ``BaseAlgorithm``:
|
||||
|
||||
- ``_load_from_file`` to ``load_from_zip_file`` (save_util.py)
|
||||
- ``_save_to_file_zip`` to ``save_to_zip_file`` (save_util.py)
|
||||
- ``safe_mean`` to ``safe_mean`` (utils.py)
|
||||
- ``check_env`` to ``check_for_correct_spaces`` (utils.py. Renamed to avoid confusion with environment checker tools)
|
||||
|
||||
- Moved static function ``_is_vectorized_observation`` from common/policies.py to common/utils.py under name ``is_vectorized_observation``.
|
||||
- Removed ``{save,load}_running_average`` functions of ``VecNormalize`` in favor of ``load/save``.
|
||||
- Removed ``use_gae`` parameter from ``RolloutBuffer.compute_returns_and_advantage``.
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -17,8 +59,10 @@ New Features:
|
|||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Fixed ``render()`` method for ``VecEnvs``
|
||||
- Fixed ``seed()``` method for ``SubprocVecEnv``
|
||||
- Fixed ``seed()`` method for ``SubprocVecEnv``
|
||||
- Fixed loading on GPU for testing when using gSDE and ``deterministic=False``
|
||||
- Fixed ``register_policy`` to allow re-registering same policy for same sub-class (i.e. assign same value to same key).
|
||||
- Fixed a bug where the gradient was passed when using ``gSDE`` with ``PPO``/``A2C``, this does not affect ``SAC``
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -26,10 +70,18 @@ Deprecations:
|
|||
Others:
|
||||
^^^^^^^
|
||||
- Re-enable unsafe ``fork`` start method in the tests (was causing a deadlock with tensorflow)
|
||||
- Added a test for seeding ``SubprocVecEnv``` and rendering
|
||||
- Added a test for seeding ``SubprocVecEnv`` and rendering
|
||||
- Fixed reference in NatureCNN (pointed to older version with different network architecture)
|
||||
- Fixed comments saying "CxWxH" instead of "CxHxW" (same style as in torch docs / commonly used)
|
||||
- Added bit further comments on register/getting policies ("MlpPolicy", "CnnPolicy").
|
||||
- Renamed ``progress`` (value from 1 in start of training to 0 in end) to ``progress_remaining``.
|
||||
- Added ``policies.py`` files for A2C/PPO, which define MlpPolicy/CnnPolicy (renamed ActorCriticPolicies).
|
||||
- Added some missing tests for ``VecNormalize``, ``VecCheckNan`` and ``PPO``.
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
- Added a paragraph on "MlpPolicy"/"CnnPolicy" and policy naming scheme under "Developer Guide"
|
||||
- Fixed second-level listing in changelog
|
||||
|
||||
|
||||
Pre-Release 0.6.0 (2020-06-01)
|
||||
|
|
@ -40,6 +92,7 @@ Pre-Release 0.6.0 (2020-06-01)
|
|||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Methods were renamed in the logger:
|
||||
|
||||
- ``logkv`` -> ``record``, ``writekvs`` -> ``write``, ``writeseq`` -> ``write_sequence``,
|
||||
- ``logkvs`` -> ``record_dict``, ``dumpkvs`` -> ``dump``,
|
||||
- ``getkvs`` -> ``get_log_dict``, ``logkv_mean`` -> ``record_mean``
|
||||
|
|
|
|||
|
|
@ -8,14 +8,29 @@ Base RL Class
|
|||
|
||||
Common interface for all the RL algorithms
|
||||
|
||||
.. autoclass:: BaseRLModel
|
||||
.. autoclass:: BaseAlgorithm
|
||||
:members:
|
||||
|
||||
|
||||
.. automodule:: stable_baselines3.common.off_policy_algorithm
|
||||
|
||||
|
||||
Base Off-Policy Class
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
The base RL model for Off-Policy algorithm (ex: SAC/TD3)
|
||||
The base RL algorithm for Off-Policy algorithm (ex: SAC/TD3)
|
||||
|
||||
.. autoclass:: OffPolicyRLModel
|
||||
.. autoclass:: OffPolicyAlgorithm
|
||||
:members:
|
||||
|
||||
|
||||
.. automodule:: stable_baselines3.common.on_policy_algorithm
|
||||
|
||||
|
||||
Base On-Policy Class
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
The base RL algorithm for On-Policy algorithm (ex: A2C/PPO)
|
||||
|
||||
.. autoclass:: OnPolicyAlgorithm
|
||||
:members:
|
||||
|
|
|
|||
|
|
@ -12,12 +12,6 @@ SAC is the successor of `Soft Q-Learning SQL <https://arxiv.org/abs/1702.08165>`
|
|||
A key feature of SAC, and a major difference with common RL algorithms, is that it is trained to maximize a trade-off between expected return and entropy, a measure of randomness in the policy.
|
||||
|
||||
|
||||
.. warning::
|
||||
|
||||
The SAC model does not support ``stable_baselines3.ppo.policies`` because it uses double q-values
|
||||
and value estimation, as a result it must use its own policy models (see :ref:`sac_policies`).
|
||||
|
||||
|
||||
.. rubric:: Available Policies
|
||||
|
||||
.. autosummary::
|
||||
|
|
|
|||
|
|
@ -12,12 +12,6 @@ TD3 is a direct successor of DDPG and improves it using three major tricks: clip
|
|||
We recommend reading `OpenAI Spinning guide on TD3 <https://spinningup.openai.com/en/latest/algorithms/td3.html>`_ to learn more about those.
|
||||
|
||||
|
||||
.. warning::
|
||||
|
||||
The TD3 model does not support ``stable_baselines3.ppo.policies`` because it uses double q-values
|
||||
estimation, as a result it must use its own policy models (see :ref:`td3_policies`).
|
||||
|
||||
|
||||
.. rubric:: Available Policies
|
||||
|
||||
.. autosummary::
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
from stable_baselines3.a2c.a2c import A2C
|
||||
from stable_baselines3.ppo.policies import MlpPolicy, CnnPolicy
|
||||
from stable_baselines3.a2c.policies import MlpPolicy, CnnPolicy
|
||||
|
|
|
|||
|
|
@ -4,13 +4,13 @@ from gym import spaces
|
|||
from typing import Type, Union, Callable, Optional, Dict, Any
|
||||
|
||||
from stable_baselines3.common import logger
|
||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
|
||||
from stable_baselines3.common.utils import explained_variance
|
||||
from stable_baselines3.ppo.policies import PPOPolicy
|
||||
from stable_baselines3.ppo.ppo import PPO
|
||||
from stable_baselines3.common.policies import ActorCriticPolicy
|
||||
|
||||
|
||||
class A2C(PPO):
|
||||
class A2C(OnPolicyAlgorithm):
|
||||
"""
|
||||
Advantage Actor Critic (A2C)
|
||||
|
||||
|
|
@ -20,7 +20,7 @@ class A2C(PPO):
|
|||
|
||||
Introduction to A2C: https://hackernoon.com/intuitive-rl-intro-to-advantage-actor-critic-a2c-4ff545978752
|
||||
|
||||
:param policy: (PPOPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, ...)
|
||||
:param policy: (ActorCriticPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, ...)
|
||||
:param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str)
|
||||
:param learning_rate: (float or callable) The learning rate, it can be a function
|
||||
:param n_steps: (int) The number of steps to run for each environment per update
|
||||
|
|
@ -49,7 +49,8 @@ class A2C(PPO):
|
|||
Setting it to auto, the code will be run on the GPU if possible.
|
||||
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
def __init__(self, policy: Union[str, Type[PPOPolicy]],
|
||||
|
||||
def __init__(self, policy: Union[str, Type[ActorCriticPolicy]],
|
||||
env: Union[GymEnv, str],
|
||||
learning_rate: Union[float, Callable] = 7e-4,
|
||||
n_steps: int = 5,
|
||||
|
|
@ -72,16 +73,17 @@ class A2C(PPO):
|
|||
_init_setup_model: bool = True):
|
||||
|
||||
super(A2C, self).__init__(policy, env, learning_rate=learning_rate,
|
||||
n_steps=n_steps, batch_size=None, n_epochs=1,
|
||||
gamma=gamma, gae_lambda=gae_lambda, ent_coef=ent_coef,
|
||||
vf_coef=vf_coef, max_grad_norm=max_grad_norm,
|
||||
n_steps=n_steps, gamma=gamma, gae_lambda=gae_lambda,
|
||||
ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm,
|
||||
use_sde=use_sde, sde_sample_freq=sde_sample_freq,
|
||||
tensorboard_log=tensorboard_log, policy_kwargs=policy_kwargs,
|
||||
verbose=verbose, device=device, create_eval_env=create_eval_env,
|
||||
seed=seed, _init_setup_model=False)
|
||||
|
||||
self.normalize_advantage = normalize_advantage
|
||||
# Override PPO optimizer to match original implementation
|
||||
|
||||
# Update optimizer inside the policy if we want to use RMSProp
|
||||
# (original implementation) rather than Adam
|
||||
if use_rms_prop and 'optimizer_class' not in self.policy_kwargs:
|
||||
self.policy_kwargs['optimizer_class'] = th.optim.RMSprop
|
||||
self.policy_kwargs['optimizer_kwargs'] = dict(alpha=0.99, eps=rms_prop_eps,
|
||||
|
|
@ -90,13 +92,13 @@ class A2C(PPO):
|
|||
if _init_setup_model:
|
||||
self._setup_model()
|
||||
|
||||
def train(self, gradient_steps: int, batch_size: Optional[int] = None) -> None:
|
||||
def train(self) -> None:
|
||||
"""
|
||||
Update policy using the currently gathered
|
||||
rollout buffer (one gradient step over whole data).
|
||||
"""
|
||||
# Update optimizer learning rate
|
||||
self._update_learning_rate(self.policy.optimizer)
|
||||
# A2C with gradient_steps > 1 does not make sense
|
||||
assert gradient_steps == 1, "A2C does not support multiple gradient steps"
|
||||
# We do not use minibatches for A2C
|
||||
assert batch_size is None, "A2C does not support minibatch"
|
||||
|
||||
for rollout_data in self.rollout_buffer.get(batch_size=None):
|
||||
|
||||
|
|
@ -160,7 +162,7 @@ class A2C(PPO):
|
|||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True) -> 'A2C':
|
||||
|
||||
return super(A2C, self).learn(total_timesteps=total_timesteps, callback=callback, log_interval=log_interval,
|
||||
eval_env=eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes,
|
||||
tb_log_name=tb_log_name, eval_log_path=eval_log_path,
|
||||
reset_num_timesteps=reset_num_timesteps)
|
||||
return super(A2C, self).learn(total_timesteps=total_timesteps, callback=callback,
|
||||
log_interval=log_interval, eval_env=eval_env, eval_freq=eval_freq,
|
||||
n_eval_episodes=n_eval_episodes, tb_log_name=tb_log_name,
|
||||
eval_log_path=eval_log_path, reset_num_timesteps=reset_num_timesteps)
|
||||
|
|
|
|||
9
stable_baselines3/a2c/policies.py
Normal file
9
stable_baselines3/a2c/policies.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
# This file is here just to define MlpPolicy/CnnPolicy
|
||||
# that work for A2C
|
||||
from stable_baselines3.common.policies import ActorCriticPolicy, ActorCriticCnnPolicy, register_policy
|
||||
|
||||
MlpPolicy = ActorCriticPolicy
|
||||
CnnPolicy = ActorCriticCnnPolicy
|
||||
|
||||
register_policy("MlpPolicy", ActorCriticPolicy)
|
||||
register_policy("CnnPolicy", ActorCriticCnnPolicy)
|
||||
|
|
@ -1,8 +1,4 @@
|
|||
import time
|
||||
import os
|
||||
import io
|
||||
import zipfile
|
||||
import pickle
|
||||
from typing import Union, Type, Optional, Dict, Any, List, Tuple, Callable
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
|
|
@ -13,27 +9,28 @@ import numpy as np
|
|||
|
||||
from stable_baselines3.common import logger, utils
|
||||
from stable_baselines3.common.policies import BasePolicy, get_policy_from_name
|
||||
from stable_baselines3.common.utils import set_random_seed, get_schedule_fn, update_learning_rate, get_device
|
||||
from stable_baselines3.common.utils import (set_random_seed, get_schedule_fn, update_learning_rate, get_device,
|
||||
check_for_correct_spaces)
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, unwrap_vec_normalize, VecNormalize, VecTransposeImage
|
||||
from stable_baselines3.common.preprocessing import is_image_space
|
||||
from stable_baselines3.common.save_util import data_to_json, json_to_data, recursive_getattr, recursive_setattr
|
||||
from stable_baselines3.common.type_aliases import GymEnv, TensorDict, RolloutReturn, MaybeCallback
|
||||
from stable_baselines3.common.save_util import (recursive_getattr, recursive_setattr, save_to_zip_file,
|
||||
load_from_zip_file)
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
|
||||
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, EvalCallback
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.noise import ActionNoise
|
||||
from stable_baselines3.common.buffers import ReplayBuffer
|
||||
|
||||
|
||||
class BaseRLModel(ABC):
|
||||
class BaseAlgorithm(ABC):
|
||||
"""
|
||||
The base RL model
|
||||
The base of RL algorithms
|
||||
|
||||
:param policy: (Type[BasePolicy]) Policy object
|
||||
:param env: (Union[GymEnv, str]) The environment to learn from
|
||||
(if registered in Gym, can be str. Can be None for loading trained models)
|
||||
:param policy_base: (Type[BasePolicy]) The base policy used by this method
|
||||
:param learning_rate: (float or callable) learning rate for the optimizer,
|
||||
it can be a function of the current progress (from 1 to 0)
|
||||
it can be a function of the current progress remaining (from 1 to 0)
|
||||
:param policy_kwargs: (Dict[str, Any]) Additional arguments to be passed to the policy on creation
|
||||
:param tensorboard_log: (str) the log location for tensorboard (if None, no logging)
|
||||
:param verbose: (int) The verbosity level: 0 none, 1 training information, 2 debug
|
||||
|
|
@ -102,9 +99,9 @@ class BaseRLModel(ABC):
|
|||
# Used for gSDE only
|
||||
self.use_sde = use_sde
|
||||
self.sde_sample_freq = sde_sample_freq
|
||||
# Track the training progress (from 1 to 0)
|
||||
# Track the training progress remaining (from 1 to 0)
|
||||
# this is used to update the learning rate
|
||||
self._current_progress = 1
|
||||
self._current_progress_remaining = 1
|
||||
# Buffers for logging
|
||||
self.ep_info_buffer = None # type: Optional[deque]
|
||||
self.ep_success_buffer = None # type: Optional[deque]
|
||||
|
|
@ -176,41 +173,30 @@ class BaseRLModel(ABC):
|
|||
"""Transform to callable if needed."""
|
||||
self.lr_schedule = get_schedule_fn(self.learning_rate)
|
||||
|
||||
def _update_current_progress(self, num_timesteps: int, total_timesteps: int) -> None:
|
||||
def _update_current_progress_remaining(self, num_timesteps: int, total_timesteps: int) -> None:
|
||||
"""
|
||||
Compute current progress (from 1 to 0)
|
||||
Compute current progress remaining (starts from 1 and ends to 0)
|
||||
|
||||
:param num_timesteps: current number of timesteps
|
||||
:param total_timesteps:
|
||||
"""
|
||||
self._current_progress = 1.0 - float(num_timesteps) / float(total_timesteps)
|
||||
self._current_progress_remaining = 1.0 - float(num_timesteps) / float(total_timesteps)
|
||||
|
||||
def _update_learning_rate(self, optimizers: Union[List[th.optim.Optimizer], th.optim.Optimizer]) -> None:
|
||||
"""
|
||||
Update the optimizers learning rate using the current learning rate schedule
|
||||
and the current progress (from 1 to 0).
|
||||
and the current progress remaining (from 1 to 0).
|
||||
|
||||
:param optimizers: (Union[List[th.optim.Optimizer], th.optim.Optimizer])
|
||||
An optimizer or a list of optimizers.
|
||||
"""
|
||||
# Log the current learning rate
|
||||
logger.record("train/learning_rate", self.lr_schedule(self._current_progress))
|
||||
logger.record("train/learning_rate", self.lr_schedule(self._current_progress_remaining))
|
||||
|
||||
if not isinstance(optimizers, list):
|
||||
optimizers = [optimizers]
|
||||
for optimizer in optimizers:
|
||||
update_learning_rate(optimizer, self.lr_schedule(self._current_progress))
|
||||
|
||||
@staticmethod
|
||||
def safe_mean(arr: Union[np.ndarray, list, deque]) -> np.ndarray:
|
||||
"""
|
||||
Compute the mean of an array if there is at least one element.
|
||||
For empty array, return NaN. It is used for logging only.
|
||||
|
||||
:param arr:
|
||||
:return:
|
||||
"""
|
||||
return np.nan if len(arr) == 0 else np.mean(arr)
|
||||
update_learning_rate(optimizer, self.lr_schedule(self._current_progress_remaining))
|
||||
|
||||
def get_env(self) -> Optional[VecEnv]:
|
||||
"""
|
||||
|
|
@ -228,26 +214,6 @@ class BaseRLModel(ABC):
|
|||
"""
|
||||
return self._vec_normalize_env
|
||||
|
||||
@staticmethod
|
||||
def check_env(env: GymEnv, observation_space: gym.spaces.Space, action_space: gym.spaces.Space):
|
||||
"""
|
||||
Checks the validity of the environment to load vs the one used for training.
|
||||
Checked parameters:
|
||||
- observation_space
|
||||
- action_space
|
||||
|
||||
:param env: (GymEnv)
|
||||
:param observation_space: (gym.spaces.Space)
|
||||
:param action_space: (gym.spaces.Space)
|
||||
"""
|
||||
if (observation_space != env.observation_space
|
||||
# Special cases for images that need to be transposed
|
||||
and not (is_image_space(env.observation_space)
|
||||
and observation_space == VecTransposeImage.transpose_space(env.observation_space))):
|
||||
raise ValueError(f'Observation spaces do not match: {observation_space} != {env.observation_space}')
|
||||
if action_space != env.action_space:
|
||||
raise ValueError(f'Action spaces do not match: {action_space} != {env.action_space}')
|
||||
|
||||
def set_env(self, env: GymEnv) -> None:
|
||||
"""
|
||||
Checks the validity of the environment, and if it is coherent, set it as the current environment.
|
||||
|
|
@ -258,7 +224,7 @@ class BaseRLModel(ABC):
|
|||
|
||||
:param env: The environment for learning a policy
|
||||
"""
|
||||
self.check_env(env, self.observation_space, self.action_space)
|
||||
check_for_correct_spaces(env, self.observation_space, self.action_space)
|
||||
# it must be coherent now
|
||||
# if it is not a VecEnv, make it a VecEnv
|
||||
env = self._wrap_env(env)
|
||||
|
|
@ -288,7 +254,7 @@ class BaseRLModel(ABC):
|
|||
eval_freq: int = -1,
|
||||
n_eval_episodes: int = 5,
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True) -> 'BaseRLModel':
|
||||
reset_num_timesteps: bool = True) -> 'BaseAlgorithm':
|
||||
"""
|
||||
Return a trained model.
|
||||
|
||||
|
|
@ -297,13 +263,12 @@ class BaseRLModel(ABC):
|
|||
It takes the local and global variables. If it returns False, training is aborted.
|
||||
:param log_interval: (int) The number of timesteps before logging.
|
||||
:param tb_log_name: (str) the name of the run for tensorboard log
|
||||
:param reset_num_timesteps: (bool) whether or not to reset the current timestep number (used in logging)
|
||||
:param eval_env: (gym.Env) Environment that will be used to evaluate the agent
|
||||
:param eval_freq: (int) Evaluate the agent every ``eval_freq`` timesteps (this may vary a little)
|
||||
:param n_eval_episodes: (int) Number of episode to evaluate the agent
|
||||
:param eval_log_path: (Optional[str]) Path to a folder where the evaluations will be saved
|
||||
:param reset_num_timesteps: (bool)
|
||||
:return: (BaseRLModel) the trained model
|
||||
:param reset_num_timesteps: (bool) whether or not to reset the current timestep number (used in logging)
|
||||
:return: (BaseAlgorithm) the trained model
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
|
@ -333,7 +298,7 @@ class BaseRLModel(ABC):
|
|||
(can be None if you only need prediction from a trained model) has priority over any saved environment
|
||||
:param kwargs: extra arguments to change the model when loading
|
||||
"""
|
||||
data, params, tensors = cls._load_from_file(load_path)
|
||||
data, params, tensors = load_from_zip_file(load_path)
|
||||
|
||||
if 'policy_kwargs' in data:
|
||||
for arg_to_remove in ['device']:
|
||||
|
|
@ -349,7 +314,7 @@ class BaseRLModel(ABC):
|
|||
raise ValueError("The observation_space and action_space was not given, can't verify new environments")
|
||||
# check if given env is valid
|
||||
if env is not None:
|
||||
cls.check_env(env, data["observation_space"], data["action_space"])
|
||||
check_for_correct_spaces(env, data["observation_space"], data["action_space"])
|
||||
# if no new env was given use stored env if possible
|
||||
if env is None and "env" in data:
|
||||
env = data["env"]
|
||||
|
|
@ -380,79 +345,6 @@ class BaseRLModel(ABC):
|
|||
model.policy.reset_noise()
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _load_from_file(load_path: str, load_data: bool = True) -> (Tuple[Optional[Dict[str, Any]],
|
||||
Optional[TensorDict],
|
||||
Optional[TensorDict]]):
|
||||
""" Load model data from a .zip archive
|
||||
|
||||
:param load_path: Where to load the model from
|
||||
:param load_data: Whether we should load and return data
|
||||
(class parameters). Mainly used by 'load_parameters' to only load model parameters (weights)
|
||||
:return: (dict),(dict),(dict) Class parameters, model state_dicts (dict of state_dict)
|
||||
and dict of extra tensors
|
||||
"""
|
||||
# Check if file exists if load_path is a string
|
||||
if isinstance(load_path, str):
|
||||
if not os.path.exists(load_path):
|
||||
if os.path.exists(load_path + ".zip"):
|
||||
load_path += ".zip"
|
||||
else:
|
||||
raise ValueError(f"Error: the file {load_path} could not be found")
|
||||
|
||||
# set device to cpu if cuda is not available
|
||||
device = get_device()
|
||||
|
||||
# Open the zip archive and load data
|
||||
try:
|
||||
with zipfile.ZipFile(load_path, "r") as archive:
|
||||
namelist = archive.namelist()
|
||||
# If data or parameters is not in the
|
||||
# zip archive, assume they were stored
|
||||
# as None (_save_to_file_zip allows this).
|
||||
data = None
|
||||
tensors = None
|
||||
params = {}
|
||||
|
||||
if "data" in namelist and load_data:
|
||||
# Load class parameters and convert to string
|
||||
json_data = archive.read("data").decode()
|
||||
data = json_to_data(json_data)
|
||||
|
||||
if "tensors.pth" in namelist and load_data:
|
||||
# Load extra tensors
|
||||
with archive.open('tensors.pth', mode="r") as tensor_file:
|
||||
# File has to be seekable, but opt_param_file is not, so load in BytesIO first
|
||||
# fixed in python >= 3.7
|
||||
file_content = io.BytesIO()
|
||||
file_content.write(tensor_file.read())
|
||||
# go to start of file
|
||||
file_content.seek(0)
|
||||
# load the parameters with the right ``map_location``
|
||||
tensors = th.load(file_content, map_location=device)
|
||||
|
||||
# check for all other .pth files
|
||||
other_files = [file_name for file_name in namelist if
|
||||
os.path.splitext(file_name)[1] == ".pth" and file_name != "tensors.pth"]
|
||||
# if there are any other files which end with .pth and aren't "params.pth"
|
||||
# assume that they each are optimizer parameters
|
||||
if len(other_files) > 0:
|
||||
for file_path in other_files:
|
||||
with archive.open(file_path, mode="r") as opt_param_file:
|
||||
# File has to be seekable, but opt_param_file is not, so load in BytesIO first
|
||||
# fixed in python >= 3.7
|
||||
file_content = io.BytesIO()
|
||||
file_content.write(opt_param_file.read())
|
||||
# go to start of file
|
||||
file_content.seek(0)
|
||||
# load the parameters with the right ``map_location``
|
||||
params[os.path.splitext(file_path)[0]] = th.load(file_content, map_location=device)
|
||||
|
||||
except zipfile.BadZipFile:
|
||||
# load_path wasn't a zip file
|
||||
raise ValueError(f"Error: the file {load_path} wasn't a zip-file")
|
||||
return data, params, tensors
|
||||
|
||||
def set_random_seed(self, seed: Optional[int] = None) -> None:
|
||||
"""
|
||||
Set the seed of the pseudo-random generators
|
||||
|
|
@ -513,8 +405,8 @@ class BaseRLModel(ABC):
|
|||
:param total_timesteps: (int) The total number of samples (env steps) to train on
|
||||
:param eval_env: (Optional[GymEnv])
|
||||
:param callback: (Union[None, BaseCallback, List[BaseCallback, Callable]])
|
||||
:param eval_freq: (int)
|
||||
:param n_eval_episodes: (int)
|
||||
:param eval_freq: (int) How many steps between evaluations
|
||||
:param n_eval_episodes: (int) How many episodes to play per evaluation
|
||||
:param log_path (Optional[str]): Path to a log folder
|
||||
:param reset_num_timesteps: (bool) Whether to reset or not the ``num_timesteps`` attribute
|
||||
:param tb_log_name: (str) the name of the run for tensorboard log
|
||||
|
|
@ -571,45 +463,6 @@ class BaseRLModel(ABC):
|
|||
if maybe_is_success is not None and dones[idx]:
|
||||
self.ep_success_buffer.append(maybe_is_success)
|
||||
|
||||
@staticmethod
|
||||
def _save_to_file_zip(save_path: str, data: Dict[str, Any] = None,
|
||||
params: Dict[str, Any] = None, tensors: Dict[str, Any] = None) -> None:
|
||||
"""
|
||||
Save model to a zip archive.
|
||||
|
||||
:param save_path: Where to store the model
|
||||
:param data: Class parameters being stored
|
||||
:param params: Model parameters being stored expected to contain an entry for every
|
||||
state_dict with its name and the state_dict
|
||||
:param tensors: Extra tensor variables expected to contain name and value of tensors
|
||||
"""
|
||||
|
||||
# data/params can be None, so do not
|
||||
# try to serialize them blindly
|
||||
if data is not None:
|
||||
serialized_data = data_to_json(data)
|
||||
|
||||
# Check postfix if save_path is a string
|
||||
if isinstance(save_path, str):
|
||||
_, ext = os.path.splitext(save_path)
|
||||
if ext == "":
|
||||
save_path += ".zip"
|
||||
|
||||
# Create a zip-archive and write our objects
|
||||
# there. This works when save_path is either
|
||||
# str or a file-like
|
||||
with zipfile.ZipFile(save_path, "w") as archive:
|
||||
# Do not try to save "None" elements
|
||||
if data is not None:
|
||||
archive.writestr("data", serialized_data)
|
||||
if tensors is not None:
|
||||
with archive.open('tensors.pth', mode="w") as tensors_file:
|
||||
th.save(tensors, tensors_file)
|
||||
if params is not None:
|
||||
for file_name, dict_ in params.items():
|
||||
with archive.open(file_name + '.pth', mode="w") as param_file:
|
||||
th.save(dict_, param_file)
|
||||
|
||||
def excluded_save_params(self) -> List[str]:
|
||||
"""
|
||||
Returns the names of the parameters that should be excluded by default
|
||||
|
|
@ -668,279 +521,4 @@ class BaseRLModel(ABC):
|
|||
# Retrieve state dict
|
||||
params_to_save[name] = attr.state_dict()
|
||||
|
||||
self._save_to_file_zip(path, data=data, params=params_to_save, tensors=tensors)
|
||||
|
||||
|
||||
class OffPolicyRLModel(BaseRLModel):
|
||||
"""
|
||||
The base RL model for Off-Policy algorithm (ex: SAC/TD3)
|
||||
|
||||
:param policy: Policy object
|
||||
:param env: The environment to learn from
|
||||
(if registered in Gym, can be str. Can be None for loading trained models)
|
||||
:param policy_base: The base policy used by this method
|
||||
:param learning_rate: (float or callable) learning rate for the optimizer,
|
||||
it can be a function of the current progress (from 1 to 0)
|
||||
:param buffer_size: (int) size of the replay buffer
|
||||
:param learning_starts: (int) how many steps of the model to collect transitions for before learning starts
|
||||
:param batch_size: (int) Minibatch size for each gradient update
|
||||
:param policy_kwargs: Additional arguments to be passed to the policy on creation
|
||||
:param verbose: The verbosity level: 0 none, 1 training information, 2 debug
|
||||
:param device: Device on which the code should run.
|
||||
By default, it will try to use a Cuda compatible device and fallback to cpu
|
||||
if it is not possible.
|
||||
:param support_multi_env: Whether the algorithm supports training
|
||||
with multiple environments (as in A2C)
|
||||
:param create_eval_env: Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||
:param monitor_wrapper: When creating an environment, whether to wrap it
|
||||
or not in a Monitor wrapper.
|
||||
:param seed: Seed for the pseudo random generators
|
||||
:param use_sde: Whether to use State Dependent Exploration (SDE)
|
||||
instead of action noise exploration (default: False)
|
||||
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
|
||||
Default: -1 (only sample at the beginning of the rollout)
|
||||
:param use_sde_at_warmup: (bool) Whether to use gSDE instead of uniform sampling
|
||||
during the warm up phase (before learning starts)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
policy: Type[BasePolicy],
|
||||
env: Union[GymEnv, str],
|
||||
policy_base: Type[BasePolicy],
|
||||
learning_rate: Union[float, Callable],
|
||||
buffer_size: int = int(1e6),
|
||||
learning_starts: int = 100,
|
||||
batch_size: int = 256,
|
||||
policy_kwargs: Dict[str, Any] = None,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
verbose: int = 0,
|
||||
device: Union[th.device, str] = 'auto',
|
||||
support_multi_env: bool = False,
|
||||
create_eval_env: bool = False,
|
||||
monitor_wrapper: bool = True,
|
||||
seed: Optional[int] = None,
|
||||
use_sde: bool = False,
|
||||
sde_sample_freq: int = -1,
|
||||
use_sde_at_warmup: bool = False):
|
||||
|
||||
super(OffPolicyRLModel, self).__init__(policy, env, policy_base, learning_rate,
|
||||
policy_kwargs, tensorboard_log, verbose,
|
||||
device, support_multi_env, create_eval_env, monitor_wrapper,
|
||||
seed, use_sde, sde_sample_freq)
|
||||
self.buffer_size = buffer_size
|
||||
self.batch_size = batch_size
|
||||
self.learning_starts = learning_starts
|
||||
self.actor = None # type: Optional[th.nn.Module]
|
||||
self.replay_buffer = None # type: Optional[ReplayBuffer]
|
||||
# Update policy keyword arguments
|
||||
self.policy_kwargs['use_sde'] = self.use_sde
|
||||
self.policy_kwargs['device'] = self.device
|
||||
# For SDE only
|
||||
self.rollout_data = None
|
||||
self.on_policy_exploration = False
|
||||
self.use_sde_at_warmup = use_sde_at_warmup
|
||||
|
||||
def _setup_model(self):
|
||||
self._setup_lr_schedule()
|
||||
self.set_random_seed(self.seed)
|
||||
self.replay_buffer = ReplayBuffer(self.buffer_size, self.observation_space,
|
||||
self.action_space, self.device)
|
||||
self.policy = self.policy_class(self.observation_space, self.action_space,
|
||||
self.lr_schedule, **self.policy_kwargs)
|
||||
self.policy = self.policy.to(self.device)
|
||||
|
||||
def save_replay_buffer(self, path: str):
|
||||
"""
|
||||
Save the replay buffer as a pickle file.
|
||||
|
||||
:param path: (str) Path to a log folder
|
||||
"""
|
||||
assert self.replay_buffer is not None, "The replay buffer is not defined"
|
||||
with open(os.path.join(path, 'replay_buffer.pkl'), 'wb') as file_handler:
|
||||
pickle.dump(self.replay_buffer, file_handler)
|
||||
|
||||
def load_replay_buffer(self, path: str):
|
||||
"""
|
||||
|
||||
:param path: (str) Path to the pickled replay buffer.
|
||||
"""
|
||||
with open(path, 'rb') as file_handler:
|
||||
self.replay_buffer = pickle.load(file_handler)
|
||||
assert isinstance(self.replay_buffer, ReplayBuffer), 'The replay buffer must inherit from ReplayBuffer class'
|
||||
|
||||
def collect_rollouts(self, # noqa: C901
|
||||
env: VecEnv,
|
||||
# Type hint as string to avoid circular import
|
||||
callback: 'BaseCallback',
|
||||
n_episodes: int = 1,
|
||||
n_steps: int = -1,
|
||||
action_noise: Optional[ActionNoise] = None,
|
||||
learning_starts: int = 0,
|
||||
replay_buffer: Optional[ReplayBuffer] = None,
|
||||
log_interval: Optional[int] = None) -> RolloutReturn:
|
||||
"""
|
||||
Collect rollout using the current policy (and possibly fill the replay buffer)
|
||||
|
||||
:param env: (VecEnv) The training environment
|
||||
:param n_episodes: (int) Number of episodes to use to collect rollout data
|
||||
You can also specify a ``n_steps`` instead
|
||||
:param n_steps: (int) Number of steps to use to collect rollout data
|
||||
You can also specify a ``n_episodes`` instead.
|
||||
:param action_noise: (Optional[ActionNoise]) Action noise that will be used for exploration
|
||||
Required for deterministic policy (e.g. TD3). This can also be used
|
||||
in addition to the stochastic policy for SAC.
|
||||
:param callback: (BaseCallback) Callback that will be called at each step
|
||||
(and at the beginning and end of the rollout)
|
||||
:param learning_starts: (int) Number of steps before learning for the warm-up phase.
|
||||
:param replay_buffer: (ReplayBuffer)
|
||||
:param log_interval: (int) Log data every ``log_interval`` episodes
|
||||
:return: (RolloutReturn)
|
||||
"""
|
||||
episode_rewards, total_timesteps = [], []
|
||||
total_steps, total_episodes = 0, 0
|
||||
|
||||
assert isinstance(env, VecEnv), "You must pass a VecEnv"
|
||||
assert env.num_envs == 1, "OffPolicyRLModel only support single environment"
|
||||
|
||||
self.rollout_data = None
|
||||
if self.use_sde:
|
||||
self.actor.reset_noise()
|
||||
# Reset rollout data
|
||||
if self.on_policy_exploration:
|
||||
self.rollout_data = {key: [] for key in ['observations', 'actions', 'rewards', 'dones', 'values']}
|
||||
|
||||
callback.on_rollout_start()
|
||||
continue_training = True
|
||||
|
||||
while total_steps < n_steps or total_episodes < n_episodes:
|
||||
done = False
|
||||
episode_reward, episode_timesteps = 0.0, 0
|
||||
|
||||
while not done:
|
||||
|
||||
if self.use_sde and self.sde_sample_freq > 0 and total_steps % self.sde_sample_freq == 0:
|
||||
# Sample a new noise matrix
|
||||
self.actor.reset_noise()
|
||||
|
||||
# Select action randomly or according to policy
|
||||
if self.num_timesteps < learning_starts and not (self.use_sde and self.use_sde_at_warmup):
|
||||
# Warmup phase
|
||||
unscaled_action = np.array([self.action_space.sample()])
|
||||
else:
|
||||
# Note: we assume that the policy uses tanh to scale the action
|
||||
# We use non-deterministic action in the case of SAC, for TD3, it does not matter
|
||||
unscaled_action, _ = self.predict(self._last_obs, deterministic=False)
|
||||
|
||||
# Rescale the action from [low, high] to [-1, 1]
|
||||
scaled_action = self.policy.scale_action(unscaled_action)
|
||||
|
||||
if self.use_sde:
|
||||
# When using SDE, the action can be out of bounds
|
||||
# TODO: fix with squashing and account for that in the proba distribution
|
||||
clipped_action = np.clip(scaled_action, -1, 1)
|
||||
else:
|
||||
clipped_action = scaled_action
|
||||
|
||||
# Add noise to the action (improve exploration)
|
||||
if action_noise is not None:
|
||||
# NOTE: in the original implementation of TD3, the noise was applied to the unscaled action
|
||||
# Update(October 2019): Not anymore
|
||||
clipped_action = np.clip(clipped_action + action_noise(), -1, 1)
|
||||
|
||||
# Rescale and perform action
|
||||
new_obs, reward, done, infos = env.step(self.policy.unscale_action(clipped_action))
|
||||
|
||||
# Only stop training if return value is False, not when it is None.
|
||||
if callback.on_step() is False:
|
||||
return RolloutReturn(0.0, total_steps, total_episodes, continue_training=False)
|
||||
|
||||
episode_reward += reward
|
||||
|
||||
# Retrieve reward and episode length if using Monitor wrapper
|
||||
self._update_info_buffer(infos, done)
|
||||
|
||||
# Store data in replay buffer
|
||||
if replay_buffer is not None:
|
||||
# Store only the unnormalized version
|
||||
if self._vec_normalize_env is not None:
|
||||
new_obs_ = self._vec_normalize_env.get_original_obs()
|
||||
reward_ = self._vec_normalize_env.get_original_reward()
|
||||
else:
|
||||
# Avoid changing the original ones
|
||||
self._last_original_obs, new_obs_, reward_ = self._last_obs, new_obs, reward
|
||||
|
||||
replay_buffer.add(self._last_original_obs, new_obs_, clipped_action, reward_, done)
|
||||
|
||||
if self.rollout_data is not None:
|
||||
# Assume only one env
|
||||
self.rollout_data['observations'].append(self._last_obs[0].copy())
|
||||
self.rollout_data['actions'].append(scaled_action[0].copy())
|
||||
self.rollout_data['rewards'].append(reward[0].copy())
|
||||
self.rollout_data['dones'].append(done[0].copy())
|
||||
obs_tensor = th.FloatTensor(self._last_obs).to(self.device)
|
||||
self.rollout_data['values'].append(self.vf_net(obs_tensor)[0].cpu().detach().numpy())
|
||||
|
||||
self._last_obs = new_obs
|
||||
# Save the unnormalized observation
|
||||
if self._vec_normalize_env is not None:
|
||||
self._last_original_obs = new_obs_
|
||||
|
||||
self.num_timesteps += 1
|
||||
episode_timesteps += 1
|
||||
total_steps += 1
|
||||
if 0 < n_steps <= total_steps:
|
||||
break
|
||||
|
||||
if done:
|
||||
total_episodes += 1
|
||||
self._episode_num += 1
|
||||
episode_rewards.append(episode_reward)
|
||||
total_timesteps.append(episode_timesteps)
|
||||
if action_noise is not None:
|
||||
action_noise.reset()
|
||||
|
||||
# Log training infos
|
||||
if log_interval is not None and self._episode_num % log_interval == 0:
|
||||
fps = int(self.num_timesteps / (time.time() - self.start_time))
|
||||
logger.record("time/episodes", self._episode_num, exclude="tensorboard")
|
||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
||||
logger.record('rollout/ep_rew_mean', self.safe_mean([ep_info['r'] for ep_info in self.ep_info_buffer]))
|
||||
logger.record('rollout/ep_len_mean', self.safe_mean([ep_info['l'] for ep_info in self.ep_info_buffer]))
|
||||
logger.record("time/fps", fps)
|
||||
logger.record('time/time_elapsed', int(time.time() - self.start_time), exclude="tensorboard")
|
||||
logger.record("time/total timesteps", self.num_timesteps, exclude="tensorboard")
|
||||
if self.use_sde:
|
||||
logger.record("train/std", (self.actor.get_std()).mean().item())
|
||||
|
||||
if len(self.ep_success_buffer) > 0:
|
||||
logger.record('rollout/success rate', self.safe_mean(self.ep_success_buffer))
|
||||
# Pass the number of timesteps for tensorboard
|
||||
logger.dump(step=self.num_timesteps)
|
||||
|
||||
mean_reward = np.mean(episode_rewards) if total_episodes > 0 else 0.0
|
||||
|
||||
# Post processing
|
||||
if self.rollout_data is not None:
|
||||
for key in ['observations', 'actions', 'rewards', 'dones', 'values']:
|
||||
self.rollout_data[key] = th.FloatTensor(np.array(self.rollout_data[key])).to(self.device)
|
||||
|
||||
self.rollout_data['returns'] = self.rollout_data['rewards'].clone()
|
||||
self.rollout_data['advantage'] = self.rollout_data['rewards'].clone()
|
||||
|
||||
# Compute return and advantage
|
||||
last_return = 0.0
|
||||
for step in reversed(range(len(self.rollout_data['rewards']))):
|
||||
if step == len(self.rollout_data['rewards']) - 1:
|
||||
next_non_terminal = 1.0 - done[0]
|
||||
next_value = self.vf_net(th.FloatTensor(self._last_obs).to(self.device))[0].detach()
|
||||
last_return = self.rollout_data['rewards'][step] + next_non_terminal * next_value
|
||||
else:
|
||||
next_non_terminal = 1.0 - self.rollout_data['dones'][step + 1]
|
||||
last_return = self.rollout_data['rewards'][step] + self.gamma * last_return * next_non_terminal
|
||||
self.rollout_data['returns'][step] = last_return
|
||||
self.rollout_data['advantage'] = self.rollout_data['returns'] - self.rollout_data['values']
|
||||
|
||||
callback.on_rollout_end()
|
||||
|
||||
return RolloutReturn(mean_reward, total_steps, total_episodes, continue_training)
|
||||
save_to_zip_file(path, data=data, params=params_to_save, tensors=tensors)
|
||||
|
|
|
|||
|
|
@ -240,49 +240,36 @@ class RolloutBuffer(BaseBuffer):
|
|||
|
||||
def compute_returns_and_advantage(self,
|
||||
last_value: th.Tensor,
|
||||
dones: np.ndarray,
|
||||
use_gae: bool = True) -> None:
|
||||
dones: np.ndarray) -> None:
|
||||
"""
|
||||
Post-processing step: compute the returns (sum of discounted rewards)
|
||||
and advantage (A(s) = R - V(S)).
|
||||
and GAE advantage.
|
||||
Adapted from Stable-Baselines PPO2.
|
||||
|
||||
Uses Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
|
||||
to compute the advantage. To obtain vanilla advantage (A(s) = R - V(S))
|
||||
where R is the discounted reward with value bootstrap,
|
||||
set ``gae_lambda=1.0`` during initialization.
|
||||
|
||||
:param last_value: (th.Tensor)
|
||||
:param dones: (np.ndarray)
|
||||
:param use_gae: (bool) Whether to use Generalized Advantage Estimation
|
||||
or normal advantage for advantage computation.
|
||||
|
||||
"""
|
||||
# convert to numpy
|
||||
last_value = last_value.clone().cpu().numpy().flatten()
|
||||
|
||||
if use_gae:
|
||||
last_gae_lam = 0
|
||||
for step in reversed(range(self.buffer_size)):
|
||||
if step == self.buffer_size - 1:
|
||||
next_non_terminal = 1.0 - dones
|
||||
next_value = last_value
|
||||
else:
|
||||
next_non_terminal = 1.0 - self.dones[step + 1]
|
||||
next_value = self.values[step + 1]
|
||||
delta = self.rewards[step] + self.gamma * next_value * next_non_terminal - self.values[step]
|
||||
last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
|
||||
self.advantages[step] = last_gae_lam
|
||||
self.returns = self.advantages + self.values
|
||||
else:
|
||||
# Discounted return with value bootstrap
|
||||
# Note: this is equivalent to GAE computation
|
||||
# with gae_lambda = 1.0
|
||||
last_return = 0.0
|
||||
for step in reversed(range(self.buffer_size)):
|
||||
if step == self.buffer_size - 1:
|
||||
next_non_terminal = 1.0 - dones
|
||||
next_value = last_value
|
||||
last_return = self.rewards[step] + next_non_terminal * next_value
|
||||
else:
|
||||
next_non_terminal = 1.0 - self.dones[step + 1]
|
||||
last_return = self.rewards[step] + self.gamma * last_return * next_non_terminal
|
||||
self.returns[step] = last_return
|
||||
self.advantages = self.returns - self.values
|
||||
last_gae_lam = 0
|
||||
for step in reversed(range(self.buffer_size)):
|
||||
if step == self.buffer_size - 1:
|
||||
next_non_terminal = 1.0 - dones
|
||||
next_value = last_value
|
||||
else:
|
||||
next_non_terminal = 1.0 - self.dones[step + 1]
|
||||
next_value = self.values[step + 1]
|
||||
delta = self.rewards[step] + self.gamma * next_value * next_non_terminal - self.values[step]
|
||||
last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
|
||||
self.advantages[step] = last_gae_lam
|
||||
self.returns = self.advantages + self.values
|
||||
|
||||
def add(self,
|
||||
obs: np.ndarray,
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from stable_baselines3.common.evaluation import evaluate_policy
|
|||
from stable_baselines3.common import logger
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from stable_baselines3.common.base_class import BaseRLModel # pytype: disable=pyi-error
|
||||
from stable_baselines3.common.base_class import BaseAlgorithm # pytype: disable=pyi-error
|
||||
|
||||
|
||||
class BaseCallback(ABC):
|
||||
|
|
@ -24,7 +24,7 @@ class BaseCallback(ABC):
|
|||
def __init__(self, verbose: int = 0):
|
||||
super(BaseCallback, self).__init__()
|
||||
# The RL model
|
||||
self.model = None # type: Optional[BaseRLModel]
|
||||
self.model = None # type: Optional[BaseAlgorithm]
|
||||
# An alias for self.model.get_env(), the environment used for training
|
||||
self.training_env = None # type: Union[gym.Env, VecEnv, None]
|
||||
# Number of time the callback was called
|
||||
|
|
@ -40,7 +40,7 @@ class BaseCallback(ABC):
|
|||
self.parent = None # type: Optional[BaseCallback]
|
||||
|
||||
# Type hint as string to avoid circular import
|
||||
def init_callback(self, model: 'BaseRLModel') -> None:
|
||||
def init_callback(self, model: 'BaseAlgorithm') -> None:
|
||||
"""
|
||||
Initialize the callback by saving references to the
|
||||
RL model and the training environment for convenience.
|
||||
|
|
@ -118,7 +118,7 @@ class EventCallback(BaseCallback):
|
|||
if callback is not None:
|
||||
self.callback.parent = self
|
||||
|
||||
def init_callback(self, model: 'BaseRLModel') -> None:
|
||||
def init_callback(self, model: 'BaseAlgorithm') -> None:
|
||||
super(EventCallback, self).init_callback(model)
|
||||
if self.callback is not None:
|
||||
self.callback.init_callback(self.model)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ def evaluate_policy(model, env, n_eval_episodes=10, deterministic=True,
|
|||
Runs policy for ``n_eval_episodes`` episodes and returns average reward.
|
||||
This is made to work only with one env.
|
||||
|
||||
:param model: (BaseRLModel) The RL agent you want to evaluate.
|
||||
:param model: (BaseAlgorithm) The RL agent you want to evaluate.
|
||||
:param env: (gym.Env or VecEnv) The gym environment. In the case of a ``VecEnv``
|
||||
this must contain only one environment.
|
||||
:param n_eval_episodes: (int) Number of episode to evaluate the agent
|
||||
|
|
|
|||
279
stable_baselines3/common/off_policy_algorithm.py
Normal file
279
stable_baselines3/common/off_policy_algorithm.py
Normal file
|
|
@ -0,0 +1,279 @@
|
|||
import time
|
||||
import os
|
||||
import pickle
|
||||
import warnings
|
||||
from typing import Union, Type, Optional, Dict, Any, Callable
|
||||
|
||||
import gym
|
||||
import torch as th
|
||||
import numpy as np
|
||||
|
||||
from stable_baselines3.common import logger
|
||||
from stable_baselines3.common.base_class import BaseAlgorithm
|
||||
from stable_baselines3.common.policies import BasePolicy
|
||||
from stable_baselines3.common.utils import safe_mean
|
||||
from stable_baselines3.common.vec_env import VecEnv
|
||||
from stable_baselines3.common.type_aliases import GymEnv, RolloutReturn
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
from stable_baselines3.common.noise import ActionNoise
|
||||
from stable_baselines3.common.buffers import ReplayBuffer
|
||||
|
||||
|
||||
class OffPolicyAlgorithm(BaseAlgorithm):
|
||||
"""
|
||||
The base for Off-Policy algorithms (ex: SAC/TD3)
|
||||
|
||||
:param policy: Policy object
|
||||
:param env: The environment to learn from
|
||||
(if registered in Gym, can be str. Can be None for loading trained models)
|
||||
:param policy_base: The base policy used by this method
|
||||
:param learning_rate: (float or callable) learning rate for the optimizer,
|
||||
it can be a function of the current progress remaining (from 1 to 0)
|
||||
:param buffer_size: (int) size of the replay buffer
|
||||
:param learning_starts: (int) how many steps of the model to collect transitions for before learning starts
|
||||
:param batch_size: (int) Minibatch size for each gradient update
|
||||
:param policy_kwargs: Additional arguments to be passed to the policy on creation
|
||||
:param tensorboard_log: (str) the log location for tensorboard (if None, no logging)
|
||||
:param verbose: The verbosity level: 0 none, 1 training information, 2 debug
|
||||
:param device: Device on which the code should run.
|
||||
By default, it will try to use a Cuda compatible device and fallback to cpu
|
||||
if it is not possible.
|
||||
:param support_multi_env: Whether the algorithm supports training
|
||||
with multiple environments (as in A2C)
|
||||
:param create_eval_env: Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||
:param monitor_wrapper: When creating an environment, whether to wrap it
|
||||
or not in a Monitor wrapper.
|
||||
:param seed: Seed for the pseudo random generators
|
||||
:param use_sde: Whether to use State Dependent Exploration (SDE)
|
||||
instead of action noise exploration (default: False)
|
||||
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
|
||||
Default: -1 (only sample at the beginning of the rollout)
|
||||
:param use_sde_at_warmup: (bool) Whether to use gSDE instead of uniform sampling
|
||||
during the warm up phase (before learning starts)
|
||||
:param sde_support: (bool) Whether the model support gSDE or not
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
policy: Type[BasePolicy],
|
||||
env: Union[GymEnv, str],
|
||||
policy_base: Type[BasePolicy],
|
||||
learning_rate: Union[float, Callable],
|
||||
buffer_size: int = int(1e6),
|
||||
learning_starts: int = 100,
|
||||
batch_size: int = 256,
|
||||
policy_kwargs: Dict[str, Any] = None,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
verbose: int = 0,
|
||||
device: Union[th.device, str] = 'auto',
|
||||
support_multi_env: bool = False,
|
||||
create_eval_env: bool = False,
|
||||
monitor_wrapper: bool = True,
|
||||
seed: Optional[int] = None,
|
||||
use_sde: bool = False,
|
||||
sde_sample_freq: int = -1,
|
||||
use_sde_at_warmup: bool = False,
|
||||
sde_support: bool = True):
|
||||
|
||||
super(OffPolicyAlgorithm, self).__init__(policy=policy, env=env, policy_base=policy_base,
|
||||
learning_rate=learning_rate, policy_kwargs=policy_kwargs,
|
||||
tensorboard_log=tensorboard_log, verbose=verbose,
|
||||
device=device, support_multi_env=support_multi_env,
|
||||
create_eval_env=create_eval_env, monitor_wrapper=monitor_wrapper,
|
||||
seed=seed, use_sde=use_sde, sde_sample_freq=sde_sample_freq)
|
||||
self.buffer_size = buffer_size
|
||||
self.batch_size = batch_size
|
||||
self.learning_starts = learning_starts
|
||||
self.actor = None # type: Optional[th.nn.Module]
|
||||
self.replay_buffer = None # type: Optional[ReplayBuffer]
|
||||
# Update policy keyword arguments
|
||||
if sde_support:
|
||||
self.policy_kwargs['use_sde'] = self.use_sde
|
||||
self.policy_kwargs['device'] = self.device
|
||||
# For gSDE only
|
||||
self.use_sde_at_warmup = use_sde_at_warmup
|
||||
|
||||
def _setup_model(self):
|
||||
self._setup_lr_schedule()
|
||||
self.set_random_seed(self.seed)
|
||||
self.replay_buffer = ReplayBuffer(self.buffer_size, self.observation_space,
|
||||
self.action_space, self.device)
|
||||
self.policy = self.policy_class(self.observation_space, self.action_space,
|
||||
self.lr_schedule, **self.policy_kwargs)
|
||||
self.policy = self.policy.to(self.device)
|
||||
|
||||
def save_replay_buffer(self, path: str):
|
||||
"""
|
||||
Save the replay buffer as a pickle file.
|
||||
|
||||
:param path: (str) Path to a log folder
|
||||
"""
|
||||
assert self.replay_buffer is not None, "The replay buffer is not defined"
|
||||
with open(os.path.join(path, 'replay_buffer.pkl'), 'wb') as file_handler:
|
||||
pickle.dump(self.replay_buffer, file_handler)
|
||||
|
||||
def load_replay_buffer(self, path: str):
|
||||
"""
|
||||
Load a replay buffer from a pickle file.
|
||||
|
||||
:param path: (str) Path to the pickled replay buffer.
|
||||
"""
|
||||
with open(path, 'rb') as file_handler:
|
||||
self.replay_buffer = pickle.load(file_handler)
|
||||
assert isinstance(self.replay_buffer, ReplayBuffer), 'The replay buffer must inherit from ReplayBuffer class'
|
||||
|
||||
def collect_rollouts(self, # noqa: C901
|
||||
env: VecEnv,
|
||||
# Type hint as string to avoid circular import
|
||||
callback: 'BaseCallback',
|
||||
n_episodes: int = 1,
|
||||
n_steps: int = -1,
|
||||
action_noise: Optional[ActionNoise] = None,
|
||||
learning_starts: int = 0,
|
||||
replay_buffer: Optional[ReplayBuffer] = None,
|
||||
log_interval: Optional[int] = None) -> RolloutReturn:
|
||||
"""
|
||||
Collect experiences and store them into a ReplayBuffer.
|
||||
|
||||
:param env: (VecEnv) The training environment
|
||||
:param callback: (BaseCallback) Callback that will be called at each step
|
||||
(and at the beginning and end of the rollout)
|
||||
:param n_episodes: (int) Number of episodes to use to collect rollout data
|
||||
You can also specify a ``n_steps`` instead
|
||||
:param n_steps: (int) Number of steps to use to collect rollout data
|
||||
You can also specify a ``n_episodes`` instead.
|
||||
:param action_noise: (Optional[ActionNoise]) Action noise that will be used for exploration
|
||||
Required for deterministic policy (e.g. TD3). This can also be used
|
||||
in addition to the stochastic policy for SAC.
|
||||
:param learning_starts: (int) Number of steps before learning for the warm-up phase.
|
||||
:param replay_buffer: (ReplayBuffer)
|
||||
:param log_interval: (int) Log data every ``log_interval`` episodes
|
||||
:return: (RolloutReturn)
|
||||
"""
|
||||
episode_rewards, total_timesteps = [], []
|
||||
total_steps, total_episodes = 0, 0
|
||||
|
||||
assert isinstance(env, VecEnv), "You must pass a VecEnv"
|
||||
assert env.num_envs == 1, "OffPolicyAlgorithm only support single environment"
|
||||
|
||||
if n_episodes > 0 and n_steps > 0:
|
||||
# Note we are refering to the constructor arguments
|
||||
# that are named `train_freq` and `n_episodes_rollout`
|
||||
# but correspond to `n_steps` and `n_episodes` here
|
||||
warnings.warn("You passed a positive value for `train_freq` and `n_episodes_rollout`."
|
||||
"Please make sure this is intended. "
|
||||
"The agent will collect data by stepping in the environment "
|
||||
"until both conditions are true: "
|
||||
"`number of steps in the env` >= `train_freq` and "
|
||||
"`number of episodes` > `n_episodes_rollout`")
|
||||
|
||||
if self.use_sde:
|
||||
self.actor.reset_noise()
|
||||
|
||||
callback.on_rollout_start()
|
||||
continue_training = True
|
||||
|
||||
while total_steps < n_steps or total_episodes < n_episodes:
|
||||
done = False
|
||||
episode_reward, episode_timesteps = 0.0, 0
|
||||
|
||||
while not done:
|
||||
|
||||
if self.use_sde and self.sde_sample_freq > 0 and total_steps % self.sde_sample_freq == 0:
|
||||
# Sample a new noise matrix
|
||||
self.actor.reset_noise()
|
||||
|
||||
# Select action randomly or according to policy
|
||||
if self.num_timesteps < learning_starts and not (self.use_sde and self.use_sde_at_warmup):
|
||||
# Warmup phase
|
||||
unscaled_action = np.array([self.action_space.sample()])
|
||||
else:
|
||||
# Note: we assume that the policy uses tanh to scale the action
|
||||
# We use non-deterministic action in the case of SAC, for TD3, it does not matter
|
||||
unscaled_action, _ = self.predict(self._last_obs, deterministic=False)
|
||||
|
||||
# Rescale the action from [low, high] to [-1, 1]
|
||||
if isinstance(self.action_space, gym.spaces.Box):
|
||||
scaled_action = self.policy.scale_action(unscaled_action)
|
||||
|
||||
# Add noise to the action (improve exploration)
|
||||
if action_noise is not None:
|
||||
# NOTE: in the original implementation of TD3, the noise was applied to the unscaled action
|
||||
# Update(October 2019): Not anymore
|
||||
scaled_action = np.clip(scaled_action + action_noise(), -1, 1)
|
||||
|
||||
# We store the scaled action in the buffer
|
||||
buffer_action = scaled_action
|
||||
action = self.policy.unscale_action(scaled_action)
|
||||
else:
|
||||
# Discrete case, no need to normalize or clip
|
||||
buffer_action = unscaled_action
|
||||
action = buffer_action
|
||||
|
||||
# Rescale and perform action
|
||||
new_obs, reward, done, infos = env.step(action)
|
||||
|
||||
# Only stop training if return value is False, not when it is None.
|
||||
if callback.on_step() is False:
|
||||
return RolloutReturn(0.0, total_steps, total_episodes, continue_training=False)
|
||||
|
||||
episode_reward += reward
|
||||
|
||||
# Retrieve reward and episode length if using Monitor wrapper
|
||||
self._update_info_buffer(infos, done)
|
||||
|
||||
# Store data in replay buffer
|
||||
if replay_buffer is not None:
|
||||
# Store only the unnormalized version
|
||||
if self._vec_normalize_env is not None:
|
||||
new_obs_ = self._vec_normalize_env.get_original_obs()
|
||||
reward_ = self._vec_normalize_env.get_original_reward()
|
||||
else:
|
||||
# Avoid changing the original ones
|
||||
self._last_original_obs, new_obs_, reward_ = self._last_obs, new_obs, reward
|
||||
|
||||
replay_buffer.add(self._last_original_obs, new_obs_, buffer_action, reward_, done)
|
||||
|
||||
self._last_obs = new_obs
|
||||
# Save the unnormalized observation
|
||||
if self._vec_normalize_env is not None:
|
||||
self._last_original_obs = new_obs_
|
||||
|
||||
self.num_timesteps += 1
|
||||
episode_timesteps += 1
|
||||
total_steps += 1
|
||||
if 0 < n_steps <= total_steps:
|
||||
break
|
||||
|
||||
if done:
|
||||
total_episodes += 1
|
||||
self._episode_num += 1
|
||||
episode_rewards.append(episode_reward)
|
||||
total_timesteps.append(episode_timesteps)
|
||||
|
||||
if action_noise is not None:
|
||||
action_noise.reset()
|
||||
|
||||
# Log training infos
|
||||
if log_interval is not None and self._episode_num % log_interval == 0:
|
||||
fps = int(self.num_timesteps / (time.time() - self.start_time))
|
||||
logger.record("time/episodes", self._episode_num, exclude="tensorboard")
|
||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
||||
logger.record('rollout/ep_rew_mean', safe_mean([ep_info['r'] for ep_info in self.ep_info_buffer]))
|
||||
logger.record('rollout/ep_len_mean', safe_mean([ep_info['l'] for ep_info in self.ep_info_buffer]))
|
||||
logger.record("time/fps", fps)
|
||||
logger.record('time/time_elapsed', int(time.time() - self.start_time), exclude="tensorboard")
|
||||
logger.record("time/total timesteps", self.num_timesteps, exclude="tensorboard")
|
||||
if self.use_sde:
|
||||
logger.record("train/std", (self.actor.get_std()).mean().item())
|
||||
|
||||
if len(self.ep_success_buffer) > 0:
|
||||
logger.record('rollout/success rate', safe_mean(self.ep_success_buffer))
|
||||
# Pass the number of timesteps for tensorboard
|
||||
logger.dump(step=self.num_timesteps)
|
||||
|
||||
mean_reward = np.mean(episode_rewards) if total_episodes > 0 else 0.0
|
||||
|
||||
callback.on_rollout_end()
|
||||
|
||||
return RolloutReturn(mean_reward, total_steps, total_episodes, continue_training)
|
||||
228
stable_baselines3/common/on_policy_algorithm.py
Normal file
228
stable_baselines3/common/on_policy_algorithm.py
Normal file
|
|
@ -0,0 +1,228 @@
|
|||
import time
|
||||
from typing import Union, Type, Optional, Dict, Any, List, Tuple, Callable
|
||||
|
||||
import gym
|
||||
import torch as th
|
||||
import numpy as np
|
||||
|
||||
from stable_baselines3.common import logger
|
||||
from stable_baselines3.common.utils import safe_mean
|
||||
from stable_baselines3.common.base_class import BaseAlgorithm
|
||||
from stable_baselines3.common.policies import ActorCriticPolicy
|
||||
from stable_baselines3.common.vec_env import VecEnv
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
from stable_baselines3.common.buffers import RolloutBuffer
|
||||
|
||||
|
||||
class OnPolicyAlgorithm(BaseAlgorithm):
|
||||
"""
|
||||
The base for On-Policy algorithms (ex: A2C/PPO).
|
||||
|
||||
:param policy: (ActorCriticPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, ...)
|
||||
:param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str)
|
||||
:param learning_rate: (float or callable) The learning rate, it can be a function
|
||||
of the current progress remaining (from 1 to 0)
|
||||
:param n_steps: (int) The number of steps to run for each environment per update
|
||||
(i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
|
||||
:param gamma: (float) Discount factor
|
||||
:param gae_lambda: (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator.
|
||||
Equivalent to classic advantage when set to 1.
|
||||
:param ent_coef: (float) Entropy coefficient for the loss calculation
|
||||
:param vf_coef: (float) Value function coefficient for the loss calculation
|
||||
:param max_grad_norm: (float) The maximum value for the gradient clipping
|
||||
:param use_sde: (bool) Whether to use generalized State Dependent Exploration (gSDE)
|
||||
instead of action noise exploration (default: False)
|
||||
:param sde_sample_freq: (int) Sample a new noise matrix every n steps when using gSDE
|
||||
Default: -1 (only sample at the beginning of the rollout)
|
||||
:param tensorboard_log: (str) the log location for tensorboard (if None, no logging)
|
||||
:param create_eval_env: (bool) Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||
:param monitor_wrapper: When creating an environment, whether to wrap it
|
||||
or not in a Monitor wrapper.
|
||||
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
|
||||
:param verbose: (int) the verbosity level: 0 no output, 1 info, 2 debug
|
||||
:param seed: (int) Seed for the pseudo random generators
|
||||
:param device: (str or th.device) Device (cpu, cuda, ...) on which the code should be run.
|
||||
Setting it to auto, the code will be run on the GPU if possible.
|
||||
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
policy: Union[str, Type[ActorCriticPolicy]],
|
||||
env: Union[GymEnv, str],
|
||||
learning_rate: Union[float, Callable],
|
||||
n_steps: int,
|
||||
gamma: float,
|
||||
gae_lambda: float,
|
||||
ent_coef: float,
|
||||
vf_coef: float,
|
||||
max_grad_norm: float,
|
||||
use_sde: bool,
|
||||
sde_sample_freq: int,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
create_eval_env: bool = False,
|
||||
monitor_wrapper: bool = True,
|
||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||
verbose: int = 0,
|
||||
seed: Optional[int] = None,
|
||||
device: Union[th.device, str] = 'auto',
|
||||
_init_setup_model: bool = True):
|
||||
|
||||
super(OnPolicyAlgorithm, self).__init__(policy=policy, env=env, policy_base=ActorCriticPolicy,
|
||||
learning_rate=learning_rate, policy_kwargs=policy_kwargs,
|
||||
verbose=verbose, device=device, use_sde=use_sde,
|
||||
sde_sample_freq=sde_sample_freq, create_eval_env=create_eval_env,
|
||||
support_multi_env=True, seed=seed, tensorboard_log=tensorboard_log)
|
||||
|
||||
self.n_steps = n_steps
|
||||
self.gamma = gamma
|
||||
self.gae_lambda = gae_lambda
|
||||
self.ent_coef = ent_coef
|
||||
self.vf_coef = vf_coef
|
||||
self.max_grad_norm = max_grad_norm
|
||||
self.rollout_buffer = None
|
||||
|
||||
if _init_setup_model:
|
||||
self._setup_model()
|
||||
|
||||
def _setup_model(self) -> None:
|
||||
self._setup_lr_schedule()
|
||||
self.set_random_seed(self.seed)
|
||||
|
||||
self.rollout_buffer = RolloutBuffer(self.n_steps, self.observation_space,
|
||||
self.action_space, self.device,
|
||||
gamma=self.gamma, gae_lambda=self.gae_lambda,
|
||||
n_envs=self.n_envs)
|
||||
self.policy = self.policy_class(self.observation_space, self.action_space,
|
||||
self.lr_schedule, use_sde=self.use_sde, device=self.device,
|
||||
**self.policy_kwargs)
|
||||
self.policy = self.policy.to(self.device)
|
||||
|
||||
def collect_rollouts(self,
|
||||
env: VecEnv,
|
||||
callback: BaseCallback,
|
||||
rollout_buffer: RolloutBuffer,
|
||||
n_rollout_steps: int) -> bool:
|
||||
"""
|
||||
Collect rollouts using the current policy and fill a `RolloutBuffer`.
|
||||
|
||||
:param env: (VecEnv) The training environment
|
||||
:param callback: (BaseCallback) Callback that will be called at each step
|
||||
(and at the beginning and end of the rollout)
|
||||
:param rollout_buffer: (RolloutBuffer) Buffer to fill with rollouts
|
||||
:param n_steps: (int) Number of experiences to collect per environment
|
||||
:return: (bool) True if function returned with at least `n_rollout_steps`
|
||||
collected, False if callback terminated rollout prematurely.
|
||||
"""
|
||||
assert self._last_obs is not None, "No previous observation was provided"
|
||||
n_steps = 0
|
||||
rollout_buffer.reset()
|
||||
# Sample new weights for the state dependent exploration
|
||||
if self.use_sde:
|
||||
self.policy.reset_noise(env.num_envs)
|
||||
|
||||
callback.on_rollout_start()
|
||||
|
||||
while n_steps < n_rollout_steps:
|
||||
if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
|
||||
# Sample a new noise matrix
|
||||
self.policy.reset_noise(env.num_envs)
|
||||
|
||||
with th.no_grad():
|
||||
# Convert to pytorch tensor
|
||||
obs_tensor = th.as_tensor(self._last_obs).to(self.device)
|
||||
actions, values, log_probs = self.policy.forward(obs_tensor)
|
||||
actions = actions.cpu().numpy()
|
||||
|
||||
# Rescale and perform action
|
||||
clipped_actions = actions
|
||||
# Clip the actions to avoid out of bound error
|
||||
if isinstance(self.action_space, gym.spaces.Box):
|
||||
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
|
||||
|
||||
new_obs, rewards, dones, infos = env.step(clipped_actions)
|
||||
|
||||
if callback.on_step() is False:
|
||||
return False
|
||||
|
||||
self._update_info_buffer(infos)
|
||||
n_steps += 1
|
||||
self.num_timesteps += env.num_envs
|
||||
|
||||
if isinstance(self.action_space, gym.spaces.Discrete):
|
||||
# Reshape in case of discrete action
|
||||
actions = actions.reshape(-1, 1)
|
||||
rollout_buffer.add(self._last_obs, actions, rewards, dones, values, log_probs)
|
||||
self._last_obs = new_obs
|
||||
|
||||
rollout_buffer.compute_returns_and_advantage(values, dones=dones)
|
||||
|
||||
callback.on_rollout_end()
|
||||
|
||||
return True
|
||||
|
||||
def train(self) -> None:
|
||||
"""
|
||||
Consume current rollout data and update policy parameters.
|
||||
Implemented by individual algorithms.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def learn(self,
|
||||
total_timesteps: int,
|
||||
callback: MaybeCallback = None,
|
||||
log_interval: int = 1,
|
||||
eval_env: Optional[GymEnv] = None,
|
||||
eval_freq: int = -1,
|
||||
n_eval_episodes: int = 5,
|
||||
tb_log_name: str = "OnPolicyAlgorithm",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True) -> 'OnPolicyAlgorithm':
|
||||
iteration = 0
|
||||
|
||||
total_timesteps, callback = self._setup_learn(total_timesteps, eval_env, callback, eval_freq,
|
||||
n_eval_episodes, eval_log_path, reset_num_timesteps,
|
||||
tb_log_name)
|
||||
|
||||
callback.on_training_start(locals(), globals())
|
||||
|
||||
while self.num_timesteps < total_timesteps:
|
||||
|
||||
continue_training = self.collect_rollouts(self.env, callback,
|
||||
self.rollout_buffer,
|
||||
n_rollout_steps=self.n_steps)
|
||||
|
||||
if continue_training is False:
|
||||
break
|
||||
|
||||
iteration += 1
|
||||
self._update_current_progress_remaining(self.num_timesteps, total_timesteps)
|
||||
|
||||
# Display training infos
|
||||
if log_interval is not None and iteration % log_interval == 0:
|
||||
fps = int(self.num_timesteps / (time.time() - self.start_time))
|
||||
logger.record("time/iterations", iteration, exclude="tensorboard")
|
||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
||||
logger.record("rollout/ep_rew_mean",
|
||||
safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
|
||||
logger.record("rollout/ep_len_mean",
|
||||
safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
|
||||
logger.record("time/fps", fps)
|
||||
logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard")
|
||||
logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
|
||||
logger.dump(step=self.num_timesteps)
|
||||
|
||||
self.train()
|
||||
|
||||
callback.on_training_end()
|
||||
|
||||
return self
|
||||
|
||||
def get_torch_variables(self) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
cf base class
|
||||
"""
|
||||
state_dicts = ["policy", "policy.optimizer"]
|
||||
|
||||
return state_dicts, []
|
||||
|
|
@ -1,89 +1,20 @@
|
|||
from typing import Union, Type, Dict, List, Tuple, Optional, Any
|
||||
|
||||
from itertools import zip_longest
|
||||
from typing import Union, Type, Dict, List, Tuple, Optional, Any, Callable
|
||||
from functools import partial
|
||||
|
||||
import gym
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from stable_baselines3.common.preprocessing import preprocess_obs, get_flattened_obs_dim, is_image_space
|
||||
from stable_baselines3.common.utils import get_device
|
||||
from stable_baselines3.common.preprocessing import preprocess_obs, is_image_space
|
||||
from stable_baselines3.common.torch_layers import (FlattenExtractor, BaseFeaturesExtractor, create_mlp,
|
||||
NatureCNN, MlpExtractor)
|
||||
from stable_baselines3.common.utils import get_device, is_vectorized_observation
|
||||
from stable_baselines3.common.vec_env import VecTransposeImage
|
||||
|
||||
|
||||
class BaseFeaturesExtractor(nn.Module):
|
||||
"""
|
||||
Base class that represents a features extractor.
|
||||
|
||||
:param observation_space: (gym.Space)
|
||||
:param features_dim: (int) Number of features extracted.
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space: gym.Space, features_dim: int = 0):
|
||||
super(BaseFeaturesExtractor, self).__init__()
|
||||
assert features_dim > 0
|
||||
self._observation_space = observation_space
|
||||
self._features_dim = features_dim
|
||||
|
||||
@property
|
||||
def features_dim(self) -> int:
|
||||
return self._features_dim
|
||||
|
||||
def forward(self, observations: th.Tensor) -> th.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class FlattenExtractor(BaseFeaturesExtractor):
|
||||
"""
|
||||
Feature extract that flatten the input.
|
||||
Used as a placeholder when feature extraction is not needed.
|
||||
|
||||
:param observation_space: (gym.Space)
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space: gym.Space):
|
||||
super(FlattenExtractor, self).__init__(observation_space, get_flattened_obs_dim(observation_space))
|
||||
self.flatten = nn.Flatten()
|
||||
|
||||
def forward(self, observations: th.Tensor) -> th.Tensor:
|
||||
return self.flatten(observations)
|
||||
|
||||
|
||||
class NatureCNN(BaseFeaturesExtractor):
|
||||
"""
|
||||
CNN from DQN nature paper: https://arxiv.org/abs/1312.5602
|
||||
|
||||
:param observation_space: (gym.Space)
|
||||
:param features_dim: (int) Number of features extracted.
|
||||
This corresponds to the number of unit for the last layer.
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space: gym.spaces.Box,
|
||||
features_dim: int = 512):
|
||||
super(NatureCNN, self).__init__(observation_space, features_dim)
|
||||
# We assume CxWxH images (channels first)
|
||||
# Re-ordering will be done by pre-preprocessing or wrapper
|
||||
assert is_image_space(observation_space), ('You should use NatureCNN '
|
||||
f'only with images not with {observation_space} '
|
||||
'(you are probably using `CnnPolicy` instead of `MlpPolicy`)')
|
||||
n_input_channels = observation_space.shape[0]
|
||||
self.cnn = nn.Sequential(nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Flatten())
|
||||
|
||||
# Compute shape by doing one forward pass
|
||||
with th.no_grad():
|
||||
n_flatten = self.cnn(th.as_tensor(observation_space.sample()[None]).float()).shape[1]
|
||||
|
||||
self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
|
||||
|
||||
def forward(self, observations: th.Tensor) -> th.Tensor:
|
||||
return self.linear(self.cnn(observations))
|
||||
from stable_baselines3.common.distributions import (make_proba_distribution, Distribution,
|
||||
DiagGaussianDistribution, CategoricalDistribution,
|
||||
MultiCategoricalDistribution, BernoulliDistribution,
|
||||
StateDependentNoiseDistribution)
|
||||
|
||||
|
||||
class BasePolicy(nn.Module):
|
||||
|
|
@ -93,8 +24,6 @@ class BasePolicy(nn.Module):
|
|||
:param observation_space: (gym.spaces.Space) The observation space of the environment
|
||||
:param action_space: (gym.spaces.Space) The action space of the environment
|
||||
:param device: (Union[th.device, str]) Device on which the code should run.
|
||||
:param squash_output: (bool) For continuous actions, whether the output is squashed
|
||||
or not using a ``tanh()`` function.
|
||||
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
|
||||
:param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments
|
||||
to pass to the feature extractor.
|
||||
|
|
@ -106,9 +35,12 @@ class BasePolicy(nn.Module):
|
|||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
:param squash_output: (bool) For continuous actions, whether the output is squashed
|
||||
or not using a ``tanh()`` function.
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space: gym.spaces.Space,
|
||||
def __init__(self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
device: Union[th.device, str] = 'auto',
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||
|
|
@ -166,7 +98,7 @@ class BasePolicy(nn.Module):
|
|||
module.bias.data.fill_(0.0)
|
||||
|
||||
@staticmethod
|
||||
def _dummy_schedule(_progress: float) -> float:
|
||||
def _dummy_schedule(_progress_remaining: float) -> float:
|
||||
""" (float) Useful for pickling policy."""
|
||||
return 0.0
|
||||
|
||||
|
|
@ -183,12 +115,14 @@ class BasePolicy(nn.Module):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def predict(self, observation: np.ndarray,
|
||||
def predict(self,
|
||||
observation: np.ndarray,
|
||||
state: Optional[np.ndarray] = None,
|
||||
mask: Optional[np.ndarray] = None,
|
||||
deterministic: bool = False) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
||||
"""
|
||||
Get the policy action and state from an observation (and optional state).
|
||||
Includes sugar-coating to handle different observations (e.g. normalizing images).
|
||||
|
||||
:param observation: (np.ndarray) the input observation
|
||||
:param state: (Optional[np.ndarray]) The last states (can be None, used in recurrent policies)
|
||||
|
|
@ -216,7 +150,7 @@ class BasePolicy(nn.Module):
|
|||
or transpose_obs.shape[1:] == self.observation_space.shape):
|
||||
observation = transpose_obs
|
||||
|
||||
vectorized_env = self._is_vectorized_observation(observation, self.observation_space)
|
||||
vectorized_env = is_vectorized_observation(observation, self.observation_space)
|
||||
|
||||
observation = observation.reshape((-1,) + self.observation_space.shape)
|
||||
|
||||
|
|
@ -263,57 +197,6 @@ class BasePolicy(nn.Module):
|
|||
low, high = self.action_space.low, self.action_space.high
|
||||
return low + (0.5 * (scaled_action + 1.0) * (high - low))
|
||||
|
||||
@staticmethod
|
||||
def _is_vectorized_observation(observation: np.ndarray, observation_space: gym.spaces.Space) -> bool:
|
||||
"""
|
||||
For every observation type, detects and validates the shape,
|
||||
then returns whether or not the observation is vectorized.
|
||||
|
||||
:param observation: (np.ndarray) the input observation to validate
|
||||
:param observation_space: (gym.spaces) the observation space
|
||||
:return: (bool) whether the given observation is vectorized or not
|
||||
"""
|
||||
if isinstance(observation_space, gym.spaces.Box):
|
||||
if observation.shape == observation_space.shape:
|
||||
return False
|
||||
elif observation.shape[1:] == observation_space.shape:
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"Error: Unexpected observation shape {observation.shape} for "
|
||||
+ f"Box environment, please use {observation_space.shape} "
|
||||
+ "or (n_env, {}) for the observation shape."
|
||||
.format(", ".join(map(str, observation_space.shape))))
|
||||
elif isinstance(observation_space, gym.spaces.Discrete):
|
||||
if observation.shape == (): # A numpy array of a number, has shape empty tuple '()'
|
||||
return False
|
||||
elif len(observation.shape) == 1:
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"Error: Unexpected observation shape {observation.shape} for "
|
||||
+ "Discrete environment, please use (1,) or (n_env, 1) for the observation shape.")
|
||||
|
||||
elif isinstance(observation_space, gym.spaces.MultiDiscrete):
|
||||
if observation.shape == (len(observation_space.nvec),):
|
||||
return False
|
||||
elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec):
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"Error: Unexpected observation shape {observation.shape} for MultiDiscrete "
|
||||
+ f"environment, please use ({len(observation_space.nvec)},) or "
|
||||
+ f"(n_env, {len(observation_space.nvec)}) for the observation shape.")
|
||||
elif isinstance(observation_space, gym.spaces.MultiBinary):
|
||||
if observation.shape == (observation_space.n,):
|
||||
return False
|
||||
elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n:
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"Error: Unexpected observation shape {observation.shape} for MultiBinary "
|
||||
+ f"environment, please use ({observation_space.n},) or "
|
||||
+ f"(n_env, {observation_space.n}) for the observation shape.")
|
||||
else:
|
||||
raise ValueError("Error: Cannot determine if the observation is vectorized "
|
||||
+ f" with the space type {observation_space}.")
|
||||
|
||||
def _get_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get data that need to be saved in order to re-create the policy.
|
||||
|
|
@ -373,42 +256,365 @@ class BasePolicy(nn.Module):
|
|||
return th.nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy()
|
||||
|
||||
|
||||
def create_mlp(input_dim: int,
|
||||
output_dim: int,
|
||||
net_arch: List[int],
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
squash_output: bool = False) -> List[nn.Module]:
|
||||
class ActorCriticPolicy(BasePolicy):
|
||||
"""
|
||||
Create a multi layer perceptron (MLP), which is
|
||||
a collection of fully-connected layers each followed by an activation function.
|
||||
Policy class for actor-critic algorithms (has both policy and value prediction).
|
||||
Used by A2C, PPO and the likes.
|
||||
|
||||
:param input_dim: (int) Dimension of the input vector
|
||||
:param output_dim: (int)
|
||||
:param net_arch: (List[int]) Architecture of the neural net
|
||||
It represents the number of units per layer.
|
||||
The length of this list is the number of layers.
|
||||
:param activation_fn: (Type[nn.Module]) The activation function
|
||||
to use after each layer.
|
||||
:param squash_output: (bool) Whether to squash the output using a Tanh
|
||||
activation function
|
||||
:return: (List[nn.Module])
|
||||
:param observation_space: (gym.spaces.Space) Observation space
|
||||
:param action_space: (gym.spaces.Space) Action space
|
||||
:param lr_schedule: (Callable) Learning rate schedule (could be constant)
|
||||
:param net_arch: ([int or dict]) The specification of the policy and value networks.
|
||||
:param device: (str or th.device) Device on which the code should run.
|
||||
:param activation_fn: (Type[nn.Module]) Activation function
|
||||
:param ortho_init: (bool) Whether to use or not orthogonal initialization
|
||||
:param use_sde: (bool) Whether to use State Dependent Exploration or not
|
||||
:param log_std_init: (float) Initial value for the log standard deviation
|
||||
:param full_std: (bool) Whether to use (n_features x n_actions) parameters
|
||||
for the std instead of only (n_features,) when using gSDE
|
||||
:param sde_net_arch: ([int]) Network architecture for extracting features
|
||||
when using gSDE. If None, the latent features from the policy will be used.
|
||||
Pass an empty list to use the states as features.
|
||||
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` to ensure
|
||||
a positive standard deviation (cf paper). It allows to keep variance
|
||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||
:param squash_output: (bool) Whether to squash the output using a tanh function,
|
||||
this allows to ensure boundaries when using gSDE.
|
||||
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
|
||||
:param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments
|
||||
to pass to the feature extractor.
|
||||
:param normalize_images: (bool) Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use,
|
||||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
"""
|
||||
|
||||
if len(net_arch) > 0:
|
||||
modules = [nn.Linear(input_dim, net_arch[0]), activation_fn()]
|
||||
else:
|
||||
modules = []
|
||||
def __init__(self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: Callable,
|
||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||
device: Union[th.device, str] = 'auto',
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
ortho_init: bool = True,
|
||||
use_sde: bool = False,
|
||||
log_std_init: float = 0.0,
|
||||
full_std: bool = True,
|
||||
sde_net_arch: Optional[List[int]] = None,
|
||||
use_expln: bool = False,
|
||||
squash_output: bool = False,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None):
|
||||
|
||||
for idx in range(len(net_arch) - 1):
|
||||
modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1]))
|
||||
modules.append(activation_fn())
|
||||
if optimizer_kwargs is None:
|
||||
optimizer_kwargs = {}
|
||||
# Small values to avoid NaN in ADAM optimizer
|
||||
if optimizer_class == th.optim.Adam:
|
||||
optimizer_kwargs['eps'] = 1e-5
|
||||
|
||||
if output_dim > 0:
|
||||
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else input_dim
|
||||
modules.append(nn.Linear(last_layer_dim, output_dim))
|
||||
if squash_output:
|
||||
modules.append(nn.Tanh())
|
||||
return modules
|
||||
super(ActorCriticPolicy, self).__init__(observation_space,
|
||||
action_space,
|
||||
device,
|
||||
features_extractor_class,
|
||||
features_extractor_kwargs,
|
||||
optimizer_class=optimizer_class,
|
||||
optimizer_kwargs=optimizer_kwargs,
|
||||
squash_output=squash_output)
|
||||
|
||||
# Default network architecture, from stable-baselines
|
||||
if net_arch is None:
|
||||
if features_extractor_class == FlattenExtractor:
|
||||
net_arch = [dict(pi=[64, 64], vf=[64, 64])]
|
||||
else:
|
||||
net_arch = []
|
||||
|
||||
self.net_arch = net_arch
|
||||
self.activation_fn = activation_fn
|
||||
self.ortho_init = ortho_init
|
||||
|
||||
self.features_extractor = features_extractor_class(self.observation_space,
|
||||
**self.features_extractor_kwargs)
|
||||
self.features_dim = self.features_extractor.features_dim
|
||||
|
||||
self.normalize_images = normalize_images
|
||||
self.log_std_init = log_std_init
|
||||
dist_kwargs = None
|
||||
# Keyword arguments for gSDE distribution
|
||||
if use_sde:
|
||||
dist_kwargs = {
|
||||
'full_std': full_std,
|
||||
'squash_output': squash_output,
|
||||
'use_expln': use_expln,
|
||||
'learn_features': sde_net_arch is not None
|
||||
}
|
||||
|
||||
self.sde_features_extractor = None
|
||||
self.sde_net_arch = sde_net_arch
|
||||
self.use_sde = use_sde
|
||||
self.dist_kwargs = dist_kwargs
|
||||
|
||||
# Action distribution
|
||||
self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, dist_kwargs=dist_kwargs)
|
||||
|
||||
self._build(lr_schedule)
|
||||
|
||||
def _get_data(self) -> Dict[str, Any]:
|
||||
data = super()._get_data()
|
||||
|
||||
data.update(dict(
|
||||
net_arch=self.net_arch,
|
||||
activation_fn=self.activation_fn,
|
||||
use_sde=self.use_sde,
|
||||
log_std_init=self.log_std_init,
|
||||
squash_output=self.dist_kwargs['squash_output'] if self.dist_kwargs else None,
|
||||
full_std=self.dist_kwargs['full_std'] if self.dist_kwargs else None,
|
||||
sde_net_arch=self.dist_kwargs['sde_net_arch'] if self.dist_kwargs else None,
|
||||
use_expln=self.dist_kwargs['use_expln'] if self.dist_kwargs else None,
|
||||
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
|
||||
ortho_init=self.ortho_init,
|
||||
optimizer_class=self.optimizer_class,
|
||||
optimizer_kwargs=self.optimizer_kwargs,
|
||||
features_extractor_class=self.features_extractor_class,
|
||||
features_extractor_kwargs=self.features_extractor_kwargs
|
||||
))
|
||||
return data
|
||||
|
||||
def reset_noise(self, n_envs: int = 1) -> None:
|
||||
"""
|
||||
Sample new weights for the exploration matrix.
|
||||
|
||||
:param n_envs: (int)
|
||||
"""
|
||||
assert isinstance(self.action_dist,
|
||||
StateDependentNoiseDistribution), 'reset_noise() is only available when using gSDE'
|
||||
self.action_dist.sample_weights(self.log_std, batch_size=n_envs)
|
||||
|
||||
def _build(self, lr_schedule: Callable) -> None:
|
||||
"""
|
||||
Create the networks and the optimizer.
|
||||
|
||||
:param lr_schedule: (Callable) Learning rate schedule
|
||||
lr_schedule(1) is the initial learning rate
|
||||
"""
|
||||
# Note: If net_arch is None and some features extractor is used,
|
||||
# net_arch here is an empty list and mlp_extractor does not
|
||||
# really contain any layers (acts like an identity module).
|
||||
self.mlp_extractor = MlpExtractor(self.features_dim, net_arch=self.net_arch,
|
||||
activation_fn=self.activation_fn, device=self.device)
|
||||
|
||||
latent_dim_pi = self.mlp_extractor.latent_dim_pi
|
||||
|
||||
# Separate feature extractor for gSDE
|
||||
if self.sde_net_arch is not None:
|
||||
self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor(self.features_dim,
|
||||
self.sde_net_arch,
|
||||
self.activation_fn)
|
||||
|
||||
if isinstance(self.action_dist, DiagGaussianDistribution):
|
||||
self.action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi,
|
||||
log_std_init=self.log_std_init)
|
||||
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
|
||||
latent_sde_dim = latent_dim_pi if self.sde_net_arch is None else latent_sde_dim
|
||||
self.action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi,
|
||||
latent_sde_dim=latent_sde_dim,
|
||||
log_std_init=self.log_std_init)
|
||||
elif isinstance(self.action_dist, CategoricalDistribution):
|
||||
self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
|
||||
elif isinstance(self.action_dist, MultiCategoricalDistribution):
|
||||
self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
|
||||
elif isinstance(self.action_dist, BernoulliDistribution):
|
||||
self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
|
||||
|
||||
self.value_net = nn.Linear(self.mlp_extractor.latent_dim_vf, 1)
|
||||
# Init weights: use orthogonal initialization
|
||||
# with small initial weight for the output
|
||||
if self.ortho_init:
|
||||
# TODO: check for features_extractor
|
||||
# Values from stable-baselines.
|
||||
# feature_extractor/mlp values are
|
||||
# originally from openai/baselines (default gains/init_scales).
|
||||
module_gains = {
|
||||
self.features_extractor: np.sqrt(2),
|
||||
self.mlp_extractor: np.sqrt(2),
|
||||
self.action_net: 0.01,
|
||||
self.value_net: 1
|
||||
}
|
||||
for module, gain in module_gains.items():
|
||||
module.apply(partial(self.init_weights, gain=gain))
|
||||
|
||||
# Setup optimizer with initial learning rate
|
||||
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||
|
||||
def forward(self, obs: th.Tensor,
|
||||
deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
||||
"""
|
||||
Forward pass in all the networks (actor and critic)
|
||||
|
||||
:param obs: (th.Tensor) Observation
|
||||
:param deterministic: (bool) Whether to sample or use deterministic actions
|
||||
:return: (Tuple[th.Tensor, th.Tensor, th.Tensor]) action, value and log probability of the action
|
||||
"""
|
||||
latent_pi, latent_vf, latent_sde = self._get_latent(obs)
|
||||
# Evaluate the values for the given observations
|
||||
values = self.value_net(latent_vf)
|
||||
distribution = self._get_action_dist_from_latent(latent_pi, latent_sde=latent_sde)
|
||||
actions = distribution.get_actions(deterministic=deterministic)
|
||||
log_prob = distribution.log_prob(actions)
|
||||
return actions, values, log_prob
|
||||
|
||||
def _get_latent(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
||||
"""
|
||||
Get the latent code (i.e., activations of the last layer of each network)
|
||||
for the different networks.
|
||||
|
||||
:param obs: (th.Tensor) Observation
|
||||
:return: (Tuple[th.Tensor, th.Tensor, th.Tensor]) Latent codes
|
||||
for the actor, the value function and for gSDE function
|
||||
"""
|
||||
# Preprocess the observation if needed
|
||||
features = self.extract_features(obs)
|
||||
latent_pi, latent_vf = self.mlp_extractor(features)
|
||||
|
||||
# Features for sde
|
||||
latent_sde = latent_pi
|
||||
if self.sde_features_extractor is not None:
|
||||
latent_sde = self.sde_features_extractor(features)
|
||||
return latent_pi, latent_vf, latent_sde
|
||||
|
||||
def _get_action_dist_from_latent(self, latent_pi: th.Tensor,
|
||||
latent_sde: Optional[th.Tensor] = None) -> Distribution:
|
||||
"""
|
||||
Retrieve action distribution given the latent codes.
|
||||
|
||||
:param latent_pi: (th.Tensor) Latent code for the actor
|
||||
:param latent_sde: (Optional[th.Tensor]) Latent code for the gSDE exploration function
|
||||
:return: (Distribution) Action distribution
|
||||
"""
|
||||
mean_actions = self.action_net(latent_pi)
|
||||
|
||||
if isinstance(self.action_dist, DiagGaussianDistribution):
|
||||
return self.action_dist.proba_distribution(mean_actions, self.log_std)
|
||||
elif isinstance(self.action_dist, CategoricalDistribution):
|
||||
# Here mean_actions are the logits before the softmax
|
||||
return self.action_dist.proba_distribution(action_logits=mean_actions)
|
||||
elif isinstance(self.action_dist, MultiCategoricalDistribution):
|
||||
# Here mean_actions are the flattened logits
|
||||
return self.action_dist.proba_distribution(action_logits=mean_actions)
|
||||
elif isinstance(self.action_dist, BernoulliDistribution):
|
||||
# Here mean_actions are the logits (before rounding to get the binary actions)
|
||||
return self.action_dist.proba_distribution(action_logits=mean_actions)
|
||||
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
|
||||
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde)
|
||||
else:
|
||||
raise ValueError('Invalid action distribution')
|
||||
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
"""
|
||||
Get the action according to the policy for a given observation.
|
||||
|
||||
:param observation: (th.Tensor)
|
||||
:param deterministic: (bool) Whether to use stochastic or deterministic actions
|
||||
:return: (th.Tensor) Taken action according to the policy
|
||||
"""
|
||||
latent_pi, _, latent_sde = self._get_latent(observation)
|
||||
distribution = self._get_action_dist_from_latent(latent_pi, latent_sde)
|
||||
return distribution.get_actions(deterministic=deterministic)
|
||||
|
||||
def evaluate_actions(self, obs: th.Tensor,
|
||||
actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
||||
"""
|
||||
Evaluate actions according to the current policy,
|
||||
given the observations.
|
||||
|
||||
:param obs: (th.Tensor)
|
||||
:param actions: (th.Tensor)
|
||||
:return: (th.Tensor, th.Tensor, th.Tensor) estimated value, log likelihood of taking those actions
|
||||
and entropy of the action distribution.
|
||||
"""
|
||||
latent_pi, latent_vf, latent_sde = self._get_latent(obs)
|
||||
distribution = self._get_action_dist_from_latent(latent_pi, latent_sde)
|
||||
log_prob = distribution.log_prob(actions)
|
||||
values = self.value_net(latent_vf)
|
||||
return values, log_prob, distribution.entropy()
|
||||
|
||||
|
||||
class ActorCriticCnnPolicy(ActorCriticPolicy):
|
||||
"""
|
||||
CNN policy class for actor-critic algorithms (has both policy and value prediction).
|
||||
Used by A2C, PPO and the likes.
|
||||
|
||||
:param observation_space: (gym.spaces.Space) Observation space
|
||||
:param action_space: (gym.spaces.Space) Action space
|
||||
:param lr_schedule: (Callable) Learning rate schedule (could be constant)
|
||||
:param net_arch: ([int or dict]) The specification of the policy and value networks.
|
||||
:param device: (str or th.device) Device on which the code should run.
|
||||
:param activation_fn: (Type[nn.Module]) Activation function
|
||||
:param ortho_init: (bool) Whether to use or not orthogonal initialization
|
||||
:param use_sde: (bool) Whether to use State Dependent Exploration or not
|
||||
:param log_std_init: (float) Initial value for the log standard deviation
|
||||
:param full_std: (bool) Whether to use (n_features x n_actions) parameters
|
||||
for the std instead of only (n_features,) when using gSDE
|
||||
:param sde_net_arch: ([int]) Network architecture for extracting features
|
||||
when using gSDE. If None, the latent features from the policy will be used.
|
||||
Pass an empty list to use the states as features.
|
||||
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` to ensure
|
||||
a positive standard deviation (cf paper). It allows to keep variance
|
||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||
:param squash_output: (bool) Whether to squash the output using a tanh function,
|
||||
this allows to ensure boundaries when using gSDE.
|
||||
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
|
||||
:param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments
|
||||
to pass to the feature extractor.
|
||||
:param normalize_images: (bool) Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use,
|
||||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: Callable,
|
||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||
device: Union[th.device, str] = 'auto',
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
ortho_init: bool = True,
|
||||
use_sde: bool = False,
|
||||
log_std_init: float = 0.0,
|
||||
full_std: bool = True,
|
||||
sde_net_arch: Optional[List[int]] = None,
|
||||
use_expln: bool = False,
|
||||
squash_output: bool = False,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None):
|
||||
super(ActorCriticCnnPolicy, self).__init__(observation_space,
|
||||
action_space,
|
||||
lr_schedule,
|
||||
net_arch,
|
||||
device,
|
||||
activation_fn,
|
||||
ortho_init,
|
||||
use_sde,
|
||||
log_std_init,
|
||||
full_std,
|
||||
sde_net_arch,
|
||||
use_expln,
|
||||
squash_output,
|
||||
features_extractor_class,
|
||||
features_extractor_kwargs,
|
||||
normalize_images,
|
||||
optimizer_class,
|
||||
optimizer_kwargs)
|
||||
|
||||
|
||||
def create_sde_features_extractor(features_dim: int,
|
||||
|
|
@ -437,7 +643,8 @@ _policy_registry = dict() # type: Dict[Type[BasePolicy], Dict[str, Type[BasePol
|
|||
|
||||
def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[BasePolicy]:
|
||||
"""
|
||||
Returns the registered policy from the base type and name
|
||||
Returns the registered policy from the base type and name.
|
||||
See `register_policy` for registering policies and explanation.
|
||||
|
||||
:param base_policy_type: (Type[BasePolicy]) the base policy class
|
||||
:param name: (str) the policy name
|
||||
|
|
@ -454,7 +661,23 @@ def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[
|
|||
def register_policy(name: str, policy: Type[BasePolicy]) -> None:
|
||||
"""
|
||||
Register a policy, so it can be called using its name.
|
||||
e.g. SAC('MlpPolicy', ...) instead of SAC(MlpPolicy, ...)
|
||||
e.g. SAC('MlpPolicy', ...) instead of SAC(MlpPolicy, ...).
|
||||
|
||||
The goal here is to standardize policy naming, e.g.
|
||||
all algorithms can call upon "MlpPolicy" or "CnnPolicy",
|
||||
and they receive respective policies that work for them.
|
||||
Consider following:
|
||||
|
||||
OnlinePolicy
|
||||
-- OnlineMlpPolicy ("MlpPolicy")
|
||||
-- OnlineCnnPolicy ("CnnPolicy")
|
||||
OfflinePolicy
|
||||
-- OfflineMlpPolicy ("MlpPolicy")
|
||||
-- OfflineCnnPolicy ("CnnPolicy")
|
||||
|
||||
Two policies have name "MlpPolicy" and two have "CnnPolicy".
|
||||
In `get_policy_from_name`, the parent class (e.g. OnlinePolicy)
|
||||
is given and used to select and return the correct policy.
|
||||
|
||||
:param name: (str) the policy name
|
||||
:param policy: (Type[BasePolicy]) the policy class
|
||||
|
|
@ -470,98 +693,9 @@ def register_policy(name: str, policy: Type[BasePolicy]) -> None:
|
|||
if sub_class not in _policy_registry:
|
||||
_policy_registry[sub_class] = {}
|
||||
if name in _policy_registry[sub_class]:
|
||||
raise ValueError(f"Error: the name {name} is alreay registered for a different policy, will not override.")
|
||||
# Check if the registered policy is same
|
||||
# we try to register. If not so,
|
||||
# do not override and complain.
|
||||
if _policy_registry[sub_class][name] != policy:
|
||||
raise ValueError(f"Error: the name {name} is already registered for a different policy, will not override.")
|
||||
_policy_registry[sub_class][name] = policy
|
||||
|
||||
|
||||
class MlpExtractor(nn.Module):
|
||||
"""
|
||||
Constructs an MLP that receives observations as an input and outputs a latent representation for the policy and
|
||||
a value network. The ``net_arch`` parameter allows to specify the amount and size of the hidden layers and how many
|
||||
of them are shared between the policy network and the value network. It is assumed to be a list with the following
|
||||
structure:
|
||||
|
||||
1. An arbitrary length (zero allowed) number of integers each specifying the number of units in a shared layer.
|
||||
If the number of ints is zero, there will be no shared layers.
|
||||
2. An optional dict, to specify the following non-shared layers for the value network and the policy network.
|
||||
It is formatted like ``dict(vf=[<value layer sizes>], pi=[<policy layer sizes>])``.
|
||||
If it is missing any of the keys (pi or vf), no non-shared layers (empty list) is assumed.
|
||||
|
||||
For example to construct a network with one shared layer of size 55 followed by two non-shared layers for the value
|
||||
network of size 255 and a single non-shared layer of size 128 for the policy network, the following layers_spec
|
||||
would be used: ``[55, dict(vf=[255, 255], pi=[128])]``. A simple shared network topology with two layers of size 128
|
||||
would be specified as [128, 128].
|
||||
|
||||
Adapted from Stable Baselines.
|
||||
|
||||
:param feature_dim: (int) Dimension of the feature vector (can be the output of a CNN)
|
||||
:param net_arch: ([int or dict]) The specification of the policy and value networks.
|
||||
See above for details on its formatting.
|
||||
:param activation_fn: (Type[nn.Module]) The activation function to use for the networks.
|
||||
:param device: (th.device)
|
||||
"""
|
||||
|
||||
def __init__(self, feature_dim: int,
|
||||
net_arch: List[Union[int, Dict[str, List[int]]]],
|
||||
activation_fn: Type[nn.Module],
|
||||
device: Union[th.device, str] = 'auto'):
|
||||
super(MlpExtractor, self).__init__()
|
||||
device = get_device(device)
|
||||
shared_net, policy_net, value_net = [], [], []
|
||||
policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network
|
||||
value_only_layers = [] # Layer sizes of the network that only belongs to the value network
|
||||
last_layer_dim_shared = feature_dim
|
||||
|
||||
# Iterate through the shared layers and build the shared parts of the network
|
||||
for idx, layer in enumerate(net_arch):
|
||||
if isinstance(layer, int): # Check that this is a shared layer
|
||||
layer_size = layer
|
||||
# TODO: give layer a meaningful name
|
||||
shared_net.append(nn.Linear(last_layer_dim_shared, layer_size))
|
||||
shared_net.append(activation_fn())
|
||||
last_layer_dim_shared = layer_size
|
||||
else:
|
||||
assert isinstance(layer, dict), "Error: the net_arch list can only contain ints and dicts"
|
||||
if 'pi' in layer:
|
||||
assert isinstance(layer['pi'], list), "Error: net_arch[-1]['pi'] must contain a list of integers."
|
||||
policy_only_layers = layer['pi']
|
||||
|
||||
if 'vf' in layer:
|
||||
assert isinstance(layer['vf'], list), "Error: net_arch[-1]['vf'] must contain a list of integers."
|
||||
value_only_layers = layer['vf']
|
||||
break # From here on the network splits up in policy and value network
|
||||
|
||||
last_layer_dim_pi = last_layer_dim_shared
|
||||
last_layer_dim_vf = last_layer_dim_shared
|
||||
|
||||
# Build the non-shared part of the network
|
||||
for idx, (pi_layer_size, vf_layer_size) in enumerate(zip_longest(policy_only_layers, value_only_layers)):
|
||||
if pi_layer_size is not None:
|
||||
assert isinstance(pi_layer_size, int), "Error: net_arch[-1]['pi'] must only contain integers."
|
||||
policy_net.append(nn.Linear(last_layer_dim_pi, pi_layer_size))
|
||||
policy_net.append(activation_fn())
|
||||
last_layer_dim_pi = pi_layer_size
|
||||
|
||||
if vf_layer_size is not None:
|
||||
assert isinstance(vf_layer_size, int), "Error: net_arch[-1]['vf'] must only contain integers."
|
||||
value_net.append(nn.Linear(last_layer_dim_vf, vf_layer_size))
|
||||
value_net.append(activation_fn())
|
||||
last_layer_dim_vf = vf_layer_size
|
||||
|
||||
# Save dim, used to create the distributions
|
||||
self.latent_dim_pi = last_layer_dim_pi
|
||||
self.latent_dim_vf = last_layer_dim_vf
|
||||
|
||||
# Create networks
|
||||
# If the list of layers is empty, the network will just act as an Identity module
|
||||
self.shared_net = nn.Sequential(*shared_net).to(device)
|
||||
self.policy_net = nn.Sequential(*policy_net).to(device)
|
||||
self.value_net = nn.Sequential(*value_net).to(device)
|
||||
|
||||
def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
"""
|
||||
:return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network.
|
||||
If all layers are shared, then ``latent_policy == latent_value``
|
||||
"""
|
||||
shared_latent = self.shared_net(features)
|
||||
return self.policy_net(shared_latent), self.value_net(shared_latent)
|
||||
|
|
|
|||
|
|
@ -2,13 +2,20 @@
|
|||
Save util taken from stable_baselines
|
||||
used to serialize data (class parameters) of model classes
|
||||
"""
|
||||
import os
|
||||
import io
|
||||
import json
|
||||
import base64
|
||||
import functools
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
import cloudpickle
|
||||
from typing import Dict, Any, Tuple, Optional
|
||||
import warnings
|
||||
import zipfile
|
||||
|
||||
import torch as th
|
||||
import cloudpickle
|
||||
|
||||
from stable_baselines3.common.type_aliases import TensorDict
|
||||
from stable_baselines3.common.utils import get_device
|
||||
|
||||
|
||||
def recursive_getattr(obj: Any, attr: str, *args) -> Any:
|
||||
|
|
@ -165,3 +172,115 @@ def json_to_data(json_string: str,
|
|||
# Read as it is
|
||||
return_data[data_key] = data_item
|
||||
return return_data
|
||||
|
||||
|
||||
def save_to_zip_file(save_path: str, data: Dict[str, Any] = None,
|
||||
params: Dict[str, Any] = None, tensors: Dict[str, Any] = None) -> None:
|
||||
"""
|
||||
Save a model to a zip archive.
|
||||
|
||||
:param save_path: Where to store the model.
|
||||
:param data: Class parameters being stored.
|
||||
:param params: Model parameters being stored expected to contain an entry for every
|
||||
state_dict with its name and the state_dict.
|
||||
:param tensors: Extra tensor variables expected to contain name and value of tensors
|
||||
"""
|
||||
|
||||
# data/params can be None, so do not
|
||||
# try to serialize them blindly
|
||||
if data is not None:
|
||||
serialized_data = data_to_json(data)
|
||||
|
||||
# Check postfix if save_path is a string
|
||||
if isinstance(save_path, str):
|
||||
_, ext = os.path.splitext(save_path)
|
||||
if ext == "":
|
||||
save_path += ".zip"
|
||||
|
||||
# Create a zip-archive and write our objects
|
||||
# there. This works when save_path is either
|
||||
# str or a file-like
|
||||
with zipfile.ZipFile(save_path, "w") as archive:
|
||||
# Do not try to save "None" elements
|
||||
if data is not None:
|
||||
archive.writestr("data", serialized_data)
|
||||
if tensors is not None:
|
||||
with archive.open('tensors.pth', mode="w") as tensors_file:
|
||||
th.save(tensors, tensors_file)
|
||||
if params is not None:
|
||||
for file_name, dict_ in params.items():
|
||||
with archive.open(file_name + '.pth', mode="w") as param_file:
|
||||
th.save(dict_, param_file)
|
||||
|
||||
|
||||
def load_from_zip_file(load_path: str, load_data: bool = True) -> (Tuple[Optional[Dict[str, Any]],
|
||||
Optional[TensorDict],
|
||||
Optional[TensorDict]]):
|
||||
"""
|
||||
Load model data from a .zip archive
|
||||
|
||||
:param load_path: Where to load the model from
|
||||
:param load_data: Whether we should load and return data
|
||||
(class parameters). Mainly used by 'load_parameters' to only load model parameters (weights)
|
||||
:return: (dict),(dict),(dict) Class parameters, model state_dicts (dict of state_dict)
|
||||
and dict of extra tensors
|
||||
"""
|
||||
# Check if file exists if load_path is a string
|
||||
if isinstance(load_path, str):
|
||||
if not os.path.exists(load_path):
|
||||
if os.path.exists(load_path + ".zip"):
|
||||
load_path += ".zip"
|
||||
else:
|
||||
raise ValueError(f"Error: the file {load_path} could not be found")
|
||||
|
||||
# set device to cpu if cuda is not available
|
||||
device = get_device()
|
||||
|
||||
# Open the zip archive and load data
|
||||
try:
|
||||
with zipfile.ZipFile(load_path, "r") as archive:
|
||||
namelist = archive.namelist()
|
||||
# If data or parameters is not in the
|
||||
# zip archive, assume they were stored
|
||||
# as None (_save_to_file_zip allows this).
|
||||
data = None
|
||||
tensors = None
|
||||
params = {}
|
||||
|
||||
if "data" in namelist and load_data:
|
||||
# Load class parameters and convert to string
|
||||
json_data = archive.read("data").decode()
|
||||
data = json_to_data(json_data)
|
||||
|
||||
if "tensors.pth" in namelist and load_data:
|
||||
# Load extra tensors
|
||||
with archive.open('tensors.pth', mode="r") as tensor_file:
|
||||
# File has to be seekable, but opt_param_file is not, so load in BytesIO first
|
||||
# fixed in python >= 3.7
|
||||
file_content = io.BytesIO()
|
||||
file_content.write(tensor_file.read())
|
||||
# go to start of file
|
||||
file_content.seek(0)
|
||||
# load the parameters with the right ``map_location``
|
||||
tensors = th.load(file_content, map_location=device)
|
||||
|
||||
# check for all other .pth files
|
||||
other_files = [file_name for file_name in namelist if
|
||||
os.path.splitext(file_name)[1] == ".pth" and file_name != "tensors.pth"]
|
||||
# if there are any other files which end with .pth and aren't "params.pth"
|
||||
# assume that they each are optimizer parameters
|
||||
if len(other_files) > 0:
|
||||
for file_path in other_files:
|
||||
with archive.open(file_path, mode="r") as opt_param_file:
|
||||
# File has to be seekable, but opt_param_file is not, so load in BytesIO first
|
||||
# fixed in python >= 3.7
|
||||
file_content = io.BytesIO()
|
||||
file_content.write(opt_param_file.read())
|
||||
# go to start of file
|
||||
file_content.seek(0)
|
||||
# load the parameters with the right ``map_location``
|
||||
params[os.path.splitext(file_path)[0]] = th.load(file_content, map_location=device)
|
||||
except zipfile.BadZipFile:
|
||||
# load_path wasn't a zip file
|
||||
raise ValueError(f"Error: the file {load_path} wasn't a zip-file")
|
||||
return data, params, tensors
|
||||
|
|
|
|||
218
stable_baselines3/common/torch_layers.py
Normal file
218
stable_baselines3/common/torch_layers.py
Normal file
|
|
@ -0,0 +1,218 @@
|
|||
from typing import Union, Type, Dict, List, Tuple
|
||||
|
||||
from itertools import zip_longest
|
||||
|
||||
import gym
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
|
||||
from stable_baselines3.common.preprocessing import get_flattened_obs_dim, is_image_space
|
||||
from stable_baselines3.common.utils import get_device
|
||||
|
||||
|
||||
class BaseFeaturesExtractor(nn.Module):
|
||||
"""
|
||||
Base class that represents a features extractor.
|
||||
|
||||
:param observation_space: (gym.Space)
|
||||
:param features_dim: (int) Number of features extracted.
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space: gym.Space, features_dim: int = 0):
|
||||
super(BaseFeaturesExtractor, self).__init__()
|
||||
assert features_dim > 0
|
||||
self._observation_space = observation_space
|
||||
self._features_dim = features_dim
|
||||
|
||||
@property
|
||||
def features_dim(self) -> int:
|
||||
return self._features_dim
|
||||
|
||||
def forward(self, observations: th.Tensor) -> th.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class FlattenExtractor(BaseFeaturesExtractor):
|
||||
"""
|
||||
Feature extract that flatten the input.
|
||||
Used as a placeholder when feature extraction is not needed.
|
||||
|
||||
:param observation_space: (gym.Space)
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space: gym.Space):
|
||||
super(FlattenExtractor, self).__init__(observation_space, get_flattened_obs_dim(observation_space))
|
||||
self.flatten = nn.Flatten()
|
||||
|
||||
def forward(self, observations: th.Tensor) -> th.Tensor:
|
||||
return self.flatten(observations)
|
||||
|
||||
|
||||
class NatureCNN(BaseFeaturesExtractor):
|
||||
"""
|
||||
CNN from DQN nature paper:
|
||||
Mnih, Volodymyr, et al.
|
||||
"Human-level control through deep reinforcement learning."
|
||||
Nature 518.7540 (2015): 529-533.
|
||||
|
||||
:param observation_space: (gym.Space)
|
||||
:param features_dim: (int) Number of features extracted.
|
||||
This corresponds to the number of unit for the last layer.
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space: gym.spaces.Box,
|
||||
features_dim: int = 512):
|
||||
super(NatureCNN, self).__init__(observation_space, features_dim)
|
||||
# We assume CxHxW images (channels first)
|
||||
# Re-ordering will be done by pre-preprocessing or wrapper
|
||||
assert is_image_space(observation_space), ('You should use NatureCNN '
|
||||
f'only with images not with {observation_space} '
|
||||
'(you are probably using `CnnPolicy` instead of `MlpPolicy`)')
|
||||
n_input_channels = observation_space.shape[0]
|
||||
self.cnn = nn.Sequential(nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Flatten())
|
||||
|
||||
# Compute shape by doing one forward pass
|
||||
with th.no_grad():
|
||||
n_flatten = self.cnn(th.as_tensor(observation_space.sample()[None]).float()).shape[1]
|
||||
|
||||
self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
|
||||
|
||||
def forward(self, observations: th.Tensor) -> th.Tensor:
|
||||
return self.linear(self.cnn(observations))
|
||||
|
||||
|
||||
def create_mlp(input_dim: int,
|
||||
output_dim: int,
|
||||
net_arch: List[int],
|
||||
activation_fn: Type[nn.Module] = nn.ReLU,
|
||||
squash_output: bool = False) -> List[nn.Module]:
|
||||
"""
|
||||
Create a multi layer perceptron (MLP), which is
|
||||
a collection of fully-connected layers each followed by an activation function.
|
||||
|
||||
:param input_dim: (int) Dimension of the input vector
|
||||
:param output_dim: (int)
|
||||
:param net_arch: (List[int]) Architecture of the neural net
|
||||
It represents the number of units per layer.
|
||||
The length of this list is the number of layers.
|
||||
:param activation_fn: (Type[nn.Module]) The activation function
|
||||
to use after each layer.
|
||||
:param squash_output: (bool) Whether to squash the output using a Tanh
|
||||
activation function
|
||||
:return: (List[nn.Module])
|
||||
"""
|
||||
|
||||
if len(net_arch) > 0:
|
||||
modules = [nn.Linear(input_dim, net_arch[0]), activation_fn()]
|
||||
else:
|
||||
modules = []
|
||||
|
||||
for idx in range(len(net_arch) - 1):
|
||||
modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1]))
|
||||
modules.append(activation_fn())
|
||||
|
||||
if output_dim > 0:
|
||||
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else input_dim
|
||||
modules.append(nn.Linear(last_layer_dim, output_dim))
|
||||
if squash_output:
|
||||
modules.append(nn.Tanh())
|
||||
return modules
|
||||
|
||||
|
||||
class MlpExtractor(nn.Module):
|
||||
"""
|
||||
Constructs an MLP that receives observations as an input and outputs a latent representation for the policy and
|
||||
a value network. The ``net_arch`` parameter allows to specify the amount and size of the hidden layers and how many
|
||||
of them are shared between the policy network and the value network. It is assumed to be a list with the following
|
||||
structure:
|
||||
|
||||
1. An arbitrary length (zero allowed) number of integers each specifying the number of units in a shared layer.
|
||||
If the number of ints is zero, there will be no shared layers.
|
||||
2. An optional dict, to specify the following non-shared layers for the value network and the policy network.
|
||||
It is formatted like ``dict(vf=[<value layer sizes>], pi=[<policy layer sizes>])``.
|
||||
If it is missing any of the keys (pi or vf), no non-shared layers (empty list) is assumed.
|
||||
|
||||
For example to construct a network with one shared layer of size 55 followed by two non-shared layers for the value
|
||||
network of size 255 and a single non-shared layer of size 128 for the policy network, the following layers_spec
|
||||
would be used: ``[55, dict(vf=[255, 255], pi=[128])]``. A simple shared network topology with two layers of size 128
|
||||
would be specified as [128, 128].
|
||||
|
||||
Adapted from Stable Baselines.
|
||||
|
||||
:param feature_dim: (int) Dimension of the feature vector (can be the output of a CNN)
|
||||
:param net_arch: ([int or dict]) The specification of the policy and value networks.
|
||||
See above for details on its formatting.
|
||||
:param activation_fn: (Type[nn.Module]) The activation function to use for the networks.
|
||||
:param device: (th.device)
|
||||
"""
|
||||
|
||||
def __init__(self, feature_dim: int,
|
||||
net_arch: List[Union[int, Dict[str, List[int]]]],
|
||||
activation_fn: Type[nn.Module],
|
||||
device: Union[th.device, str] = 'auto'):
|
||||
super(MlpExtractor, self).__init__()
|
||||
device = get_device(device)
|
||||
shared_net, policy_net, value_net = [], [], []
|
||||
policy_only_layers = [] # Layer sizes of the network that only belongs to the policy network
|
||||
value_only_layers = [] # Layer sizes of the network that only belongs to the value network
|
||||
last_layer_dim_shared = feature_dim
|
||||
|
||||
# Iterate through the shared layers and build the shared parts of the network
|
||||
for idx, layer in enumerate(net_arch):
|
||||
if isinstance(layer, int): # Check that this is a shared layer
|
||||
layer_size = layer
|
||||
# TODO: give layer a meaningful name
|
||||
shared_net.append(nn.Linear(last_layer_dim_shared, layer_size))
|
||||
shared_net.append(activation_fn())
|
||||
last_layer_dim_shared = layer_size
|
||||
else:
|
||||
assert isinstance(layer, dict), "Error: the net_arch list can only contain ints and dicts"
|
||||
if 'pi' in layer:
|
||||
assert isinstance(layer['pi'], list), "Error: net_arch[-1]['pi'] must contain a list of integers."
|
||||
policy_only_layers = layer['pi']
|
||||
|
||||
if 'vf' in layer:
|
||||
assert isinstance(layer['vf'], list), "Error: net_arch[-1]['vf'] must contain a list of integers."
|
||||
value_only_layers = layer['vf']
|
||||
break # From here on the network splits up in policy and value network
|
||||
|
||||
last_layer_dim_pi = last_layer_dim_shared
|
||||
last_layer_dim_vf = last_layer_dim_shared
|
||||
|
||||
# Build the non-shared part of the network
|
||||
for idx, (pi_layer_size, vf_layer_size) in enumerate(zip_longest(policy_only_layers, value_only_layers)):
|
||||
if pi_layer_size is not None:
|
||||
assert isinstance(pi_layer_size, int), "Error: net_arch[-1]['pi'] must only contain integers."
|
||||
policy_net.append(nn.Linear(last_layer_dim_pi, pi_layer_size))
|
||||
policy_net.append(activation_fn())
|
||||
last_layer_dim_pi = pi_layer_size
|
||||
|
||||
if vf_layer_size is not None:
|
||||
assert isinstance(vf_layer_size, int), "Error: net_arch[-1]['vf'] must only contain integers."
|
||||
value_net.append(nn.Linear(last_layer_dim_vf, vf_layer_size))
|
||||
value_net.append(activation_fn())
|
||||
last_layer_dim_vf = vf_layer_size
|
||||
|
||||
# Save dim, used to create the distributions
|
||||
self.latent_dim_pi = last_layer_dim_pi
|
||||
self.latent_dim_vf = last_layer_dim_vf
|
||||
|
||||
# Create networks
|
||||
# If the list of layers is empty, the network will just act as an Identity module
|
||||
self.shared_net = nn.Sequential(*shared_net).to(device)
|
||||
self.policy_net = nn.Sequential(*policy_net).to(device)
|
||||
self.value_net = nn.Sequential(*value_net).to(device)
|
||||
|
||||
def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
"""
|
||||
:return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network.
|
||||
If all layers are shared, then ``latent_policy == latent_value``
|
||||
"""
|
||||
shared_latent = self.shared_net(features)
|
||||
return self.policy_net(shared_latent), self.value_net(shared_latent)
|
||||
|
|
@ -1,8 +1,11 @@
|
|||
from collections import deque
|
||||
from typing import Callable, Union, Optional
|
||||
import random
|
||||
|
||||
import os
|
||||
import glob
|
||||
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch as th
|
||||
# Check if tensorboard is available for pytorch
|
||||
|
|
@ -12,6 +15,9 @@ except ImportError:
|
|||
SummaryWriter = None
|
||||
|
||||
from stable_baselines3.common import logger
|
||||
from stable_baselines3.common.type_aliases import GymEnv
|
||||
from stable_baselines3.common.preprocessing import is_image_space
|
||||
from stable_baselines3.common.vec_env import VecTransposeImage
|
||||
|
||||
|
||||
def set_random_seed(seed: int, using_cuda: bool = False) -> None:
|
||||
|
|
@ -158,3 +164,86 @@ def configure_logger(verbose: int = 0, tensorboard_log: Optional[str] = None,
|
|||
logger.configure(save_path, ["tensorboard"])
|
||||
elif verbose == 0:
|
||||
logger.configure(format_strings=[""])
|
||||
|
||||
|
||||
def check_for_correct_spaces(env: GymEnv, observation_space: gym.spaces.Space, action_space: gym.spaces.Space):
|
||||
"""
|
||||
Checks that the environment has same spaces as provided ones. Used by BaseAlgorithm to check if
|
||||
spaces match after loading the model with given env.
|
||||
Checked parameters:
|
||||
- observation_space
|
||||
- action_space
|
||||
|
||||
:param env: (GymEnv) Environment to check for valid spaces
|
||||
:param observation_space: (gym.spaces.Space) Observation space to check against
|
||||
:param action_space: (gym.spaces.Space) Action space to check against
|
||||
"""
|
||||
if (observation_space != env.observation_space
|
||||
# Special cases for images that need to be transposed
|
||||
and not (is_image_space(env.observation_space)
|
||||
and observation_space == VecTransposeImage.transpose_space(env.observation_space))):
|
||||
raise ValueError(f'Observation spaces do not match: {observation_space} != {env.observation_space}')
|
||||
if action_space != env.action_space:
|
||||
raise ValueError(f'Action spaces do not match: {action_space} != {env.action_space}')
|
||||
|
||||
|
||||
def is_vectorized_observation(observation: np.ndarray, observation_space: gym.spaces.Space) -> bool:
|
||||
"""
|
||||
For every observation type, detects and validates the shape,
|
||||
then returns whether or not the observation is vectorized.
|
||||
|
||||
:param observation: (np.ndarray) the input observation to validate
|
||||
:param observation_space: (gym.spaces) the observation space
|
||||
:return: (bool) whether the given observation is vectorized or not
|
||||
"""
|
||||
if isinstance(observation_space, gym.spaces.Box):
|
||||
if observation.shape == observation_space.shape:
|
||||
return False
|
||||
elif observation.shape[1:] == observation_space.shape:
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"Error: Unexpected observation shape {observation.shape} for "
|
||||
+ f"Box environment, please use {observation_space.shape} "
|
||||
+ "or (n_env, {}) for the observation shape."
|
||||
.format(", ".join(map(str, observation_space.shape))))
|
||||
elif isinstance(observation_space, gym.spaces.Discrete):
|
||||
if observation.shape == (): # A numpy array of a number, has shape empty tuple '()'
|
||||
return False
|
||||
elif len(observation.shape) == 1:
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"Error: Unexpected observation shape {observation.shape} for "
|
||||
+ "Discrete environment, please use (1,) or (n_env, 1) for the observation shape.")
|
||||
|
||||
elif isinstance(observation_space, gym.spaces.MultiDiscrete):
|
||||
if observation.shape == (len(observation_space.nvec),):
|
||||
return False
|
||||
elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec):
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"Error: Unexpected observation shape {observation.shape} for MultiDiscrete "
|
||||
+ f"environment, please use ({len(observation_space.nvec)},) or "
|
||||
+ f"(n_env, {len(observation_space.nvec)}) for the observation shape.")
|
||||
elif isinstance(observation_space, gym.spaces.MultiBinary):
|
||||
if observation.shape == (observation_space.n,):
|
||||
return False
|
||||
elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n:
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"Error: Unexpected observation shape {observation.shape} for MultiBinary "
|
||||
+ f"environment, please use ({observation_space.n},) or "
|
||||
+ f"(n_env, {observation_space.n}) for the observation shape.")
|
||||
else:
|
||||
raise ValueError("Error: Cannot determine if the observation is vectorized "
|
||||
+ f" with the space type {observation_space}.")
|
||||
|
||||
|
||||
def safe_mean(arr: Union[np.ndarray, list, deque]) -> np.ndarray:
|
||||
"""
|
||||
Compute the mean of an array if there is at least one element.
|
||||
For empty array, return NaN. It is used for logging only.
|
||||
|
||||
:param arr:
|
||||
:return:
|
||||
"""
|
||||
return np.nan if len(arr) == 0 else np.mean(arr)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import pickle
|
|||
|
||||
import numpy as np
|
||||
|
||||
from stable_baselines3.common.vec_env.base_vec_env import VecEnvWrapper
|
||||
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper
|
||||
from stable_baselines3.common.running_mean_std import RunningMeanStd
|
||||
|
||||
|
||||
|
|
@ -160,12 +160,12 @@ class VecNormalize(VecEnvWrapper):
|
|||
return self.normalize_obs(obs)
|
||||
|
||||
@staticmethod
|
||||
def load(load_path, venv):
|
||||
def load(load_path: str, venv: VecEnv) -> "VecNormalize":
|
||||
"""
|
||||
Loads a saved VecNormalize object.
|
||||
|
||||
:param load_path: the path to load from.
|
||||
:param venv: the VecEnv to wrap.
|
||||
:param load_path: (str) the path to load from.
|
||||
:param venv: (VecEnv) the VecEnv to wrap.
|
||||
:return: (VecNormalize)
|
||||
"""
|
||||
with open(load_path, "rb") as file_handler:
|
||||
|
|
@ -173,22 +173,12 @@ class VecNormalize(VecEnvWrapper):
|
|||
vec_normalize.set_venv(venv)
|
||||
return vec_normalize
|
||||
|
||||
def save(self, save_path):
|
||||
def save(self, save_path: str) -> None:
|
||||
"""
|
||||
Save current VecNormalize object with
|
||||
all running statistics and settings (e.g. clip_obs)
|
||||
|
||||
:param save_path: (str) The path to save to
|
||||
"""
|
||||
with open(save_path, "wb") as file_handler:
|
||||
pickle.dump(self, file_handler)
|
||||
|
||||
def save_running_average(self, path):
|
||||
"""
|
||||
:param path: (str) path to log dir
|
||||
"""
|
||||
for rms, name in zip([self.obs_rms, self.ret_rms], ['obs_rms', 'ret_rms']):
|
||||
with open(f"{path}/{name}.pkl", 'wb') as file_handler:
|
||||
pickle.dump(rms, file_handler)
|
||||
|
||||
def load_running_average(self, path):
|
||||
"""
|
||||
:param path: (str) path to log dir
|
||||
"""
|
||||
for name in ['obs_rms', 'ret_rms']:
|
||||
with open(f"{path}/{name}.pkl", 'rb') as file_handler:
|
||||
setattr(self, name, pickle.load(file_handler))
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ if typing.TYPE_CHECKING:
|
|||
|
||||
class VecTransposeImage(VecEnvWrapper):
|
||||
"""
|
||||
Re-order channels, from WxHxC to CxWxH.
|
||||
Re-order channels, from HxWxC to CxHxW.
|
||||
It is required for PyTorch convolution layers.
|
||||
|
||||
:param venv: (VecEnv)
|
||||
|
|
|
|||
|
|
@ -1,375 +1,9 @@
|
|||
from typing import Optional, List, Tuple, Callable, Union, Dict, Type, Any
|
||||
from functools import partial
|
||||
# This file is here just to define MlpPolicy/CnnPolicy
|
||||
# that work for PPO
|
||||
from stable_baselines3.common.policies import ActorCriticPolicy, ActorCriticCnnPolicy, register_policy
|
||||
|
||||
import gym
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
MlpPolicy = ActorCriticPolicy
|
||||
CnnPolicy = ActorCriticCnnPolicy
|
||||
|
||||
from stable_baselines3.common.policies import (BasePolicy, register_policy, MlpExtractor,
|
||||
create_sde_features_extractor, NatureCNN,
|
||||
BaseFeaturesExtractor, FlattenExtractor)
|
||||
from stable_baselines3.common.distributions import (make_proba_distribution, Distribution,
|
||||
DiagGaussianDistribution, CategoricalDistribution,
|
||||
MultiCategoricalDistribution, BernoulliDistribution,
|
||||
StateDependentNoiseDistribution)
|
||||
|
||||
|
||||
class PPOPolicy(BasePolicy):
|
||||
"""
|
||||
Policy class (with both actor and critic) for A2C and derivates (PPO).
|
||||
|
||||
:param observation_space: (gym.spaces.Space) Observation space
|
||||
:param action_space: (gym.spaces.Space) Action space
|
||||
:param lr_schedule: (Callable) Learning rate schedule (could be constant)
|
||||
:param net_arch: ([int or dict]) The specification of the policy and value networks.
|
||||
:param device: (str or th.device) Device on which the code should run.
|
||||
:param activation_fn: (Type[nn.Module]) Activation function
|
||||
:param ortho_init: (bool) Whether to use or not orthogonal initialization
|
||||
:param use_sde: (bool) Whether to use State Dependent Exploration or not
|
||||
:param log_std_init: (float) Initial value for the log standard deviation
|
||||
:param full_std: (bool) Whether to use (n_features x n_actions) parameters
|
||||
for the std instead of only (n_features,) when using gSDE
|
||||
:param sde_net_arch: ([int]) Network architecture for extracting features
|
||||
when using gSDE. If None, the latent features from the policy will be used.
|
||||
Pass an empty list to use the states as features.
|
||||
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` to ensure
|
||||
a positive standard deviation (cf paper). It allows to keep variance
|
||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||
:param squash_output: (bool) Whether to squash the output using a tanh function,
|
||||
this allows to ensure boundaries when using gSDE.
|
||||
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
|
||||
:param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments
|
||||
to pass to the feature extractor.
|
||||
:param normalize_images: (bool) Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use,
|
||||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: Callable,
|
||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||
device: Union[th.device, str] = 'auto',
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
ortho_init: bool = True,
|
||||
use_sde: bool = False,
|
||||
log_std_init: float = 0.0,
|
||||
full_std: bool = True,
|
||||
sde_net_arch: Optional[List[int]] = None,
|
||||
use_expln: bool = False,
|
||||
squash_output: bool = False,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None):
|
||||
|
||||
if optimizer_kwargs is None:
|
||||
optimizer_kwargs = {}
|
||||
# Small values to avoid NaN in ADAM optimizer
|
||||
if optimizer_class == th.optim.Adam:
|
||||
optimizer_kwargs['eps'] = 1e-5
|
||||
|
||||
super(PPOPolicy, self).__init__(observation_space, action_space,
|
||||
device,
|
||||
features_extractor_class,
|
||||
features_extractor_kwargs,
|
||||
optimizer_class=optimizer_class,
|
||||
optimizer_kwargs=optimizer_kwargs,
|
||||
squash_output=squash_output)
|
||||
|
||||
# Default network architecture, from stable-baselines
|
||||
if net_arch is None:
|
||||
if features_extractor_class == FlattenExtractor:
|
||||
net_arch = [dict(pi=[64, 64], vf=[64, 64])]
|
||||
else:
|
||||
net_arch = []
|
||||
|
||||
self.net_arch = net_arch
|
||||
self.activation_fn = activation_fn
|
||||
self.ortho_init = ortho_init
|
||||
|
||||
self.features_extractor = features_extractor_class(self.observation_space,
|
||||
**self.features_extractor_kwargs)
|
||||
self.features_dim = self.features_extractor.features_dim
|
||||
|
||||
self.normalize_images = normalize_images
|
||||
self.log_std_init = log_std_init
|
||||
dist_kwargs = None
|
||||
# Keyword arguments for gSDE distribution
|
||||
if use_sde:
|
||||
dist_kwargs = {
|
||||
'full_std': full_std,
|
||||
'squash_output': squash_output,
|
||||
'use_expln': use_expln,
|
||||
'learn_features': sde_net_arch is not None
|
||||
}
|
||||
|
||||
self.sde_features_extractor = None
|
||||
self.sde_net_arch = sde_net_arch
|
||||
self.use_sde = use_sde
|
||||
self.dist_kwargs = dist_kwargs
|
||||
|
||||
# Action distribution
|
||||
self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, dist_kwargs=dist_kwargs)
|
||||
|
||||
self._build(lr_schedule)
|
||||
|
||||
def _get_data(self) -> Dict[str, Any]:
|
||||
data = super()._get_data()
|
||||
|
||||
data.update(dict(
|
||||
net_arch=self.net_arch,
|
||||
activation_fn=self.activation_fn,
|
||||
use_sde=self.use_sde,
|
||||
log_std_init=self.log_std_init,
|
||||
squash_output=self.dist_kwargs['squash_output'] if self.dist_kwargs else None,
|
||||
full_std=self.dist_kwargs['full_std'] if self.dist_kwargs else None,
|
||||
sde_net_arch=self.dist_kwargs['sde_net_arch'] if self.dist_kwargs else None,
|
||||
use_expln=self.dist_kwargs['use_expln'] if self.dist_kwargs else None,
|
||||
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
|
||||
ortho_init=self.ortho_init,
|
||||
optimizer_class=self.optimizer_class,
|
||||
optimizer_kwargs=self.optimizer_kwargs,
|
||||
features_extractor_class=self.features_extractor_class,
|
||||
features_extractor_kwargs=self.features_extractor_kwargs
|
||||
))
|
||||
return data
|
||||
|
||||
def reset_noise(self, n_envs: int = 1) -> None:
|
||||
"""
|
||||
Sample new weights for the exploration matrix.
|
||||
|
||||
:param n_envs: (int)
|
||||
"""
|
||||
assert isinstance(self.action_dist,
|
||||
StateDependentNoiseDistribution), 'reset_noise() is only available when using gSDE'
|
||||
self.action_dist.sample_weights(self.log_std, batch_size=n_envs)
|
||||
|
||||
def _build(self, lr_schedule: Callable) -> None:
|
||||
"""
|
||||
Create the networks and the optimizer.
|
||||
|
||||
:param lr_schedule: (Callable) Learning rate schedule
|
||||
lr_schedule(1) is the initial learning rate
|
||||
"""
|
||||
self.mlp_extractor = MlpExtractor(self.features_dim, net_arch=self.net_arch,
|
||||
activation_fn=self.activation_fn, device=self.device)
|
||||
|
||||
latent_dim_pi = self.mlp_extractor.latent_dim_pi
|
||||
|
||||
# Separate feature extractor for gSDE
|
||||
if self.sde_net_arch is not None:
|
||||
self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor(self.features_dim,
|
||||
self.sde_net_arch,
|
||||
self.activation_fn)
|
||||
|
||||
if isinstance(self.action_dist, DiagGaussianDistribution):
|
||||
self.action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi,
|
||||
log_std_init=self.log_std_init)
|
||||
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
|
||||
latent_sde_dim = latent_dim_pi if self.sde_net_arch is None else latent_sde_dim
|
||||
self.action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi,
|
||||
latent_sde_dim=latent_sde_dim,
|
||||
log_std_init=self.log_std_init)
|
||||
elif isinstance(self.action_dist, CategoricalDistribution):
|
||||
self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
|
||||
elif isinstance(self.action_dist, MultiCategoricalDistribution):
|
||||
self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
|
||||
elif isinstance(self.action_dist, BernoulliDistribution):
|
||||
self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
|
||||
|
||||
self.value_net = nn.Linear(self.mlp_extractor.latent_dim_vf, 1)
|
||||
# Init weights: use orthogonal initialization
|
||||
# with small initial weight for the output
|
||||
if self.ortho_init:
|
||||
# TODO: check for features_extractor
|
||||
for module in [self.features_extractor, self.mlp_extractor,
|
||||
self.action_net, self.value_net]:
|
||||
# Values from stable-baselines, TODO: check why
|
||||
gain = {
|
||||
self.features_extractor: np.sqrt(2),
|
||||
self.mlp_extractor: np.sqrt(2),
|
||||
self.action_net: 0.01,
|
||||
self.value_net: 1
|
||||
}[module]
|
||||
module.apply(partial(self.init_weights, gain=gain))
|
||||
# Setup optimizer with initial learning rate
|
||||
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
|
||||
|
||||
def forward(self, obs: th.Tensor,
|
||||
deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
||||
"""
|
||||
Forward pass in all the networks (actor and critic)
|
||||
|
||||
:param obs: (th.Tensor) Observation
|
||||
:param deterministic: (bool) Whether to sample or use deterministic actions
|
||||
:return: (Tuple[th.Tensor, th.Tensor, th.Tensor]) action, value and log probability of the action
|
||||
"""
|
||||
latent_pi, latent_vf, latent_sde = self._get_latent(obs)
|
||||
# Evaluate the values for the given observations
|
||||
values = self.value_net(latent_vf)
|
||||
distribution = self._get_action_dist_from_latent(latent_pi, latent_sde=latent_sde)
|
||||
actions = distribution.get_actions(deterministic=deterministic)
|
||||
log_prob = distribution.log_prob(actions)
|
||||
return actions, values, log_prob
|
||||
|
||||
def _get_latent(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
||||
"""
|
||||
Get the latent code (i.e., activations of the last layer of each network)
|
||||
for the different networks.
|
||||
|
||||
:param obs: (th.Tensor) Observation
|
||||
:return: (Tuple[th.Tensor, th.Tensor, th.Tensor]) Latent codes
|
||||
for the actor, the value function and for gSDE function
|
||||
"""
|
||||
# Preprocess the observation if needed
|
||||
features = self.extract_features(obs)
|
||||
latent_pi, latent_vf = self.mlp_extractor(features)
|
||||
|
||||
# Features for sde
|
||||
latent_sde = latent_pi
|
||||
if self.sde_features_extractor is not None:
|
||||
latent_sde = self.sde_features_extractor(features)
|
||||
return latent_pi, latent_vf, latent_sde
|
||||
|
||||
def _get_action_dist_from_latent(self, latent_pi: th.Tensor,
|
||||
latent_sde: Optional[th.Tensor] = None) -> Distribution:
|
||||
"""
|
||||
Retrieve action distribution given the latent codes.
|
||||
|
||||
:param latent_pi: (th.Tensor) Latent code for the actor
|
||||
:param latent_sde: (Optional[th.Tensor]) Latent code for the gSDE exploration function
|
||||
:return: (Distribution) Action distribution
|
||||
"""
|
||||
mean_actions = self.action_net(latent_pi)
|
||||
|
||||
if isinstance(self.action_dist, DiagGaussianDistribution):
|
||||
return self.action_dist.proba_distribution(mean_actions, self.log_std)
|
||||
elif isinstance(self.action_dist, CategoricalDistribution):
|
||||
# Here mean_actions are the logits before the softmax
|
||||
return self.action_dist.proba_distribution(action_logits=mean_actions)
|
||||
elif isinstance(self.action_dist, MultiCategoricalDistribution):
|
||||
# Here mean_actions are the flattened logits
|
||||
return self.action_dist.proba_distribution(action_logits=mean_actions)
|
||||
elif isinstance(self.action_dist, BernoulliDistribution):
|
||||
# Here mean_actions are the logits (before rounding to get the binary actions)
|
||||
return self.action_dist.proba_distribution(action_logits=mean_actions)
|
||||
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
|
||||
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde)
|
||||
else:
|
||||
raise ValueError('Invalid action distribution')
|
||||
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
"""
|
||||
Get the action according to the policy for a given observation.
|
||||
|
||||
:param observation: (th.Tensor)
|
||||
:param deterministic: (bool) Whether to use stochastic or deterministic actions
|
||||
:return: (th.Tensor) Taken action according to the policy
|
||||
"""
|
||||
latent_pi, _, latent_sde = self._get_latent(observation)
|
||||
distribution = self._get_action_dist_from_latent(latent_pi, latent_sde)
|
||||
return distribution.get_actions(deterministic=deterministic)
|
||||
|
||||
def evaluate_actions(self, obs: th.Tensor,
|
||||
actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
||||
"""
|
||||
Evaluate actions according to the current policy,
|
||||
given the observations.
|
||||
|
||||
:param obs: (th.Tensor)
|
||||
:param actions: (th.Tensor)
|
||||
:return: (th.Tensor, th.Tensor, th.Tensor) estimated value, log likelihood of taking those actions
|
||||
and entropy of the action distribution.
|
||||
"""
|
||||
latent_pi, latent_vf, latent_sde = self._get_latent(obs)
|
||||
distribution = self._get_action_dist_from_latent(latent_pi, latent_sde)
|
||||
log_prob = distribution.log_prob(actions)
|
||||
values = self.value_net(latent_vf)
|
||||
return values, log_prob, distribution.entropy()
|
||||
|
||||
|
||||
MlpPolicy = PPOPolicy
|
||||
|
||||
|
||||
class CnnPolicy(PPOPolicy):
|
||||
"""
|
||||
CnnPolicy class (with both actor and critic) for A2C and derivates (PPO).
|
||||
|
||||
:param observation_space: (gym.spaces.Space) Observation space
|
||||
:param action_space: (gym.spaces.Space) Action space
|
||||
:param lr_schedule: (Callable) Learning rate schedule (could be constant)
|
||||
:param net_arch: ([int or dict]) The specification of the policy and value networks.
|
||||
:param device: (str or th.device) Device on which the code should run.
|
||||
:param activation_fn: (Type[nn.Module]) Activation function
|
||||
:param ortho_init: (bool) Whether to use or not orthogonal initialization
|
||||
:param use_sde: (bool) Whether to use State Dependent Exploration or not
|
||||
:param log_std_init: (float) Initial value for the log standard deviation
|
||||
:param full_std: (bool) Whether to use (n_features x n_actions) parameters
|
||||
for the std instead of only (n_features,) when using gSDE
|
||||
:param sde_net_arch: ([int]) Network architecture for extracting features
|
||||
when using gSDE. If None, the latent features from the policy will be used.
|
||||
Pass an empty list to use the states as features.
|
||||
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` to ensure
|
||||
a positive standard deviation (cf paper). It allows to keep variance
|
||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||
:param squash_output: (bool) Whether to squash the output using a tanh function,
|
||||
this allows to ensure boundaries when using gSDE.
|
||||
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
|
||||
:param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments
|
||||
to pass to the feature extractor.
|
||||
:param normalize_images: (bool) Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param optimizer_class: (Type[th.optim.Optimizer]) The optimizer to use,
|
||||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: Callable,
|
||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||
device: Union[th.device, str] = 'auto',
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
ortho_init: bool = True,
|
||||
use_sde: bool = False,
|
||||
log_std_init: float = 0.0,
|
||||
full_std: bool = True,
|
||||
sde_net_arch: Optional[List[int]] = None,
|
||||
use_expln: bool = False,
|
||||
squash_output: bool = False,
|
||||
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
|
||||
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None):
|
||||
super(CnnPolicy, self).__init__(observation_space,
|
||||
action_space,
|
||||
lr_schedule,
|
||||
net_arch,
|
||||
device,
|
||||
activation_fn,
|
||||
ortho_init,
|
||||
use_sde,
|
||||
log_std_init,
|
||||
full_std,
|
||||
sde_net_arch,
|
||||
use_expln,
|
||||
squash_output,
|
||||
features_extractor_class,
|
||||
features_extractor_kwargs,
|
||||
normalize_images,
|
||||
optimizer_class,
|
||||
optimizer_kwargs)
|
||||
|
||||
|
||||
register_policy("MlpPolicy", MlpPolicy)
|
||||
register_policy("CnnPolicy", CnnPolicy)
|
||||
register_policy("MlpPolicy", ActorCriticPolicy)
|
||||
register_policy("CnnPolicy", ActorCriticCnnPolicy)
|
||||
|
|
|
|||
|
|
@ -1,23 +1,18 @@
|
|||
import time
|
||||
from typing import List, Tuple, Type, Union, Callable, Optional, Dict, Any
|
||||
from typing import Type, Union, Callable, Optional, Dict, Any
|
||||
|
||||
import gym
|
||||
from gym import spaces
|
||||
import torch as th
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from stable_baselines3.common import logger
|
||||
from stable_baselines3.common.base_class import BaseRLModel
|
||||
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
|
||||
from stable_baselines3.common.buffers import RolloutBuffer
|
||||
from stable_baselines3.common.utils import explained_variance, get_schedule_fn
|
||||
from stable_baselines3.common.vec_env import VecEnv
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
from stable_baselines3.ppo.policies import PPOPolicy
|
||||
from stable_baselines3.common.policies import ActorCriticPolicy
|
||||
|
||||
|
||||
class PPO(BaseRLModel):
|
||||
class PPO(OnPolicyAlgorithm):
|
||||
"""
|
||||
Proximal Policy Optimization algorithm (PPO) (clip version)
|
||||
|
||||
|
|
@ -28,10 +23,10 @@ class PPO(BaseRLModel):
|
|||
|
||||
Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html
|
||||
|
||||
:param policy: (PPOPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, ...)
|
||||
:param policy: (ActorCriticPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, ...)
|
||||
:param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str)
|
||||
:param learning_rate: (float or callable) The learning rate, it can be a function
|
||||
of the current progress (from 1 to 0)
|
||||
of the current progress remaining (from 1 to 0)
|
||||
:param n_steps: (int) The number of steps to run for each environment per update
|
||||
(i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
|
||||
:param batch_size: (int) Minibatch size
|
||||
|
|
@ -39,9 +34,9 @@ class PPO(BaseRLModel):
|
|||
:param gamma: (float) Discount factor
|
||||
:param gae_lambda: (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator
|
||||
:param clip_range: (float or callable) Clipping parameter, it can be a function of the current progress
|
||||
(from 1 to 0).
|
||||
remaining (from 1 to 0).
|
||||
:param clip_range_vf: (float or callable) Clipping parameter for the value function,
|
||||
it can be a function of the current progress (from 1 to 0).
|
||||
it can be a function of the current progress remaining (from 1 to 0).
|
||||
This is a parameter specific to the OpenAI implementation. If None is passed (default),
|
||||
no clipping will be done on the value function.
|
||||
IMPORTANT: this clipping depends on the reward scaling.
|
||||
|
|
@ -67,7 +62,7 @@ class PPO(BaseRLModel):
|
|||
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
|
||||
"""
|
||||
|
||||
def __init__(self, policy: Union[str, Type[PPOPolicy]],
|
||||
def __init__(self, policy: Union[str, Type[ActorCriticPolicy]],
|
||||
env: Union[GymEnv, str],
|
||||
learning_rate: Union[float, Callable] = 3e-4,
|
||||
n_steps: int = 2048,
|
||||
|
|
@ -91,41 +86,27 @@ class PPO(BaseRLModel):
|
|||
device: Union[th.device, str] = "auto",
|
||||
_init_setup_model: bool = True):
|
||||
|
||||
super(PPO, self).__init__(policy, env, PPOPolicy, learning_rate,
|
||||
policy_kwargs=policy_kwargs, tensorboard_log=tensorboard_log,
|
||||
verbose=verbose, device=device, use_sde=use_sde, sde_sample_freq=sde_sample_freq,
|
||||
create_eval_env=create_eval_env, support_multi_env=True, seed=seed)
|
||||
super(PPO, self).__init__(policy, env, learning_rate=learning_rate,
|
||||
n_steps=n_steps, gamma=gamma, gae_lambda=gae_lambda,
|
||||
ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm,
|
||||
use_sde=use_sde, sde_sample_freq=sde_sample_freq,
|
||||
tensorboard_log=tensorboard_log, policy_kwargs=policy_kwargs,
|
||||
verbose=verbose, device=device, create_eval_env=create_eval_env,
|
||||
seed=seed, _init_setup_model=False)
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.n_epochs = n_epochs
|
||||
self.n_steps = n_steps
|
||||
self.gamma = gamma
|
||||
self.gae_lambda = gae_lambda
|
||||
self.clip_range = clip_range
|
||||
self.clip_range_vf = clip_range_vf
|
||||
self.ent_coef = ent_coef
|
||||
self.vf_coef = vf_coef
|
||||
self.max_grad_norm = max_grad_norm
|
||||
self.rollout_buffer = None
|
||||
self.target_kl = target_kl
|
||||
self.tb_writer = None
|
||||
|
||||
if _init_setup_model:
|
||||
self._setup_model()
|
||||
|
||||
def _setup_model(self) -> None:
|
||||
self._setup_lr_schedule()
|
||||
self.set_random_seed(self.seed)
|
||||
|
||||
self.rollout_buffer = RolloutBuffer(self.n_steps, self.observation_space,
|
||||
self.action_space, self.device,
|
||||
gamma=self.gamma, gae_lambda=self.gae_lambda,
|
||||
n_envs=self.n_envs)
|
||||
self.policy = self.policy_class(self.observation_space, self.action_space,
|
||||
self.lr_schedule, use_sde=self.use_sde, device=self.device,
|
||||
**self.policy_kwargs)
|
||||
self.policy = self.policy.to(self.device)
|
||||
super(PPO, self)._setup_model()
|
||||
|
||||
# Initialize schedules for policy/value clipping
|
||||
self.clip_range = get_schedule_fn(self.clip_range)
|
||||
if self.clip_range_vf is not None:
|
||||
if isinstance(self.clip_range_vf, (float, int)):
|
||||
|
|
@ -134,77 +115,28 @@ class PPO(BaseRLModel):
|
|||
|
||||
self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
|
||||
|
||||
def collect_rollouts(self,
|
||||
env: VecEnv,
|
||||
callback: BaseCallback,
|
||||
rollout_buffer: RolloutBuffer,
|
||||
n_rollout_steps: int = 256) -> bool:
|
||||
|
||||
assert self._last_obs is not None, "No previous observation was provided"
|
||||
n_steps = 0
|
||||
rollout_buffer.reset()
|
||||
# Sample new weights for the state dependent exploration
|
||||
if self.use_sde:
|
||||
self.policy.reset_noise(env.num_envs)
|
||||
|
||||
callback.on_rollout_start()
|
||||
|
||||
while n_steps < n_rollout_steps:
|
||||
if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
|
||||
# Sample a new noise matrix
|
||||
self.policy.reset_noise(env.num_envs)
|
||||
|
||||
with th.no_grad():
|
||||
# Convert to pytorch tensor
|
||||
obs_tensor = th.as_tensor(self._last_obs).to(self.device)
|
||||
actions, values, log_probs = self.policy.forward(obs_tensor)
|
||||
actions = actions.cpu().numpy()
|
||||
|
||||
# Rescale and perform action
|
||||
clipped_actions = actions
|
||||
# Clip the actions to avoid out of bound error
|
||||
if isinstance(self.action_space, gym.spaces.Box):
|
||||
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
|
||||
|
||||
new_obs, rewards, dones, infos = env.step(clipped_actions)
|
||||
|
||||
if callback.on_step() is False:
|
||||
return False
|
||||
|
||||
self._update_info_buffer(infos)
|
||||
n_steps += 1
|
||||
self.num_timesteps += env.num_envs
|
||||
|
||||
if isinstance(self.action_space, gym.spaces.Discrete):
|
||||
# Reshape in case of discrete action
|
||||
actions = actions.reshape(-1, 1)
|
||||
rollout_buffer.add(self._last_obs, actions, rewards, dones, values, log_probs)
|
||||
self._last_obs = new_obs
|
||||
|
||||
rollout_buffer.compute_returns_and_advantage(values, dones=dones)
|
||||
|
||||
callback.on_rollout_end()
|
||||
|
||||
return True
|
||||
|
||||
def train(self, n_epochs: int, batch_size: int = 64) -> None:
|
||||
def train(self) -> None:
|
||||
"""
|
||||
Update policy using the currently gathered
|
||||
rollout buffer.
|
||||
"""
|
||||
# Update optimizer learning rate
|
||||
self._update_learning_rate(self.policy.optimizer)
|
||||
# Compute current clip range
|
||||
clip_range = self.clip_range(self._current_progress)
|
||||
clip_range = self.clip_range(self._current_progress_remaining)
|
||||
# Optional: clip range for the value function
|
||||
if self.clip_range_vf is not None:
|
||||
clip_range_vf = self.clip_range_vf(self._current_progress)
|
||||
clip_range_vf = self.clip_range_vf(self._current_progress_remaining)
|
||||
|
||||
entropy_losses, all_kl_divs = [], []
|
||||
pg_losses, value_losses = [], []
|
||||
clip_fractions = []
|
||||
|
||||
# train for gradient_steps epochs
|
||||
for epoch in range(n_epochs):
|
||||
for epoch in range(self.n_epochs):
|
||||
approx_kl_divs = []
|
||||
# Do a complete pass on the rollout buffer
|
||||
for rollout_data in self.rollout_buffer.get(batch_size):
|
||||
for rollout_data in self.rollout_buffer.get(self.batch_size):
|
||||
actions = rollout_data.actions
|
||||
if isinstance(self.action_space, spaces.Discrete):
|
||||
# Convert discrete action from float to long
|
||||
|
|
@ -214,7 +146,7 @@ class PPO(BaseRLModel):
|
|||
# TODO: investigate why there is no issue with the gradient
|
||||
# if that line is commented (as in SAC)
|
||||
if self.use_sde:
|
||||
self.policy.reset_noise(batch_size)
|
||||
self.policy.reset_noise(self.batch_size)
|
||||
|
||||
values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
|
||||
values = values.flatten()
|
||||
|
|
@ -272,7 +204,7 @@ class PPO(BaseRLModel):
|
|||
print(f"Early stopping at step {epoch} due to reaching max kl: {np.mean(approx_kl_divs):.2f}")
|
||||
break
|
||||
|
||||
self._n_updates += n_epochs
|
||||
self._n_updates += self.n_epochs
|
||||
explained_var = explained_variance(self.rollout_buffer.returns.flatten(),
|
||||
self.rollout_buffer.values.flatten())
|
||||
|
||||
|
|
@ -303,48 +235,7 @@ class PPO(BaseRLModel):
|
|||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True) -> "PPO":
|
||||
|
||||
iteration = 0
|
||||
total_timesteps, callback = self._setup_learn(total_timesteps, eval_env, callback, eval_freq,
|
||||
n_eval_episodes, eval_log_path, reset_num_timesteps,
|
||||
tb_log_name)
|
||||
callback.on_training_start(locals(), globals())
|
||||
|
||||
while self.num_timesteps < total_timesteps:
|
||||
|
||||
continue_training = self.collect_rollouts(self.env, callback,
|
||||
self.rollout_buffer,
|
||||
n_rollout_steps=self.n_steps)
|
||||
|
||||
if continue_training is False:
|
||||
break
|
||||
|
||||
iteration += 1
|
||||
self._update_current_progress(self.num_timesteps, total_timesteps)
|
||||
|
||||
# Log training infos
|
||||
if log_interval is not None and iteration % log_interval == 0:
|
||||
fps = int(self.num_timesteps / (time.time() - self.start_time))
|
||||
logger.record("time/iterations", iteration, exclude="tensorboard")
|
||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
||||
logger.record("rollout/ep_rew_mean",
|
||||
self.safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
|
||||
logger.record("rollout/ep_len_mean",
|
||||
self.safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
|
||||
logger.record("time/fps", fps)
|
||||
logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard")
|
||||
logger.record("time/total timesteps", self.num_timesteps, exclude="tensorboard")
|
||||
logger.dump(step=self.num_timesteps)
|
||||
|
||||
self.train(self.n_epochs, batch_size=self.batch_size)
|
||||
|
||||
callback.on_training_end()
|
||||
|
||||
return self
|
||||
|
||||
def get_torch_variables(self) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
cf base class
|
||||
"""
|
||||
state_dicts = ["policy", "policy.optimizer"]
|
||||
|
||||
return state_dicts, []
|
||||
return super(PPO, self).learn(total_timesteps=total_timesteps, callback=callback,
|
||||
log_interval=log_interval, eval_env=eval_env, eval_freq=eval_freq,
|
||||
n_eval_episodes=n_eval_episodes, tb_log_name=tb_log_name,
|
||||
eval_log_path=eval_log_path, reset_num_timesteps=reset_num_timesteps)
|
||||
|
|
|
|||
|
|
@ -5,9 +5,8 @@ import torch as th
|
|||
import torch.nn as nn
|
||||
|
||||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
from stable_baselines3.common.policies import (BasePolicy, register_policy, create_mlp,
|
||||
create_sde_features_extractor, NatureCNN,
|
||||
BaseFeaturesExtractor, FlattenExtractor)
|
||||
from stable_baselines3.common.policies import BasePolicy, register_policy, create_sde_features_extractor
|
||||
from stable_baselines3.common.torch_layers import create_mlp, NatureCNN, BaseFeaturesExtractor, FlattenExtractor
|
||||
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
|
||||
|
||||
# CAP the standard deviation of the actor
|
||||
|
|
|
|||
|
|
@ -4,13 +4,13 @@ import torch.nn.functional as F
|
|||
import numpy as np
|
||||
|
||||
from stable_baselines3.common import logger
|
||||
from stable_baselines3.common.base_class import OffPolicyRLModel
|
||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
|
||||
from stable_baselines3.common.noise import ActionNoise
|
||||
from stable_baselines3.sac.policies import SACPolicy
|
||||
|
||||
|
||||
class SAC(OffPolicyRLModel):
|
||||
class SAC(OffPolicyAlgorithm):
|
||||
"""
|
||||
Soft Actor-Critic (SAC)
|
||||
Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor,
|
||||
|
|
@ -28,7 +28,7 @@ class SAC(OffPolicyRLModel):
|
|||
:param env: (GymEnv or str) The environment to learn from (if registered in Gym, can be str)
|
||||
:param learning_rate: (float or callable) learning rate for adam optimizer,
|
||||
the same learning rate will be used for all networks (Q-Values, Actor and Value function)
|
||||
it can be a function of the current progress (from 1 to 0)
|
||||
it can be a function of the current progress remaining (from 1 to 0)
|
||||
:param buffer_size: (int) size of the replay buffer
|
||||
:param learning_starts: (int) how many steps of the model to collect transitions for before learning starts
|
||||
:param batch_size: (int) Minibatch size for each gradient update
|
||||
|
|
@ -252,7 +252,7 @@ class SAC(OffPolicyRLModel):
|
|||
n_eval_episodes: int = 5,
|
||||
tb_log_name: str = "SAC",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True) -> OffPolicyRLModel:
|
||||
reset_num_timesteps: bool = True) -> OffPolicyAlgorithm:
|
||||
|
||||
total_timesteps, callback = self._setup_learn(total_timesteps, eval_env, callback, eval_freq,
|
||||
n_eval_episodes, eval_log_path, reset_num_timesteps,
|
||||
|
|
@ -270,7 +270,7 @@ class SAC(OffPolicyRLModel):
|
|||
if rollout.continue_training is False:
|
||||
break
|
||||
|
||||
self._update_current_progress(self.num_timesteps, total_timesteps)
|
||||
self._update_current_progress_remaining(self.num_timesteps, total_timesteps)
|
||||
|
||||
if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts:
|
||||
gradient_steps = self.gradient_steps if self.gradient_steps > 0 else rollout.episode_timesteps
|
||||
|
|
|
|||
|
|
@ -5,9 +5,9 @@ import torch as th
|
|||
import torch.nn as nn
|
||||
|
||||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
from stable_baselines3.common.policies import (BasePolicy, register_policy, create_mlp,
|
||||
create_sde_features_extractor, NatureCNN,
|
||||
BaseFeaturesExtractor, FlattenExtractor)
|
||||
from stable_baselines3.common.policies import BasePolicy, register_policy, create_sde_features_extractor
|
||||
from stable_baselines3.common.torch_layers import (create_mlp, NatureCNN, BaseFeaturesExtractor,
|
||||
FlattenExtractor)
|
||||
from stable_baselines3.common.distributions import StateDependentNoiseDistribution
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,15 +1,22 @@
|
|||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
import torch.nn.functional as F
|
||||
from typing import List, Tuple, Type, Union, Callable, Optional, Dict, Any
|
||||
|
||||
from stable_baselines3.common import logger
|
||||
from stable_baselines3.common.base_class import OffPolicyRLModel
|
||||
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutReturn
|
||||
from stable_baselines3.common.utils import safe_mean
|
||||
from stable_baselines3.common.vec_env import VecEnv
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
from stable_baselines3.common.noise import ActionNoise
|
||||
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
|
||||
from stable_baselines3.common.buffers import ReplayBuffer
|
||||
from stable_baselines3.td3.policies import TD3Policy
|
||||
|
||||
|
||||
class TD3(OffPolicyRLModel):
|
||||
class TD3(OffPolicyAlgorithm):
|
||||
"""
|
||||
Twin Delayed DDPG (TD3)
|
||||
Addressing Function Approximation Error in Actor-Critic Methods.
|
||||
|
|
@ -22,7 +29,7 @@ class TD3(OffPolicyRLModel):
|
|||
:param env: (GymEnv or str) The environment to learn from (if registered in Gym, can be str)
|
||||
:param learning_rate: (float or callable) learning rate for adam optimizer,
|
||||
the same learning rate will be used for all networks (Q-Values, Actor and Value function)
|
||||
it can be a function of the current progress (from 1 to 0)
|
||||
it can be a function of the current progress remaining (from 1 to 0)
|
||||
:param buffer_size: (int) size of the replay buffer
|
||||
:param learning_starts: (int) how many steps of the model to collect transitions for before learning starts
|
||||
:param batch_size: (int) Minibatch size for each gradient update
|
||||
|
|
@ -197,7 +204,7 @@ class TD3(OffPolicyRLModel):
|
|||
value_loss = F.mse_loss(returns, values)
|
||||
|
||||
# A2C loss
|
||||
policy_loss = -(advantage * log_prob).mean()
|
||||
policy_loss = -(advantage * log_prob).mean() # pytype: disable=attribute-error
|
||||
|
||||
# Entropy loss favor exploration
|
||||
if entropy is None:
|
||||
|
|
@ -233,7 +240,7 @@ class TD3(OffPolicyRLModel):
|
|||
n_eval_episodes: int = 5,
|
||||
tb_log_name: str = "TD3",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True) -> OffPolicyRLModel:
|
||||
reset_num_timesteps: bool = True) -> OffPolicyAlgorithm:
|
||||
|
||||
total_timesteps, callback = self._setup_learn(total_timesteps, eval_env, callback, eval_freq,
|
||||
n_eval_episodes, eval_log_path, reset_num_timesteps,
|
||||
|
|
@ -252,14 +259,14 @@ class TD3(OffPolicyRLModel):
|
|||
if rollout.continue_training is False:
|
||||
break
|
||||
|
||||
self._update_current_progress(self.num_timesteps, total_timesteps)
|
||||
self._update_current_progress_remaining(self.num_timesteps, total_timesteps)
|
||||
|
||||
if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts:
|
||||
|
||||
if self.use_sde:
|
||||
if self.sde_log_std_scheduler is not None:
|
||||
# Call the scheduler
|
||||
value = self.sde_log_std_scheduler(self._current_progress)
|
||||
value = self.sde_log_std_scheduler(self._current_progress_remaining)
|
||||
self.actor.log_std.data = th.ones_like(self.actor.log_std) * value
|
||||
else:
|
||||
# On-policy gradient
|
||||
|
|
@ -272,6 +279,182 @@ class TD3(OffPolicyRLModel):
|
|||
|
||||
return self
|
||||
|
||||
def collect_rollouts(self, # noqa: C901
|
||||
env: VecEnv,
|
||||
# Type hint as string to avoid circular import
|
||||
callback: 'BaseCallback',
|
||||
n_episodes: int = 1,
|
||||
n_steps: int = -1,
|
||||
action_noise: Optional[ActionNoise] = None,
|
||||
learning_starts: int = 0,
|
||||
replay_buffer: Optional[ReplayBuffer] = None,
|
||||
log_interval: Optional[int] = None) -> RolloutReturn:
|
||||
"""
|
||||
Collect rollout using the current policy (and possibly fill the replay buffer)
|
||||
|
||||
:param env: (VecEnv) The training environment
|
||||
:param n_episodes: (int) Number of episodes to use to collect rollout data
|
||||
You can also specify a ``n_steps`` instead
|
||||
:param n_steps: (int) Number of steps to use to collect rollout data
|
||||
You can also specify a ``n_episodes`` instead.
|
||||
:param action_noise: (Optional[ActionNoise]) Action noise that will be used for exploration
|
||||
Required for deterministic policy (e.g. TD3). This can also be used
|
||||
in addition to the stochastic policy for SAC.
|
||||
:param callback: (BaseCallback) Callback that will be called at each step
|
||||
(and at the beginning and end of the rollout)
|
||||
:param learning_starts: (int) Number of steps before learning for the warm-up phase.
|
||||
:param replay_buffer: (ReplayBuffer)
|
||||
:param log_interval: (int) Log data every ``log_interval`` episodes
|
||||
:return: (RolloutReturn)
|
||||
"""
|
||||
episode_rewards, total_timesteps = [], []
|
||||
total_steps, total_episodes = 0, 0
|
||||
|
||||
assert isinstance(env, VecEnv), "You must pass a VecEnv"
|
||||
assert env.num_envs == 1, "OffPolicyRLModel only support single environment"
|
||||
|
||||
self.rollout_data = None
|
||||
if self.use_sde:
|
||||
self.actor.reset_noise()
|
||||
# Reset rollout data
|
||||
if self.on_policy_exploration:
|
||||
self.rollout_data = {key: [] for key in ['observations', 'actions', 'rewards', 'dones', 'values']}
|
||||
|
||||
callback.on_rollout_start()
|
||||
continue_training = True
|
||||
|
||||
while total_steps < n_steps or total_episodes < n_episodes:
|
||||
done = False
|
||||
episode_reward, episode_timesteps = 0.0, 0
|
||||
|
||||
while not done:
|
||||
|
||||
if self.use_sde and self.sde_sample_freq > 0 and total_steps % self.sde_sample_freq == 0:
|
||||
# Sample a new noise matrix
|
||||
self.actor.reset_noise()
|
||||
|
||||
# Select action randomly or according to policy
|
||||
if self.num_timesteps < learning_starts and not (self.use_sde and self.use_sde_at_warmup):
|
||||
# Warmup phase
|
||||
unscaled_action = np.array([self.action_space.sample()])
|
||||
else:
|
||||
# Note: we assume that the policy uses tanh to scale the action
|
||||
# We use non-deterministic action in the case of SAC, for TD3, it does not matter
|
||||
unscaled_action, _ = self.predict(self._last_obs, deterministic=False)
|
||||
|
||||
# Rescale the action from [low, high] to [-1, 1]
|
||||
scaled_action = self.policy.scale_action(unscaled_action)
|
||||
|
||||
if self.use_sde:
|
||||
# When using SDE, the action can be out of bounds
|
||||
# TODO: fix with squashing and account for that in the proba distribution
|
||||
clipped_action = np.clip(scaled_action, -1, 1)
|
||||
else:
|
||||
clipped_action = scaled_action
|
||||
|
||||
# Add noise to the action (improve exploration)
|
||||
if action_noise is not None:
|
||||
# NOTE: in the original implementation of TD3, the noise was applied to the unscaled action
|
||||
# Update(October 2019): Not anymore
|
||||
clipped_action = np.clip(clipped_action + action_noise(), -1, 1)
|
||||
|
||||
# Rescale and perform action
|
||||
new_obs, reward, done, infos = env.step(self.policy.unscale_action(clipped_action))
|
||||
|
||||
# Only stop training if return value is False, not when it is None.
|
||||
if callback.on_step() is False:
|
||||
return RolloutReturn(0.0, total_steps, total_episodes, continue_training=False)
|
||||
|
||||
episode_reward += reward
|
||||
|
||||
# Retrieve reward and episode length if using Monitor wrapper
|
||||
self._update_info_buffer(infos, done)
|
||||
|
||||
# Store data in replay buffer
|
||||
if replay_buffer is not None:
|
||||
# Store only the unnormalized version
|
||||
if self._vec_normalize_env is not None:
|
||||
new_obs_ = self._vec_normalize_env.get_original_obs()
|
||||
reward_ = self._vec_normalize_env.get_original_reward()
|
||||
else:
|
||||
# Avoid changing the original ones
|
||||
self._last_original_obs, new_obs_, reward_ = self._last_obs, new_obs, reward
|
||||
|
||||
replay_buffer.add(self._last_original_obs, new_obs_, clipped_action, reward_, done)
|
||||
|
||||
if self.rollout_data is not None:
|
||||
# Assume only one env
|
||||
self.rollout_data['observations'].append(self._last_obs[0].copy())
|
||||
self.rollout_data['actions'].append(scaled_action[0].copy())
|
||||
self.rollout_data['rewards'].append(reward[0].copy())
|
||||
self.rollout_data['dones'].append(done[0].copy())
|
||||
obs_tensor = th.FloatTensor(self._last_obs).to(self.device)
|
||||
self.rollout_data['values'].append(self.vf_net(obs_tensor)[0].cpu().detach().numpy())
|
||||
|
||||
self._last_obs = new_obs
|
||||
# Save the unnormalized observation
|
||||
if self._vec_normalize_env is not None:
|
||||
self._last_original_obs = new_obs_
|
||||
|
||||
self.num_timesteps += 1
|
||||
episode_timesteps += 1
|
||||
total_steps += 1
|
||||
if 0 < n_steps <= total_steps:
|
||||
break
|
||||
|
||||
if done:
|
||||
total_episodes += 1
|
||||
self._episode_num += 1
|
||||
episode_rewards.append(episode_reward)
|
||||
total_timesteps.append(episode_timesteps)
|
||||
if action_noise is not None:
|
||||
action_noise.reset()
|
||||
|
||||
# Log training infos
|
||||
if log_interval is not None and self._episode_num % log_interval == 0:
|
||||
fps = int(self.num_timesteps / (time.time() - self.start_time))
|
||||
logger.record("time/episodes", self._episode_num, exclude="tensorboard")
|
||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
||||
logger.record('rollout/ep_rew_mean', safe_mean([ep_info['r'] for ep_info in self.ep_info_buffer]))
|
||||
logger.record('rollout/ep_len_mean', safe_mean([ep_info['l'] for ep_info in self.ep_info_buffer]))
|
||||
logger.record("time/fps", fps)
|
||||
logger.record('time/time_elapsed', int(time.time() - self.start_time), exclude="tensorboard")
|
||||
logger.record("time/total timesteps", self.num_timesteps, exclude="tensorboard")
|
||||
if self.use_sde:
|
||||
logger.record("train/std", (self.actor.get_std()).mean().item())
|
||||
|
||||
if len(self.ep_success_buffer) > 0:
|
||||
logger.record('rollout/success rate', safe_mean(self.ep_success_buffer))
|
||||
# Pass the number of timesteps for tensorboard
|
||||
logger.dump(step=self.num_timesteps)
|
||||
|
||||
mean_reward = np.mean(episode_rewards) if total_episodes > 0 else 0.0
|
||||
|
||||
# Post processing
|
||||
if self.rollout_data is not None:
|
||||
for key in ['observations', 'actions', 'rewards', 'dones', 'values']:
|
||||
self.rollout_data[key] = th.FloatTensor(np.array(self.rollout_data[key])).to(self.device)
|
||||
|
||||
self.rollout_data['returns'] = self.rollout_data['rewards'].clone() # pytype: disable=attribute-error
|
||||
self.rollout_data['advantage'] = self.rollout_data['rewards'].clone() # pytype: disable=attribute-error
|
||||
|
||||
# Compute return and advantage
|
||||
last_return = 0.0
|
||||
for step in reversed(range(len(self.rollout_data['rewards']))):
|
||||
if step == len(self.rollout_data['rewards']) - 1:
|
||||
next_non_terminal = 1.0 - done[0]
|
||||
next_value = self.vf_net(th.FloatTensor(self._last_obs).to(self.device))[0].detach()
|
||||
last_return = self.rollout_data['rewards'][step] + next_non_terminal * next_value
|
||||
else:
|
||||
next_non_terminal = 1.0 - self.rollout_data['dones'][step + 1]
|
||||
last_return = self.rollout_data['rewards'][step] + self.gamma * last_return * next_non_terminal
|
||||
self.rollout_data['returns'][step] = last_return
|
||||
self.rollout_data['advantage'] = self.rollout_data['returns'] - self.rollout_data['values']
|
||||
|
||||
callback.on_rollout_end()
|
||||
|
||||
return RolloutReturn(mean_reward, total_steps, total_episodes, continue_training)
|
||||
|
||||
def excluded_save_params(self) -> List[str]:
|
||||
"""
|
||||
Returns the names of the parameters that should be excluded by default
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
0.7.0a1
|
||||
0.8.0a0
|
||||
|
|
|
|||
|
|
@ -14,14 +14,27 @@ def test_td3(action_noise):
|
|||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C, PPO])
|
||||
@pytest.mark.parametrize("env_id", ['CartPole-v1', 'Pendulum-v0'])
|
||||
def test_onpolicy(model_class, env_id):
|
||||
model = model_class('MlpPolicy', env_id, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True)
|
||||
def test_a2c(env_id):
|
||||
model = A2C('MlpPolicy', env_id, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True)
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ent_coef", ['auto', 0.01])
|
||||
@pytest.mark.parametrize("env_id", ['CartPole-v1', 'Pendulum-v0'])
|
||||
@pytest.mark.parametrize("clip_range_vf", [None, 0.2, -0.2])
|
||||
def test_ppo(env_id, clip_range_vf):
|
||||
if clip_range_vf is not None and clip_range_vf < 0:
|
||||
# Should throw an error
|
||||
with pytest.raises(AssertionError):
|
||||
model = PPO('MlpPolicy', env_id, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True,
|
||||
clip_range_vf=clip_range_vf)
|
||||
else:
|
||||
model = PPO('MlpPolicy', env_id, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True,
|
||||
clip_range_vf=clip_range_vf)
|
||||
model.learn(total_timesteps=1000, eval_freq=500)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ent_coef", ['auto', 0.01, 'auto_0.01'])
|
||||
def test_sac(ent_coef):
|
||||
model = SAC('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[64, 64]),
|
||||
learning_starts=100, verbose=1, create_eval_env=True, ent_coef=ent_coef,
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ def test_save_load(model_class):
|
|||
|
||||
''warning does not test function of optimizer parameter load
|
||||
|
||||
:param model_class: (BaseRLModel) A RL model
|
||||
:param model_class: (BaseAlgorithm) A RL model
|
||||
"""
|
||||
env = DummyVecEnv([lambda: IdentityEnvBox(10)])
|
||||
|
||||
|
|
@ -84,7 +84,7 @@ def test_save_load(model_class):
|
|||
def test_set_env(model_class):
|
||||
"""
|
||||
Test if set_env function does work correct
|
||||
:param model_class: (BaseRLModel) A RL model
|
||||
:param model_class: (BaseAlgorithm) A RL model
|
||||
"""
|
||||
env = DummyVecEnv([lambda: IdentityEnvBox(10)])
|
||||
env2 = DummyVecEnv([lambda: IdentityEnvBox(10)])
|
||||
|
|
@ -111,7 +111,7 @@ def test_exclude_include_saved_params(model_class):
|
|||
"""
|
||||
Test if exclude and include parameters of save() work
|
||||
|
||||
:param model_class: (BaseRLModel) A RL model
|
||||
:param model_class: (BaseAlgorithm) A RL model
|
||||
"""
|
||||
env = DummyVecEnv([lambda: IdentityEnvBox(10)])
|
||||
|
||||
|
|
@ -169,7 +169,7 @@ def test_save_load_policy(model_class, policy_str):
|
|||
"""
|
||||
Test saving and loading policy only.
|
||||
|
||||
:param model_class: (BaseRLModel) A RL model
|
||||
:param model_class: (BaseAlgorithm) A RL model
|
||||
:param policy_str: (str) Name of the policy.
|
||||
"""
|
||||
kwargs = {}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import gym
|
||||
from gym import spaces
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan
|
||||
|
||||
|
|
@ -40,32 +41,18 @@ def test_check_nan():
|
|||
|
||||
env.step([[0]])
|
||||
|
||||
try:
|
||||
with pytest.raises(ValueError):
|
||||
env.step([[float('NaN')]])
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
assert False
|
||||
|
||||
try:
|
||||
with pytest.raises(ValueError):
|
||||
env.step([[float('inf')]])
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
assert False
|
||||
|
||||
try:
|
||||
with pytest.raises(ValueError):
|
||||
env.step([[-1]])
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
assert False
|
||||
|
||||
try:
|
||||
with pytest.raises(ValueError):
|
||||
env.step([[1]])
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
assert False
|
||||
|
||||
env.step(np.array([[0, 1], [0, 1]]))
|
||||
|
||||
env.reset()
|
||||
|
|
|
|||
|
|
@ -136,7 +136,7 @@ def test_sync_vec_normalize():
|
|||
|
||||
assert unwrap_vec_normalize(env) is None
|
||||
|
||||
env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10., clip_reward=10.)
|
||||
env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=100., clip_reward=100.)
|
||||
|
||||
assert isinstance(unwrap_vec_normalize(env), VecNormalize)
|
||||
|
||||
|
|
@ -145,22 +145,33 @@ def test_sync_vec_normalize():
|
|||
assert isinstance(unwrap_vec_normalize(env), VecNormalize)
|
||||
|
||||
eval_env = DummyVecEnv([make_env])
|
||||
eval_env = VecNormalize(eval_env, training=False, norm_obs=True, norm_reward=True, clip_obs=10., clip_reward=10.)
|
||||
eval_env = VecNormalize(eval_env, training=False, norm_obs=True, norm_reward=True,
|
||||
clip_obs=100., clip_reward=100.)
|
||||
eval_env = VecFrameStack(eval_env, 1)
|
||||
|
||||
env.seed(0)
|
||||
env.action_space.seed(0)
|
||||
|
||||
env.reset()
|
||||
# Initialize running mean
|
||||
latest_reward = None
|
||||
for _ in range(100):
|
||||
env.step([env.action_space.sample()])
|
||||
_, latest_reward, _, _ = env.step([env.action_space.sample()])
|
||||
|
||||
# Check that unnormalized reward is same as original reward
|
||||
original_latest_reward = env.get_original_reward()
|
||||
assert np.allclose(original_latest_reward, env.unnormalize_reward(latest_reward))
|
||||
|
||||
obs = env.reset()
|
||||
original_obs = env.get_original_obs()
|
||||
dummy_rewards = np.random.rand(10)
|
||||
# Normalization must be different
|
||||
original_obs = env.get_original_obs()
|
||||
# Check that unnormalization works
|
||||
assert np.allclose(original_obs, env.unnormalize_obs(obs))
|
||||
# Normalization must be different (between different environments)
|
||||
assert not np.allclose(obs, eval_env.normalize_obs(original_obs))
|
||||
|
||||
# Test syncing of parameters
|
||||
sync_envs_normalization(env, eval_env)
|
||||
|
||||
# Now they must be synced
|
||||
assert np.allclose(obs, eval_env.normalize_obs(original_obs))
|
||||
assert np.allclose(env.normalize_reward(dummy_rewards), eval_env.normalize_reward(dummy_rewards))
|
||||
|
|
|
|||
Loading…
Reference in a new issue