mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-26 03:01:19 +00:00
Merge branch 'master' into sde
This commit is contained in:
commit
45fd3028c1
24 changed files with 616 additions and 242 deletions
|
|
@ -1,4 +1,4 @@
|
|||
image: stablebaselines/stable-baselines3-cpu:0.6.0a7
|
||||
image: stablebaselines/stable-baselines3-cpu:0.6.0
|
||||
|
||||
type-check:
|
||||
script:
|
||||
|
|
@ -15,5 +15,4 @@ doc-build:
|
|||
|
||||
lint-check:
|
||||
script:
|
||||
- pip install flake8 # TODO: remove when new version on Pypi
|
||||
- make lint
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ RUN \
|
|||
cd ${CODE_DIR}/stable-baselines3 3&& \
|
||||
pip install -e .[extra,tests,docs] && \
|
||||
# Use headless version for docker
|
||||
pip uninstall -y opencv-python && \
|
||||
pip install opencv-python-headless && \
|
||||
rm -rf $HOME/.cache/pip
|
||||
|
||||
|
|
|
|||
|
|
@ -34,12 +34,12 @@ These algorithms will make it easier for the research community and industry to
|
|||
| Custom policies | :heavy_check_mark: |
|
||||
| Common interface | :heavy_check_mark: |
|
||||
| Ipython / Notebook friendly | :heavy_check_mark: |
|
||||
| Tensorboard support | :heavy_check_mark: |
|
||||
| PEP8 code style | :heavy_check_mark: |
|
||||
| Custom callback | :heavy_check_mark: |
|
||||
| High code coverage | :heavy_check_mark: |
|
||||
| Type hints | :heavy_check_mark: |
|
||||
|
||||
<!-- | Tensorboard support | :heavy_check_mark: | -->
|
||||
|
||||
### Roadmap to V1.0
|
||||
|
||||
|
|
@ -49,11 +49,9 @@ Planned features:
|
|||
- [ ] DQN (almost ready, currently in testing phase)
|
||||
- [ ] DDPG (you can use its successor TD3 for now)
|
||||
- [ ] HER
|
||||
- [ ] Support for MultiDiscrete and MultiBinary action spaces
|
||||
|
||||
### Planned features (v1.1+)
|
||||
|
||||
- [ ] Full Tensorboard support
|
||||
- [ ] DQN extensions (prioritized replay, double q-learning, ...)
|
||||
- [ ] Support for `Tuple` and `Dict` observation spaces
|
||||
- [ ] Recurrent Policies
|
||||
|
|
@ -104,7 +102,7 @@ Install the Stable Baselines3 package:
|
|||
pip install stable-baselines3[extra]
|
||||
```
|
||||
|
||||
This includes an optional dependencies like OpenCV or `atari-py` to train on atari games. If you do not need those, you can use:
|
||||
This includes an optional dependencies like Tensorboard, OpenCV or `atari-py` to train on atari games. If you do not need those, you can use:
|
||||
```
|
||||
pip install stable-baselines3
|
||||
```
|
||||
|
|
|
|||
BIN
docs/_static/img/Tensorboard_example.png
vendored
Normal file
BIN
docs/_static/img/Tensorboard_example.png
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 233 KiB |
|
|
@ -44,7 +44,7 @@ This will give you access to events (``_on_training_start``, ``_on_step``) and u
|
|||
# self.locals = None # type: Dict[str, Any]
|
||||
# self.globals = None # type: Dict[str, Any]
|
||||
# The logger object, used to report things in the terminal
|
||||
# self.logger = None # type: logger.Logger
|
||||
# self.logger = None # stable_baselines3.common.logger
|
||||
# # Sometimes, for event callback, it is useful
|
||||
# # to have access to the parent object
|
||||
# self.parent = None # type: Optional[BaseCallback]
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ To install Stable Baselines3 with pip, execute:
|
|||
|
||||
pip install stable-baselines3[extra]
|
||||
|
||||
This includes an optional dependencies like OpenCV or ```atari-py``` to train on atari games. If you do not need those, you can use:
|
||||
This includes an optional dependencies like Tensorboard, OpenCV or ```atari-py``` to train on atari games. If you do not need those, you can use:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
|
|
|
|||
82
docs/guide/tensorboard.rst
Normal file
82
docs/guide/tensorboard.rst
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
.. _tensorboard:
|
||||
|
||||
Tensorboard Integration
|
||||
=======================
|
||||
|
||||
Basic Usage
|
||||
------------
|
||||
|
||||
To use Tensorboard with stable baselines3, you simply need to pass the location of the log folder to the RL agent:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from stable_baselines3 import A2C
|
||||
|
||||
model = A2C('MlpPolicy', 'CartPole-v1', verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
|
||||
model.learn(total_timesteps=10000)
|
||||
|
||||
|
||||
You can also define custom logging name when training (by default it is the algorithm name)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from stable_baselines3 import A2C
|
||||
|
||||
model = A2C('MlpPolicy', 'CartPole-v1', verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
|
||||
model.learn(total_timesteps=10000, tb_log_name="first_run")
|
||||
# Pass reset_num_timesteps=False to continue the training curve in tensorboard
|
||||
# By default, it will create a new curve
|
||||
model.learn(total_timesteps=10000, tb_log_name="second_run", reset_num_timesteps=False)
|
||||
model.learn(total_timesteps=10000, tb_log_name="third_run", reset_num_timesteps=False)
|
||||
|
||||
|
||||
Once the learn function is called, you can monitor the RL agent during or after the training, with the following bash command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
tensorboard --logdir ./a2c_cartpole_tensorboard/
|
||||
|
||||
you can also add past logging folders:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
tensorboard --logdir ./a2c_cartpole_tensorboard/;./ppo2_cartpole_tensorboard/
|
||||
|
||||
It will display information such as the episode reward (when using a ``Monitor`` wrapper), the model losses and other parameter unique to some models.
|
||||
|
||||
.. image:: ../_static/img/Tensorboard_example.png
|
||||
:width: 600
|
||||
:alt: plotting
|
||||
|
||||
Logging More Values
|
||||
-------------------
|
||||
|
||||
Using a callback, you can easily log more values with TensorBoard.
|
||||
Here is a simple example on how to log both additional tensor or arbitrary scalar value:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import numpy as np
|
||||
|
||||
from stable_baselines3 import SAC
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
|
||||
model = SAC("MlpPolicy", "Pendulum-v0", tensorboard_log="/tmp/sac/", verbose=1)
|
||||
|
||||
|
||||
class TensorboardCallback(BaseCallback):
|
||||
"""
|
||||
Custom callback for plotting additional values in tensorboard.
|
||||
"""
|
||||
|
||||
def __init__(self, verbose=0):
|
||||
super(TensorboardCallback, self).__init__(verbose)
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
# Log scalar value (here a random variable)
|
||||
value = np.random.random()
|
||||
self.logger.record('random_value', value)
|
||||
return True
|
||||
|
||||
|
||||
model.learn(50000, callback=TensorboardCallback())
|
||||
|
|
@ -25,7 +25,7 @@ Main Features
|
|||
- Documented functions and classes
|
||||
- Tests, high code coverage and type hints
|
||||
- Clean code
|
||||
|
||||
- Tensorboard support
|
||||
|
||||
|
||||
.. toctree::
|
||||
|
|
@ -42,6 +42,7 @@ Main Features
|
|||
guide/custom_env
|
||||
guide/custom_policy
|
||||
guide/callbacks
|
||||
guide/tensorboard
|
||||
guide/rl_zoo
|
||||
guide/migration
|
||||
guide/checking_nan
|
||||
|
|
|
|||
|
|
@ -3,11 +3,17 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Pre-Release 0.6.0a10 (WIP)
|
||||
Pre-Release 0.6.0 (2020-06-01)
|
||||
------------------------------
|
||||
|
||||
**Tensorboard support, refactored logger**
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
- Methods were renamed in the logger:
|
||||
- ``logkv`` -> ``record``, ``writekvs`` -> ``write``, ``writeseq`` -> ``write_sequence``,
|
||||
- ``logkvs`` -> ``record_dict``, ``dumpkvs`` -> ``dump``,
|
||||
- ``getkvs`` -> ``get_log_dict``, ``logkv_mean`` -> ``record_mean``
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -17,6 +23,9 @@ New Features:
|
|||
- Added ``cmd_util`` and ``atari_wrappers``
|
||||
- Added support for ``MultiDiscrete`` and ``MultiBinary`` observation spaces (@rolandgvc)
|
||||
- Added ``MultiCategorical`` and ``Bernoulli`` distributions for PPO/A2C (@rolandgvc)
|
||||
- Added support for logging to tensorboard (@rolandgvc)
|
||||
- Added ``VectorizedActionNoise`` for continuous vectorized environments (@PartiallyTyped)
|
||||
- Log evaluation in the ``EvalCallback`` using the logger
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
@ -24,6 +33,7 @@ Bug Fixes:
|
|||
- Fixed version number that had a new line included
|
||||
- Fixed weird seg fault in docker image due to FakeImageEnv by reducing screen size
|
||||
- Fixed ``sde_sample_freq`` that was not taken into account for SAC
|
||||
- Pass logger module to ``BaseCallback`` otherwise they cannot write in the one used by the algorithms
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -37,12 +47,14 @@ Others:
|
|||
- Added ``.readthedoc.yml`` file
|
||||
- Added ``flake8`` and ``make lint`` command
|
||||
- Added Github workflow
|
||||
- Added warning when passing both ``train_freq`` and ``n_episodes_rollout`` to Off-Policy Algorithms
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
- Added most documentation (adapted from Stable-Baselines)
|
||||
- Added link to CONTRIBUTING.md in the README (@kinalmehta)
|
||||
- Added gSDE project and update docstrings accordingly
|
||||
- Fix ``TD3`` example code block
|
||||
|
||||
|
||||
Pre-Release 0.5.0 (2020-05-05)
|
||||
|
|
@ -228,4 +240,4 @@ And all the contributors:
|
|||
@XMaster96 @kantneel @Pastafarianist @GerardMaggiolino @PatrickWalter214 @yutingsz @sc420 @Aaahh @billtubbs
|
||||
@Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket
|
||||
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching
|
||||
@flodorner @KuKuXia @NeoExtended @solliet @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur
|
||||
@flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3
|
||||
|
|
|
|||
|
|
@ -62,17 +62,20 @@ Example
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
from stable_baselines3 import TD3
|
||||
from stable_baselines3.td3.policies import MlpPolicy
|
||||
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
|
||||
|
||||
env = gym.make('Pendulum-v0')
|
||||
|
||||
# The noise objects for TD3
|
||||
n_actions = env.action_space.shape[-1]
|
||||
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
|
||||
|
||||
model = TD3(MlpPolicy, 'Pendulum-v0', action_noise=action_noise, verbose=1)
|
||||
model = TD3(MlpPolicy, env, action_noise=action_noise, verbose=1)
|
||||
model.learn(total_timesteps=10000, log_interval=10)
|
||||
model.save("td3_pendulum")
|
||||
env = model.get_env()
|
||||
|
|
|
|||
4
setup.py
4
setup.py
|
|
@ -106,7 +106,9 @@ setup(name='stable_baselines3',
|
|||
# For render
|
||||
'opencv-python',
|
||||
# For atari games,
|
||||
'atari_py~=0.2.0', 'pillow'
|
||||
'atari_py~=0.2.0', 'pillow',
|
||||
# Tensorboard support
|
||||
'tensorboard'
|
||||
]
|
||||
},
|
||||
description='Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.',
|
||||
|
|
|
|||
|
|
@ -141,13 +141,13 @@ class A2C(PPO):
|
|||
self.rollout_buffer.values.flatten())
|
||||
|
||||
self._n_updates += 1
|
||||
logger.logkv("n_updates", self._n_updates)
|
||||
logger.logkv("explained_variance", explained_var)
|
||||
logger.logkv("entropy_loss", entropy_loss.item())
|
||||
logger.logkv("policy_loss", policy_loss.item())
|
||||
logger.logkv("value_loss", value_loss.item())
|
||||
logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
|
||||
logger.record("train/explained_variance", explained_var)
|
||||
logger.record("train/entropy_loss", entropy_loss.item())
|
||||
logger.record("train/policy_loss", policy_loss.item())
|
||||
logger.record("train/value_loss", value_loss.item())
|
||||
if hasattr(self.policy, 'log_std'):
|
||||
logger.logkv("std", th.exp(self.policy.log_std).mean().item())
|
||||
logger.record("train/std", th.exp(self.policy.log_std).mean().item())
|
||||
|
||||
def learn(self,
|
||||
total_timesteps: int,
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import os
|
|||
import io
|
||||
import zipfile
|
||||
import pickle
|
||||
import warnings
|
||||
from typing import Union, Type, Optional, Dict, Any, List, Tuple, Callable
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
|
|
@ -11,7 +12,7 @@ import gym
|
|||
import torch as th
|
||||
import numpy as np
|
||||
|
||||
from stable_baselines3.common import logger
|
||||
from stable_baselines3.common import logger, utils
|
||||
from stable_baselines3.common.policies import BasePolicy, get_policy_from_name
|
||||
from stable_baselines3.common.utils import set_random_seed, get_schedule_fn, update_learning_rate, get_device
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, unwrap_vec_normalize, VecNormalize, VecTransposeImage
|
||||
|
|
@ -35,6 +36,7 @@ class BaseRLModel(ABC):
|
|||
:param learning_rate: (float or callable) learning rate for the optimizer,
|
||||
it can be a function of the current progress (from 1 to 0)
|
||||
:param policy_kwargs: (Dict[str, Any]) Additional arguments to be passed to the policy on creation
|
||||
:param tensorboard_log: (str) the log location for tensorboard (if None, no logging)
|
||||
:param verbose: (int) The verbosity level: 0 none, 1 training information, 2 debug
|
||||
:param device: (Union[th.device, str]) Device on which the code should run.
|
||||
By default, it will try to use a Cuda compatible device and fallback to cpu
|
||||
|
|
@ -58,6 +60,7 @@ class BaseRLModel(ABC):
|
|||
policy_base: Type[BasePolicy],
|
||||
learning_rate: Union[float, Callable],
|
||||
policy_kwargs: Dict[str, Any] = None,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
verbose: int = 0,
|
||||
device: Union[th.device, str] = 'auto',
|
||||
support_multi_env: bool = False,
|
||||
|
|
@ -91,6 +94,7 @@ class BaseRLModel(ABC):
|
|||
self.start_time = None
|
||||
self.policy = None
|
||||
self.learning_rate = learning_rate
|
||||
self.tensorboard_log = tensorboard_log
|
||||
self.lr_schedule = None # type: Optional[Callable]
|
||||
self._last_obs = None # type: Optional[np.ndarray]
|
||||
# When using VecNormalize:
|
||||
|
|
@ -191,7 +195,7 @@ class BaseRLModel(ABC):
|
|||
An optimizer or a list of optimizers.
|
||||
"""
|
||||
# Log the current learning rate
|
||||
logger.logkv("learning_rate", self.lr_schedule(self._current_progress))
|
||||
logger.record("train/learning_rate", self.lr_schedule(self._current_progress))
|
||||
|
||||
if not isinstance(optimizers, list):
|
||||
optimizers = [optimizers]
|
||||
|
|
@ -289,7 +293,7 @@ class BaseRLModel(ABC):
|
|||
"""
|
||||
Return a trained model.
|
||||
|
||||
:param total_timesteps: (int) The total number of samples to train on
|
||||
:param total_timesteps: (int) The total number of samples (env steps) to train on
|
||||
:param callback: (function (dict, dict)) -> boolean function called at every steps with state of the algorithm.
|
||||
It takes the local and global variables. If it returns False, training is aborted.
|
||||
:param log_interval: (int) The number of timesteps before logging.
|
||||
|
|
@ -491,23 +495,27 @@ class BaseRLModel(ABC):
|
|||
return callback
|
||||
|
||||
def _setup_learn(self,
|
||||
total_timesteps: int,
|
||||
eval_env: Optional[GymEnv],
|
||||
callback: Union[None, Callable, List[BaseCallback], BaseCallback] = None,
|
||||
eval_freq: int = 10000,
|
||||
n_eval_episodes: int = 5,
|
||||
log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
) -> 'BaseCallback':
|
||||
tb_log_name: str = 'run',
|
||||
) -> Tuple[int, 'BaseCallback']:
|
||||
"""
|
||||
Initialize different variables needed for training.
|
||||
|
||||
:param total_timesteps: (int) The total number of samples (env steps) to train on
|
||||
:param eval_env: (Optional[GymEnv])
|
||||
:param callback: (Union[None, BaseCallback, List[BaseCallback, Callable]])
|
||||
:param eval_freq: (int)
|
||||
:param n_eval_episodes: (int)
|
||||
:param log_path (Optional[str]): Path to a log folder
|
||||
:param reset_num_timesteps: (bool) Whether to reset or not the ``num_timesteps`` attribute
|
||||
:return: (BaseCallback)
|
||||
:param tb_log_name: (str) the name of the run for tensorboard log
|
||||
:return: (int, Tuple[BaseCallback])
|
||||
"""
|
||||
self.start_time = time.time()
|
||||
self.ep_info_buffer = deque(maxlen=100)
|
||||
|
|
@ -519,6 +527,9 @@ class BaseRLModel(ABC):
|
|||
if reset_num_timesteps:
|
||||
self.num_timesteps = 0
|
||||
self._episode_num = 0
|
||||
else:
|
||||
# Make sure training timesteps are ahead of the internal counter
|
||||
total_timesteps += self.num_timesteps
|
||||
|
||||
# Avoid resetting the environment when calling ``.learn()`` consecutive times
|
||||
if reset_num_timesteps or self._last_obs is None:
|
||||
|
|
@ -532,10 +543,13 @@ class BaseRLModel(ABC):
|
|||
|
||||
eval_env = self._get_eval_env(eval_env)
|
||||
|
||||
# Configure logger's outputs
|
||||
utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps)
|
||||
|
||||
# Create eval callback if needed
|
||||
callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path)
|
||||
|
||||
return callback
|
||||
return total_timesteps, callback
|
||||
|
||||
def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.ndarray] = None) -> None:
|
||||
"""
|
||||
|
|
@ -696,6 +710,7 @@ class OffPolicyRLModel(BaseRLModel):
|
|||
learning_starts: int = 100,
|
||||
batch_size: int = 256,
|
||||
policy_kwargs: Dict[str, Any] = None,
|
||||
tensorboard_log: Optional[str] = None,
|
||||
verbose: int = 0,
|
||||
device: Union[th.device, str] = 'auto',
|
||||
support_multi_env: bool = False,
|
||||
|
|
@ -707,13 +722,13 @@ class OffPolicyRLModel(BaseRLModel):
|
|||
use_sde_at_warmup: bool = False):
|
||||
|
||||
super(OffPolicyRLModel, self).__init__(policy, env, policy_base, learning_rate,
|
||||
policy_kwargs, verbose,
|
||||
policy_kwargs, tensorboard_log, verbose,
|
||||
device, support_multi_env, create_eval_env, monitor_wrapper,
|
||||
seed, use_sde, sde_sample_freq)
|
||||
self.buffer_size = buffer_size
|
||||
self.batch_size = batch_size
|
||||
self.learning_starts = learning_starts
|
||||
self.actor = None
|
||||
self.actor = None # type: Optional[th.nn.Module]
|
||||
self.replay_buffer = None # type: Optional[ReplayBuffer]
|
||||
# Update policy keyword arguments
|
||||
self.policy_kwargs['use_sde'] = self.use_sde
|
||||
|
|
@ -751,7 +766,7 @@ class OffPolicyRLModel(BaseRLModel):
|
|||
self.replay_buffer = pickle.load(file_handler)
|
||||
assert isinstance(self.replay_buffer, ReplayBuffer), 'The replay buffer must inherit from ReplayBuffer class'
|
||||
|
||||
def collect_rollouts(self,
|
||||
def collect_rollouts(self, # noqa: C901
|
||||
env: VecEnv,
|
||||
# Type hint as string to avoid circular import
|
||||
callback: 'BaseCallback',
|
||||
|
|
@ -882,23 +897,23 @@ class OffPolicyRLModel(BaseRLModel):
|
|||
if action_noise is not None:
|
||||
action_noise.reset()
|
||||
|
||||
# Display training infos
|
||||
if self.verbose >= 1 and log_interval is not None and self._episode_num % log_interval == 0:
|
||||
# Log training infos
|
||||
if log_interval is not None and self._episode_num % log_interval == 0:
|
||||
fps = int(self.num_timesteps / (time.time() - self.start_time))
|
||||
logger.logkv("episodes", self._episode_num)
|
||||
logger.record("time/episodes", self._episode_num, exclude="tensorboard")
|
||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
||||
logger.logkv('ep_rew_mean', self.safe_mean([ep_info['r'] for ep_info in self.ep_info_buffer]))
|
||||
logger.logkv('ep_len_mean', self.safe_mean([ep_info['l'] for ep_info in self.ep_info_buffer]))
|
||||
# logger.logkv("n_updates", n_updates)
|
||||
logger.logkv("fps", fps)
|
||||
logger.logkv('time_elapsed', int(time.time() - self.start_time))
|
||||
logger.logkv("total timesteps", self.num_timesteps)
|
||||
logger.record('rollout/ep_rew_mean', self.safe_mean([ep_info['r'] for ep_info in self.ep_info_buffer]))
|
||||
logger.record('rollout/ep_len_mean', self.safe_mean([ep_info['l'] for ep_info in self.ep_info_buffer]))
|
||||
logger.record("time/fps", fps)
|
||||
logger.record('time/time_elapsed', int(time.time() - self.start_time), exclude="tensorboard")
|
||||
logger.record("time/total timesteps", self.num_timesteps, exclude="tensorboard")
|
||||
if self.use_sde:
|
||||
logger.logkv("std", (self.actor.get_std()).mean().item())
|
||||
logger.record("train/std", (self.actor.get_std()).mean().item())
|
||||
|
||||
if len(self.ep_success_buffer) > 0:
|
||||
logger.logkv('success rate', self.safe_mean(self.ep_success_buffer))
|
||||
logger.dumpkvs()
|
||||
logger.record('rollout/success rate', self.safe_mean(self.ep_success_buffer))
|
||||
# Pass the number of timesteps for tensorboard
|
||||
logger.dump(step=self.num_timesteps)
|
||||
|
||||
mean_reward = np.mean(episode_rewards) if total_episodes > 0 else 0.0
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import numpy as np
|
|||
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.common.logger import Logger
|
||||
from stable_baselines3.common import logger
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from stable_baselines3.common.base_class import BaseRLModel # pytype: disable=pyi-error
|
||||
|
|
@ -34,7 +34,7 @@ class BaseCallback(ABC):
|
|||
self.verbose = verbose
|
||||
self.locals = None # type: Optional[Dict[str, Any]]
|
||||
self.globals = None # type: Optional[Dict[str, Any]]
|
||||
self.logger = None # type: Optional[Logger]
|
||||
self.logger = None
|
||||
# Sometimes, for event callback, it is useful
|
||||
# to have access to the parent object
|
||||
self.parent = None # type: Optional[BaseCallback]
|
||||
|
|
@ -47,7 +47,7 @@ class BaseCallback(ABC):
|
|||
"""
|
||||
self.model = model
|
||||
self.training_env = model.get_env()
|
||||
self.logger = Logger.CURRENT
|
||||
self.logger = logger
|
||||
self._init_callback()
|
||||
|
||||
def _init_callback(self) -> None:
|
||||
|
|
@ -313,6 +313,9 @@ class EvalCallback(EventCallback):
|
|||
print(f"Eval num_timesteps={self.num_timesteps}, "
|
||||
f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}")
|
||||
print(f"Episode length: {mean_ep_length:.2f} +/- {std_ep_length:.2f}")
|
||||
# Add to current Logger
|
||||
self.logger.record('eval/mean_reward', float(mean_reward))
|
||||
self.logger.record('eval/mean_ep_length', mean_ep_length)
|
||||
|
||||
if mean_reward > self.best_mean_reward:
|
||||
if self.verbose > 0:
|
||||
|
|
|
|||
|
|
@ -5,9 +5,15 @@ import os
|
|||
import tempfile
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, TextIO, Union, Any, Optional
|
||||
from typing import Dict, List, TextIO, Union, Any, Optional, Tuple
|
||||
|
||||
import pandas
|
||||
import numpy as np
|
||||
import torch as th
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
except ImportError:
|
||||
SummaryWriter = None
|
||||
|
||||
DEBUG = 10
|
||||
INFO = 20
|
||||
|
|
@ -21,11 +27,14 @@ class KVWriter(object):
|
|||
Key Value writer
|
||||
"""
|
||||
|
||||
def writekvs(self, kvs: Dict) -> None:
|
||||
def write(self, key_values: Dict[str, Any],
|
||||
key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
|
||||
"""
|
||||
write a dictionary to file
|
||||
Write a dictionary to file
|
||||
|
||||
:param kvs: (dict)
|
||||
:param key_values: (dict)
|
||||
:param key_excluded: (dict)
|
||||
:param step: (int)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
@ -41,11 +50,11 @@ class SeqWriter(object):
|
|||
sequence writer
|
||||
"""
|
||||
|
||||
def writeseq(self, seq: List):
|
||||
def write_sequence(self, sequence: List):
|
||||
"""
|
||||
write an array to file
|
||||
write_sequence an array to file
|
||||
|
||||
:param seq: (list)
|
||||
:param sequence: (list)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
@ -65,16 +74,29 @@ class HumanOutputFormat(KVWriter, SeqWriter):
|
|||
self.file = filename_or_file
|
||||
self.own_file = False
|
||||
|
||||
def writekvs(self, kvs: Dict) -> None:
|
||||
def write(self, key_values: Dict, key_excluded: Dict, step: int = 0) -> None:
|
||||
# Create strings for printing
|
||||
key2str = {}
|
||||
for (key, val) in sorted(kvs.items()):
|
||||
if isinstance(val, float):
|
||||
tag = None
|
||||
for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())):
|
||||
|
||||
if excluded is not None and 'stdout' in excluded:
|
||||
continue
|
||||
|
||||
if isinstance(value, float):
|
||||
# Align left
|
||||
val_str = f'{val:<8.3g}'
|
||||
value_str = f'{value:<8.3g}'
|
||||
else:
|
||||
val_str = str(val)
|
||||
key2str[self._truncate(key)] = self._truncate(val_str)
|
||||
value_str = str(value)
|
||||
|
||||
if key.find('/') > 0: # Find tag and add it to the dict
|
||||
tag = key[:key.find('/') + 1]
|
||||
key2str[self._truncate(tag)] = ''
|
||||
# Remove tag from key
|
||||
if tag is not None and tag in key:
|
||||
key = str(' ' + key[len(tag):])
|
||||
|
||||
key2str[self._truncate(key)] = self._truncate(value_str)
|
||||
|
||||
# Find max widths
|
||||
if len(key2str) == 0:
|
||||
|
|
@ -87,10 +109,10 @@ class HumanOutputFormat(KVWriter, SeqWriter):
|
|||
# Write out the data
|
||||
dashes = '-' * (key_width + val_width + 7)
|
||||
lines = [dashes]
|
||||
for (key, val) in sorted(key2str.items()):
|
||||
for key, value in key2str.items():
|
||||
key_space = ' ' * (key_width - len(key))
|
||||
val_space = ' ' * (val_width - len(val))
|
||||
lines.append(f"| {key}{key_space} | {val}{val_space} |")
|
||||
val_space = ' ' * (val_width - len(value))
|
||||
lines.append(f"| {key}{key_space} | {value}{val_space} |")
|
||||
lines.append(dashes)
|
||||
self.file.write('\n'.join(lines) + '\n')
|
||||
|
||||
|
|
@ -98,14 +120,14 @@ class HumanOutputFormat(KVWriter, SeqWriter):
|
|||
self.file.flush()
|
||||
|
||||
@classmethod
|
||||
def _truncate(cls, string: str) -> str:
|
||||
return string[:20] + '...' if len(string) > 23 else string
|
||||
def _truncate(cls, string: str, max_length: int = 23) -> str:
|
||||
return string[:max_length - 3] + '...' if len(string) > max_length else string
|
||||
|
||||
def writeseq(self, seq: List) -> None:
|
||||
seq = list(seq)
|
||||
for (i, elem) in enumerate(seq):
|
||||
def write_sequence(self, sequence: List) -> None:
|
||||
sequence = list(sequence)
|
||||
for i, elem in enumerate(sequence):
|
||||
self.file.write(elem)
|
||||
if i < len(seq) - 1: # add space unless this is the last one
|
||||
if i < len(sequence) - 1: # add space unless this is the last one
|
||||
self.file.write(' ')
|
||||
self.file.write('\n')
|
||||
self.file.flush()
|
||||
|
|
@ -127,22 +149,28 @@ class JSONOutputFormat(KVWriter):
|
|||
"""
|
||||
self.file = open(filename, 'wt')
|
||||
|
||||
def writekvs(self, kvs: Dict) -> None:
|
||||
for key, value in sorted(kvs.items()):
|
||||
def write(self, key_values: Dict[str, Any],
|
||||
key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
|
||||
for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())):
|
||||
|
||||
if excluded is not None and 'json' in excluded:
|
||||
continue
|
||||
|
||||
if hasattr(value, 'dtype'):
|
||||
if value.shape == () or len(value) == 1:
|
||||
# if value is a dimensionless numpy array or of length 1, serialize as a float
|
||||
kvs[key] = float(value)
|
||||
key_values[key] = float(value)
|
||||
else:
|
||||
# otherwise, a value is a numpy array, serialize as a list or nested lists
|
||||
kvs[key] = value.tolist()
|
||||
self.file.write(json.dumps(kvs) + '\n')
|
||||
key_values[key] = value.tolist()
|
||||
self.file.write(json.dumps(key_values) + '\n')
|
||||
self.file.flush()
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
closes the file
|
||||
"""
|
||||
|
||||
self.file.close()
|
||||
|
||||
|
||||
|
|
@ -153,13 +181,15 @@ class CSVOutputFormat(KVWriter):
|
|||
|
||||
:param filename: (str) the file to write the log to
|
||||
"""
|
||||
|
||||
self.file = open(filename, 'w+t')
|
||||
self.keys = []
|
||||
self.sep = ','
|
||||
self.separator = ','
|
||||
|
||||
def writekvs(self, kvs: Dict) -> None:
|
||||
def write(self, key_values: Dict[str, Any],
|
||||
key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
|
||||
# Add our current row to the history
|
||||
extra_keys = kvs.keys() - self.keys
|
||||
extra_keys = key_values.keys() - self.keys
|
||||
if extra_keys:
|
||||
self.keys.extend(extra_keys)
|
||||
self.file.seek(0)
|
||||
|
|
@ -172,12 +202,12 @@ class CSVOutputFormat(KVWriter):
|
|||
self.file.write('\n')
|
||||
for line in lines[1:]:
|
||||
self.file.write(line[:-1])
|
||||
self.file.write(self.sep * len(extra_keys))
|
||||
self.file.write(self.separator * len(extra_keys))
|
||||
self.file.write('\n')
|
||||
for i, key in enumerate(self.keys):
|
||||
if i > 0:
|
||||
self.file.write(',')
|
||||
value = kvs.get(key)
|
||||
value = key_values.get(key)
|
||||
if value is not None:
|
||||
self.file.write(str(value))
|
||||
self.file.write('\n')
|
||||
|
|
@ -190,25 +220,49 @@ class CSVOutputFormat(KVWriter):
|
|||
self.file.close()
|
||||
|
||||
|
||||
def valid_float_value(value: Any) -> bool:
|
||||
"""
|
||||
Returns True if the value can be successfully cast into a float
|
||||
class TensorBoardOutputFormat(KVWriter):
|
||||
def __init__(self, folder: str):
|
||||
"""
|
||||
Dumps key/value pairs into TensorBoard's numeric format.
|
||||
|
||||
:param value: (Any) the value to check
|
||||
:return: (bool)
|
||||
"""
|
||||
try:
|
||||
float(value)
|
||||
return True
|
||||
except TypeError:
|
||||
return False
|
||||
:param folder: (str) the folder to write the log to
|
||||
"""
|
||||
assert SummaryWriter is not None, ("tensorboard is not installed, you can use "
|
||||
"pip install tensorboard to do so")
|
||||
self.writer = SummaryWriter(log_dir=folder)
|
||||
|
||||
def write(self, key_values: Dict[str, Any],
|
||||
key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
|
||||
|
||||
for (key, value), (_, excluded) in zip(sorted(key_values.items()),
|
||||
sorted(key_excluded.items())):
|
||||
|
||||
if excluded is not None and 'tensorboard' in excluded:
|
||||
continue
|
||||
|
||||
if isinstance(value, np.ScalarType):
|
||||
self.writer.add_scalar(key, value, step)
|
||||
|
||||
if isinstance(value, th.Tensor):
|
||||
self.writer.add_histogram(key, value, step)
|
||||
|
||||
# Flush the output to the file
|
||||
self.writer.flush()
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
closes the file
|
||||
"""
|
||||
if self.writer:
|
||||
self.writer.close()
|
||||
self.writer = None
|
||||
|
||||
|
||||
def make_output_format(_format: str, log_dir: str, log_suffix: str = '') -> KVWriter:
|
||||
"""
|
||||
return a logger for the requested format
|
||||
|
||||
:param _format: (str) the requested format to log to ('stdout', 'log', 'json' or 'csv')
|
||||
:param _format: (str) the requested format to log to ('stdout', 'log', 'json' or 'csv' or 'tensorboard')
|
||||
:param log_dir: (str) the logging directory
|
||||
:param log_suffix: (str) the suffix for the log file
|
||||
:return: (KVWriter) the logger
|
||||
|
|
@ -222,6 +276,8 @@ def make_output_format(_format: str, log_dir: str, log_suffix: str = '') -> KVWr
|
|||
return JSONOutputFormat(os.path.join(log_dir, f'progress{log_suffix}.json'))
|
||||
elif _format == 'csv':
|
||||
return CSVOutputFormat(os.path.join(log_dir, f'progress{log_suffix}.csv'))
|
||||
elif _format == 'tensorboard':
|
||||
return TensorBoardOutputFormat(log_dir)
|
||||
else:
|
||||
raise ValueError(f'Unknown format specified: {_format}')
|
||||
|
||||
|
|
@ -230,52 +286,56 @@ def make_output_format(_format: str, log_dir: str, log_suffix: str = '') -> KVWr
|
|||
# API
|
||||
# ================================================================
|
||||
|
||||
def logkv(key: Any, val: Any) -> None:
|
||||
def record(key: str, value: Any,
|
||||
exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None:
|
||||
"""
|
||||
Log a value of some diagnostic
|
||||
Call this once for each diagnostic quantity, each iteration
|
||||
If called many times, last value will be used.
|
||||
|
||||
:param key: (Any) save to log this key
|
||||
:param val: (Any) save to log this value
|
||||
:param value: (Any) save to log this value
|
||||
:param exclude: (str or tuple) outputs to be excluded
|
||||
"""
|
||||
Logger.CURRENT.logkv(key, val)
|
||||
Logger.CURRENT.record(key, value, exclude)
|
||||
|
||||
|
||||
def logkv_mean(key: Any, val: Union[int, float]) -> None:
|
||||
def record_mean(key: str, value: Union[int, float],
|
||||
exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None:
|
||||
"""
|
||||
The same as logkv(), but if called many times, values averaged.
|
||||
The same as record(), but if called many times, values averaged.
|
||||
|
||||
:param key: (Any) save to log this key
|
||||
:param val: (Number) save to log this value
|
||||
:param value: (Number) save to log this value
|
||||
:param exclude: (str or tuple) outputs to be excluded
|
||||
"""
|
||||
Logger.CURRENT.logkv_mean(key, val)
|
||||
Logger.CURRENT.record_mean(key, value, exclude)
|
||||
|
||||
|
||||
def logkvs(key_values: Dict) -> None:
|
||||
def record_dict(key_values: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Log a dictionary of key-value pairs
|
||||
Log a dictionary of key-value pairs.
|
||||
|
||||
:param key_values: (dict) the list of keys and values to save to log
|
||||
"""
|
||||
for key, value in key_values.items():
|
||||
logkv(key, value)
|
||||
record(key, value)
|
||||
|
||||
|
||||
def dumpkvs() -> None:
|
||||
def dump(step: int = 0) -> None:
|
||||
"""
|
||||
Write all of the diagnostics from the current iteration
|
||||
"""
|
||||
Logger.CURRENT.dumpkvs()
|
||||
Logger.CURRENT.dump(step)
|
||||
|
||||
|
||||
def getkvs() -> Dict:
|
||||
def get_log_dict() -> Dict:
|
||||
"""
|
||||
get the key values logs
|
||||
|
||||
:return: (dict) the logged values
|
||||
"""
|
||||
return Logger.CURRENT.name2val
|
||||
return Logger.CURRENT.name_to_value
|
||||
|
||||
|
||||
def log(*args, level: int = INFO) -> None:
|
||||
|
|
@ -363,8 +423,8 @@ def get_dir() -> str:
|
|||
return Logger.CURRENT.get_dir()
|
||||
|
||||
|
||||
record_tabular = logkv
|
||||
dump_tabular = dumpkvs
|
||||
record_tabular = record
|
||||
dump_tabular = dump
|
||||
|
||||
|
||||
# ================================================================
|
||||
|
|
@ -384,50 +444,59 @@ class Logger(object):
|
|||
:param folder: (str) the logging location
|
||||
:param output_formats: ([str]) the list of output format
|
||||
"""
|
||||
self.name2val = defaultdict(float) # values this iteration
|
||||
self.name2cnt = defaultdict(int)
|
||||
self.name_to_value = defaultdict(float) # values this iteration
|
||||
self.name_to_count = defaultdict(int)
|
||||
self.name_to_excluded = defaultdict(str)
|
||||
self.level = INFO
|
||||
self.dir = folder
|
||||
self.output_formats = output_formats
|
||||
|
||||
# Logging API, forwarded
|
||||
# ----------------------------------------
|
||||
def logkv(self, key: Any, val: Any) -> None:
|
||||
def record(self, key: str, value: Any,
|
||||
exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None:
|
||||
"""
|
||||
Log a value of some diagnostic
|
||||
Call this once for each diagnostic quantity, each iteration
|
||||
If called many times, last value will be used.
|
||||
|
||||
:param key: (Any) save to log this key
|
||||
:param val: (Any) save to log this value
|
||||
:param value: (Any) save to log this value
|
||||
:param exclude: (str or tuple) outputs to be excluded
|
||||
"""
|
||||
self.name2val[key] = val
|
||||
self.name_to_value[key] = value
|
||||
self.name_to_excluded[key] = exclude
|
||||
|
||||
def logkv_mean(self, key: Any, val: Any) -> None:
|
||||
def record_mean(self, key: str, value: Any,
|
||||
exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None:
|
||||
"""
|
||||
The same as logkv(), but if called many times, values averaged.
|
||||
The same as record(), but if called many times, values averaged.
|
||||
|
||||
:param key: (Any) save to log this key
|
||||
:param val: (Number) save to log this value
|
||||
:param value: (Number) save to log this value
|
||||
:param exclude: (str or tuple) outputs to be excluded
|
||||
"""
|
||||
if val is None:
|
||||
self.name2val[key] = None
|
||||
if value is None:
|
||||
self.name_to_value[key] = None
|
||||
return
|
||||
oldval, cnt = self.name2val[key], self.name2cnt[key]
|
||||
self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
|
||||
self.name2cnt[key] = cnt + 1
|
||||
old_val, count = self.name_to_value[key], self.name_to_count[key]
|
||||
self.name_to_value[key] = old_val * count / (count + 1) + value / (count + 1)
|
||||
self.name_to_count[key] = count + 1
|
||||
self.name_to_excluded[key] = exclude
|
||||
|
||||
def dumpkvs(self) -> None:
|
||||
def dump(self, step: int = 0) -> None:
|
||||
"""
|
||||
Write all of the diagnostics from the current iteration
|
||||
"""
|
||||
if self.level == DISABLED:
|
||||
return
|
||||
for fmt in self.output_formats:
|
||||
if isinstance(fmt, KVWriter):
|
||||
fmt.writekvs(self.name2val)
|
||||
self.name2val.clear()
|
||||
self.name2cnt.clear()
|
||||
for _format in self.output_formats:
|
||||
if isinstance(_format, KVWriter):
|
||||
_format.write(self.name_to_value, self.name_to_excluded, step)
|
||||
|
||||
self.name_to_value.clear()
|
||||
self.name_to_count.clear()
|
||||
self.name_to_excluded.clear()
|
||||
|
||||
def log(self, *args, level: int = INFO) -> None:
|
||||
"""
|
||||
|
|
@ -466,8 +535,8 @@ class Logger(object):
|
|||
"""
|
||||
closes the file
|
||||
"""
|
||||
for fmt in self.output_formats:
|
||||
fmt.close()
|
||||
for _format in self.output_formats:
|
||||
_format.close()
|
||||
|
||||
# Misc
|
||||
# ----------------------------------------
|
||||
|
|
@ -477,37 +546,37 @@ class Logger(object):
|
|||
|
||||
:param args: (list) the arguments to log
|
||||
"""
|
||||
for fmt in self.output_formats:
|
||||
if isinstance(fmt, SeqWriter):
|
||||
fmt.writeseq(map(str, args))
|
||||
for _format in self.output_formats:
|
||||
if isinstance(_format, SeqWriter):
|
||||
_format.write_sequence(map(str, args))
|
||||
|
||||
|
||||
# Initialize logger
|
||||
Logger.DEFAULT = Logger.CURRENT = Logger(folder=None, output_formats=[HumanOutputFormat(sys.stdout)])
|
||||
|
||||
|
||||
def configure(folder: Optional[str] = None, format_strs: Optional[List[str]] = None) -> None:
|
||||
def configure(folder: Optional[str] = None, format_strings: Optional[List[str]] = None) -> None:
|
||||
"""
|
||||
configure the current logger
|
||||
|
||||
:param folder: (Optional[str]) the save location
|
||||
(if None, $BASELINES_LOGDIR, if still None, tempdir/baselines-[date & time])
|
||||
:param format_strs: (Optional[List[str]]) the output logging format
|
||||
(if None, $BASELINES_LOG_FORMAT, if still None, ['stdout', 'log', 'csv'])
|
||||
(if None, $SB3_LOGDIR, if still None, tempdir/baselines-[date & time])
|
||||
:param format_strings: (Optional[List[str]]) the output logging format
|
||||
(if None, $SB3_LOG_FORMAT, if still None, ['stdout', 'log', 'csv'])
|
||||
"""
|
||||
if folder is None:
|
||||
folder = os.getenv('BASELINES_LOGDIR')
|
||||
folder = os.getenv('SB3_LOGDIR')
|
||||
if folder is None:
|
||||
folder = os.path.join(tempfile.gettempdir(), datetime.datetime.now().strftime("baselines-%Y-%m-%d-%H-%M-%S-%f"))
|
||||
folder = os.path.join(tempfile.gettempdir(), datetime.datetime.now().strftime("SB3-%Y-%m-%d-%H-%M-%S-%f"))
|
||||
assert isinstance(folder, str)
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
|
||||
log_suffix = ''
|
||||
if format_strs is None:
|
||||
format_strs = os.getenv('BASELINES_LOG_FORMAT', 'stdout,log,csv').split(',')
|
||||
if format_strings is None:
|
||||
format_strings = os.getenv('SB3_LOG_FORMAT', 'stdout,log,csv').split(',')
|
||||
|
||||
format_strs = filter(None, format_strs)
|
||||
output_formats = [make_output_format(f, folder, log_suffix) for f in format_strs]
|
||||
format_strings = filter(None, format_strings)
|
||||
output_formats = [make_output_format(f, folder, log_suffix) for f in format_strings]
|
||||
|
||||
Logger.CURRENT = Logger(folder=folder, output_formats=output_formats)
|
||||
log(f'Logging to {folder}')
|
||||
|
|
@ -524,28 +593,28 @@ def reset() -> None:
|
|||
|
||||
|
||||
class ScopedConfigure(object):
|
||||
def __init__(self, folder: Optional[str] = None, format_strs: Optional[List[str]] = None):
|
||||
def __init__(self, folder: Optional[str] = None, format_strings: Optional[List[str]] = None):
|
||||
"""
|
||||
Class for using context manager while logging
|
||||
|
||||
usage:
|
||||
with ScopedConfigure(folder=None, format_strs=None):
|
||||
with ScopedConfigure(folder=None, format_strings=None):
|
||||
{code}
|
||||
|
||||
:param folder: (str) the logging folder
|
||||
:param format_strs: ([str]) the list of output logging format
|
||||
:param format_strings: ([str]) the list of output logging format
|
||||
"""
|
||||
self.dir = folder
|
||||
self.format_strs = format_strs
|
||||
self.prevlogger = None
|
||||
self.format_strings = format_strings
|
||||
self.prev_logger = None
|
||||
|
||||
def __enter__(self) -> None:
|
||||
self.prevlogger = Logger.CURRENT
|
||||
configure(folder=self.dir, format_strs=self.format_strs)
|
||||
self.prev_logger = Logger.CURRENT
|
||||
configure(folder=self.dir, format_strings=self.format_strings)
|
||||
|
||||
def __exit__(self, *args) -> None:
|
||||
Logger.CURRENT.close()
|
||||
Logger.CURRENT = self.prevlogger
|
||||
Logger.CURRENT = self.prev_logger
|
||||
|
||||
|
||||
# ================================================================
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from typing import Optional
|
||||
from typing import Optional, List, Iterable
|
||||
from abc import ABC, abstractmethod
|
||||
import copy
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
|
@ -45,7 +46,7 @@ class NormalActionNoise(ActionNoise):
|
|||
|
||||
class OrnsteinUhlenbeckActionNoise(ActionNoise):
|
||||
"""
|
||||
A Ornstein Uhlenbeck action noise, this is designed to aproximate brownian motion with friction.
|
||||
An Ornstein Uhlenbeck action noise, this is designed to aproximate brownian motion with friction.
|
||||
|
||||
Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
|
||||
|
||||
|
|
@ -84,3 +85,81 @@ class OrnsteinUhlenbeckActionNoise(ActionNoise):
|
|||
|
||||
def __repr__(self) -> str:
|
||||
return f'OrnsteinUhlenbeckActionNoise(mu={self._mu}, sigma={self._sigma})'
|
||||
|
||||
|
||||
class VectorizedActionNoise(ActionNoise):
|
||||
"""
|
||||
A Vectorized action noise for parallel environments.
|
||||
|
||||
:param base_noise: ActionNoise The noise generator to use
|
||||
:param n_envs: (int) The number of parallel environments
|
||||
"""
|
||||
|
||||
def __init__(self, base_noise: ActionNoise, n_envs: int):
|
||||
try:
|
||||
self.n_envs = int(n_envs)
|
||||
assert self.n_envs > 0
|
||||
except (TypeError, AssertionError):
|
||||
raise ValueError(f"Expected n_envs={n_envs} to be positive integer greater than 0")
|
||||
|
||||
self.base_noise = base_noise
|
||||
self.noises = [copy.deepcopy(self.base_noise) for _ in range(n_envs)]
|
||||
|
||||
def reset(self, indices: Optional[Iterable[int]] = None) -> None:
|
||||
"""
|
||||
Reset all the noise processes, or those listed in indices
|
||||
|
||||
:param indices: Optional[Iterable[int]] The indices to reset. Default: None.
|
||||
If the parameter is None, then all processes are reset to their initial position.
|
||||
"""
|
||||
if indices is None:
|
||||
indices = range(len(self.noises))
|
||||
|
||||
for index in indices:
|
||||
self.noises[index].reset()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"VecNoise(BaseNoise={repr(self.base_noise)}), n_envs={len(self.noises)})"
|
||||
|
||||
def __call__(self) -> np.ndarray:
|
||||
"""
|
||||
Generate and stack the action noise from each noise object
|
||||
"""
|
||||
noise = np.stack([noise() for noise in self.noises])
|
||||
return noise
|
||||
|
||||
@property
|
||||
def base_noise(self) -> ActionNoise:
|
||||
return self._base_noise
|
||||
|
||||
@base_noise.setter
|
||||
def base_noise(self, base_noise: ActionNoise):
|
||||
if base_noise is None:
|
||||
raise ValueError("Expected base_noise to be an instance of ActionNoise, not None", ActionNoise)
|
||||
if not isinstance(base_noise, ActionNoise):
|
||||
raise TypeError("Expected base_noise to be an instance of type ActionNoise", ActionNoise)
|
||||
self._base_noise = base_noise
|
||||
|
||||
@property
|
||||
def noises(self) -> List[ActionNoise]:
|
||||
return self._noises
|
||||
|
||||
@noises.setter
|
||||
def noises(self, noises: List[ActionNoise]) -> None:
|
||||
noises = list(noises) # raises TypeError if not iterable
|
||||
assert len(noises) == self.n_envs, f"Expected a list of {self.n_envs} ActionNoises, found {len(noises)}."
|
||||
|
||||
different_types = [
|
||||
i for i, noise in enumerate(noises)
|
||||
if not isinstance(noise, type(self.base_noise))
|
||||
]
|
||||
|
||||
if len(different_types):
|
||||
raise ValueError(
|
||||
f"Noise instances at indices {different_types} don't match the type of base_noise",
|
||||
type(self.base_noise)
|
||||
)
|
||||
|
||||
self._noises = noises
|
||||
for noise in noises:
|
||||
noise.reset()
|
||||
|
|
|
|||
|
|
@ -1,8 +1,17 @@
|
|||
from typing import Callable, Union
|
||||
from typing import Callable, Union, Optional
|
||||
import random
|
||||
|
||||
import os
|
||||
import glob
|
||||
import numpy as np
|
||||
import torch as th
|
||||
# Check if tensorboard is available for pytorch
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
except ImportError:
|
||||
SummaryWriter = None
|
||||
|
||||
from stable_baselines3.common import logger
|
||||
|
||||
|
||||
def set_random_seed(seed: int, using_cuda: bool = False) -> None:
|
||||
|
|
@ -110,3 +119,42 @@ def get_device(device: Union[th.device, str] = 'auto') -> th.device:
|
|||
return th.device('cpu')
|
||||
|
||||
return device
|
||||
|
||||
|
||||
def get_latest_run_id(log_path: Optional[str] = None, log_name: str = '') -> int:
|
||||
"""
|
||||
Returns the latest run number for the given log name and log path,
|
||||
by finding the greatest number in the directories.
|
||||
|
||||
:return: (int) latest run number
|
||||
"""
|
||||
max_run_id = 0
|
||||
for path in glob.glob(f"{log_path}/{log_name}_[0-9]*"):
|
||||
file_name = path.split(os.sep)[-1]
|
||||
ext = file_name.split("_")[-1]
|
||||
if log_name == "_".join(file_name.split("_")[:-1]) and ext.isdigit() and int(ext) > max_run_id:
|
||||
max_run_id = int(ext)
|
||||
return max_run_id
|
||||
|
||||
|
||||
def configure_logger(verbose: int = 0, tensorboard_log: Optional[str] = None,
|
||||
tb_log_name: str = '', reset_num_timesteps: bool = True) -> None:
|
||||
"""
|
||||
Configure the logger's outputs.
|
||||
|
||||
:param verbose: (int) the verbosity level: 0 no output, 1 info, 2 debug
|
||||
:param tensorboard_log: (str) the log location for tensorboard (if None, no logging)
|
||||
:param tb_log_name: (str) tensorboard log
|
||||
"""
|
||||
if tensorboard_log is not None and SummaryWriter is not None:
|
||||
latest_run_id = get_latest_run_id(tensorboard_log, tb_log_name)
|
||||
if not reset_num_timesteps:
|
||||
# Continue training in the same directory
|
||||
latest_run_id -= 1
|
||||
save_path = os.path.join(tensorboard_log, f"{tb_log_name}_{latest_run_id + 1}")
|
||||
if verbose >= 1:
|
||||
logger.configure(save_path, ["stdout", "tensorboard"])
|
||||
else:
|
||||
logger.configure(save_path, ["tensorboard"])
|
||||
elif verbose == 0:
|
||||
logger.configure(format_strings=[""])
|
||||
|
|
|
|||
|
|
@ -5,13 +5,6 @@ import gym
|
|||
from gym import spaces
|
||||
import torch as th
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Check if tensorboard is available for pytorch
|
||||
# TODO: finish tensorboard integration
|
||||
# try:
|
||||
# from torch.utils.tensorboard import SummaryWriter
|
||||
# except ImportError:
|
||||
# SummaryWriter = None
|
||||
import numpy as np
|
||||
|
||||
from stable_baselines3.common import logger
|
||||
|
|
@ -95,10 +88,11 @@ class PPO(BaseRLModel):
|
|||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
||||
verbose: int = 0,
|
||||
seed: Optional[int] = None,
|
||||
device: Union[th.device, str] = 'auto',
|
||||
device: Union[th.device, str] = "auto",
|
||||
_init_setup_model: bool = True):
|
||||
|
||||
super(PPO, self).__init__(policy, env, PPOPolicy, learning_rate, policy_kwargs=policy_kwargs,
|
||||
super(PPO, self).__init__(policy, env, PPOPolicy, learning_rate,
|
||||
policy_kwargs=policy_kwargs, tensorboard_log=tensorboard_log,
|
||||
verbose=verbose, device=device, use_sde=use_sde, sde_sample_freq=sde_sample_freq,
|
||||
create_eval_env=create_eval_env, support_multi_env=True, seed=seed)
|
||||
|
||||
|
|
@ -114,7 +108,6 @@ class PPO(BaseRLModel):
|
|||
self.max_grad_norm = max_grad_norm
|
||||
self.rollout_buffer = None
|
||||
self.target_kl = target_kl
|
||||
self.tensorboard_log = tensorboard_log
|
||||
self.tb_writer = None
|
||||
|
||||
if _init_setup_model:
|
||||
|
|
@ -136,8 +129,8 @@ class PPO(BaseRLModel):
|
|||
self.clip_range = get_schedule_fn(self.clip_range)
|
||||
if self.clip_range_vf is not None:
|
||||
if isinstance(self.clip_range_vf, (float, int)):
|
||||
assert self.clip_range_vf > 0, ('`clip_range_vf` must be positive, '
|
||||
'pass `None` to deactivate vf clipping')
|
||||
assert self.clip_range_vf > 0, ("`clip_range_vf` must be positive, "
|
||||
"pass `None` to deactivate vf clipping")
|
||||
|
||||
self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
|
||||
|
||||
|
|
@ -231,6 +224,7 @@ class PPO(BaseRLModel):
|
|||
|
||||
# ratio between old and new policy, should be one at the first iteration
|
||||
ratio = th.exp(log_prob - rollout_data.old_log_prob)
|
||||
|
||||
# clipped surrogate loss
|
||||
policy_loss_1 = advantages * ratio
|
||||
policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
|
||||
|
|
@ -282,19 +276,21 @@ class PPO(BaseRLModel):
|
|||
explained_var = explained_variance(self.rollout_buffer.returns.flatten(),
|
||||
self.rollout_buffer.values.flatten())
|
||||
|
||||
logger.logkv("n_updates", self._n_updates)
|
||||
logger.logkv("clip_fraction", np.mean(clip_fraction))
|
||||
logger.logkv("clip_range", clip_range)
|
||||
if self.clip_range_vf is not None:
|
||||
logger.logkv("clip_range_vf", clip_range_vf)
|
||||
# Logs
|
||||
logger.record("train/entropy_loss", np.mean(entropy_losses))
|
||||
logger.record("train/policy_gradient_loss", np.mean(pg_losses))
|
||||
logger.record("train/value_loss", np.mean(value_losses))
|
||||
logger.record("train/approx_kl", np.mean(approx_kl_divs))
|
||||
logger.record("train/clip_fraction", np.mean(clip_fraction))
|
||||
logger.record("train/loss", loss.item())
|
||||
logger.record("train/explained_variance", explained_var)
|
||||
if hasattr(self.policy, "log_std"):
|
||||
logger.record("train/std", th.exp(self.policy.log_std).mean().item())
|
||||
|
||||
logger.logkv("approx_kl", np.mean(approx_kl_divs))
|
||||
logger.logkv("explained_variance", explained_var)
|
||||
logger.logkv("entropy_loss", np.mean(entropy_losses))
|
||||
logger.logkv("policy_gradient_loss", np.mean(pg_losses))
|
||||
logger.logkv("value_loss", np.mean(value_losses))
|
||||
if hasattr(self.policy, 'log_std'):
|
||||
logger.logkv("std", th.exp(self.policy.log_std).mean().item())
|
||||
logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
|
||||
logger.record("train/clip_range", clip_range)
|
||||
if self.clip_range_vf is not None:
|
||||
logger.record("train/clip_range_vf", clip_range_vf)
|
||||
|
||||
def learn(self,
|
||||
total_timesteps: int,
|
||||
|
|
@ -305,15 +301,12 @@ class PPO(BaseRLModel):
|
|||
n_eval_episodes: int = 5,
|
||||
tb_log_name: str = "PPO",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True) -> 'PPO':
|
||||
reset_num_timesteps: bool = True) -> "PPO":
|
||||
|
||||
iteration = 0
|
||||
callback = self._setup_learn(eval_env, callback, eval_freq,
|
||||
n_eval_episodes, eval_log_path, reset_num_timesteps)
|
||||
|
||||
# if self.tensorboard_log is not None and SummaryWriter is not None:
|
||||
# self.tb_writer = SummaryWriter(log_dir=os.path.join(self.tensorboard_log, tb_log_name))
|
||||
|
||||
total_timesteps, callback = self._setup_learn(total_timesteps, eval_env, callback, eval_freq,
|
||||
n_eval_episodes, eval_log_path, reset_num_timesteps,
|
||||
tb_log_name)
|
||||
callback.on_training_start(locals(), globals())
|
||||
|
||||
while self.num_timesteps < total_timesteps:
|
||||
|
|
@ -328,24 +321,22 @@ class PPO(BaseRLModel):
|
|||
iteration += 1
|
||||
self._update_current_progress(self.num_timesteps, total_timesteps)
|
||||
|
||||
# Display training infos
|
||||
if self.verbose >= 1 and log_interval is not None and iteration % log_interval == 0:
|
||||
# Log training infos
|
||||
if log_interval is not None and iteration % log_interval == 0:
|
||||
fps = int(self.num_timesteps / (time.time() - self.start_time))
|
||||
logger.logkv("iterations", iteration)
|
||||
logger.record("time/iterations", iteration, exclude="tensorboard")
|
||||
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
|
||||
logger.logkv('ep_rew_mean', self.safe_mean([ep_info['r'] for ep_info in self.ep_info_buffer]))
|
||||
logger.logkv('ep_len_mean', self.safe_mean([ep_info['l'] for ep_info in self.ep_info_buffer]))
|
||||
logger.logkv("fps", fps)
|
||||
logger.logkv('time_elapsed', int(time.time() - self.start_time))
|
||||
logger.logkv("total timesteps", self.num_timesteps)
|
||||
logger.dumpkvs()
|
||||
logger.record("rollout/ep_rew_mean",
|
||||
self.safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
|
||||
logger.record("rollout/ep_len_mean",
|
||||
self.safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
|
||||
logger.record("time/fps", fps)
|
||||
logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard")
|
||||
logger.record("time/total timesteps", self.num_timesteps, exclude="tensorboard")
|
||||
logger.dump(step=self.num_timesteps)
|
||||
|
||||
self.train(self.n_epochs, batch_size=self.batch_size)
|
||||
|
||||
# For tensorboard integration
|
||||
# if self.tb_writer is not None:
|
||||
# self.tb_writer.add_scalar('Eval/reward', mean_reward, self.num_timesteps)
|
||||
|
||||
callback.on_training_end()
|
||||
|
||||
return self
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from typing import List, Tuple, Type, Union, Callable, Optional, Dict, Any
|
||||
|
||||
import torch as th
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
|
@ -90,7 +89,7 @@ class SAC(OffPolicyRLModel):
|
|||
|
||||
super(SAC, self).__init__(policy, env, SACPolicy, learning_rate,
|
||||
buffer_size, learning_starts, batch_size,
|
||||
policy_kwargs, verbose, device,
|
||||
policy_kwargs, tensorboard_log, verbose, device,
|
||||
create_eval_env=create_eval_env, seed=seed,
|
||||
use_sde=use_sde, sde_sample_freq=sde_sample_freq,
|
||||
use_sde_at_warmup=use_sde_at_warmup)
|
||||
|
|
@ -237,12 +236,12 @@ class SAC(OffPolicyRLModel):
|
|||
|
||||
self._n_updates += gradient_steps
|
||||
|
||||
logger.logkv("n_updates", self._n_updates)
|
||||
logger.logkv("ent_coef", np.mean(ent_coefs))
|
||||
logger.logkv("actor_loss", np.mean(actor_losses))
|
||||
logger.logkv("critic_loss", np.mean(critic_losses))
|
||||
logger.record("train/n_updates", self._n_updates, exclude='tensorboard')
|
||||
logger.record("train/ent_coef", np.mean(ent_coefs))
|
||||
logger.record("train/actor_loss", np.mean(actor_losses))
|
||||
logger.record("train/critic_loss", np.mean(critic_losses))
|
||||
if len(ent_coef_losses) > 0:
|
||||
logger.logkv("ent_coef_loss", np.mean(ent_coef_losses))
|
||||
logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
|
||||
|
||||
def learn(self,
|
||||
total_timesteps: int,
|
||||
|
|
@ -255,8 +254,9 @@ class SAC(OffPolicyRLModel):
|
|||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True) -> OffPolicyRLModel:
|
||||
|
||||
callback = self._setup_learn(eval_env, callback, eval_freq,
|
||||
n_eval_episodes, eval_log_path, reset_num_timesteps)
|
||||
total_timesteps, callback = self._setup_learn(total_timesteps, eval_env, callback, eval_freq,
|
||||
n_eval_episodes, eval_log_path, reset_num_timesteps,
|
||||
tb_log_name)
|
||||
callback.on_training_start(locals(), globals())
|
||||
|
||||
while self.num_timesteps < total_timesteps:
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ class TD3(OffPolicyRLModel):
|
|||
|
||||
super(TD3, self).__init__(policy, env, TD3Policy, learning_rate,
|
||||
buffer_size, learning_starts, batch_size,
|
||||
policy_kwargs, verbose, device,
|
||||
policy_kwargs, tensorboard_log, verbose, device,
|
||||
create_eval_env=create_eval_env, seed=seed,
|
||||
use_sde=use_sde, sde_sample_freq=sde_sample_freq,
|
||||
use_sde_at_warmup=use_sde_at_warmup)
|
||||
|
|
@ -176,7 +176,7 @@ class TD3(OffPolicyRLModel):
|
|||
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
|
||||
|
||||
self._n_updates += gradient_steps
|
||||
logger.logkv("n_updates", self._n_updates)
|
||||
logger.record("train/n_updates", self._n_updates, exclude='tensorboard')
|
||||
|
||||
def train_sde(self) -> None:
|
||||
# Update optimizer learning rate
|
||||
|
|
@ -235,9 +235,9 @@ class TD3(OffPolicyRLModel):
|
|||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True) -> OffPolicyRLModel:
|
||||
|
||||
callback = self._setup_learn(eval_env, callback, eval_freq,
|
||||
n_eval_episodes, eval_log_path, reset_num_timesteps)
|
||||
|
||||
total_timesteps, callback = self._setup_learn(total_timesteps, eval_env, callback, eval_freq,
|
||||
n_eval_episodes, eval_log_path, reset_num_timesteps,
|
||||
tb_log_name)
|
||||
callback.on_training_start(locals(), globals())
|
||||
|
||||
while self.num_timesteps < total_timesteps:
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
0.6.0a10
|
||||
0.6.0
|
||||
|
|
|
|||
|
|
@ -1,12 +1,9 @@
|
|||
import os
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from stable_baselines3.common.logger import (make_output_format, read_csv, read_json, DEBUG, ScopedConfigure,
|
||||
info, debug, set_level, configure, logkv, logkvs,
|
||||
dumpkvs, logkv_mean, warn, error, reset)
|
||||
info, debug, set_level, configure, record, record_dict,
|
||||
dump, record_mean, warn, error, reset)
|
||||
|
||||
KEY_VALUES = {
|
||||
"test": 1,
|
||||
|
|
@ -18,10 +15,12 @@ KEY_VALUES = {
|
|||
"g": np.array([[[1]]]),
|
||||
}
|
||||
|
||||
LOG_DIR = '/tmp/stable_baselines3/'
|
||||
KEY_EXCLUDED = {}
|
||||
for key in KEY_VALUES.keys():
|
||||
KEY_EXCLUDED[key] = None
|
||||
|
||||
|
||||
def test_main():
|
||||
def test_main(tmp_path):
|
||||
"""
|
||||
tests for the logger module
|
||||
"""
|
||||
|
|
@ -29,55 +28,56 @@ def test_main():
|
|||
debug("shouldn't appear")
|
||||
set_level(DEBUG)
|
||||
debug("should appear")
|
||||
folder = "/tmp/testlogging"
|
||||
if os.path.exists(folder):
|
||||
shutil.rmtree(folder)
|
||||
configure(folder=folder)
|
||||
logkv("a", 3)
|
||||
logkv("b", 2.5)
|
||||
dumpkvs()
|
||||
logkv("b", -2.5)
|
||||
logkv("a", 5.5)
|
||||
dumpkvs()
|
||||
configure(folder=str(tmp_path))
|
||||
record("a", 3)
|
||||
record("b", 2.5)
|
||||
dump()
|
||||
record("b", -2.5)
|
||||
record("a", 5.5)
|
||||
dump()
|
||||
info("^^^ should see a = 5.5")
|
||||
logkv_mean("b", -22.5)
|
||||
logkv_mean("b", -44.4)
|
||||
logkv("a", 5.5)
|
||||
dumpkvs()
|
||||
record_mean("b", -22.5)
|
||||
record_mean("b", -44.4)
|
||||
record("a", 5.5)
|
||||
dump()
|
||||
with ScopedConfigure(None, None):
|
||||
info("^^^ should see b = 33.3")
|
||||
|
||||
with ScopedConfigure("/tmp/test-logger/", ["json"]):
|
||||
logkv("b", -2.5)
|
||||
dumpkvs()
|
||||
with ScopedConfigure(str(tmp_path / "test-logger"), ["json"]):
|
||||
record("b", -2.5)
|
||||
dump()
|
||||
|
||||
reset()
|
||||
logkv("a", "longasslongasslongasslongasslongasslongassvalue")
|
||||
dumpkvs()
|
||||
record("a", "longasslongasslongasslongasslongasslongassvalue")
|
||||
dump()
|
||||
warn("hey")
|
||||
error("oh")
|
||||
logkvs({"test": 1})
|
||||
record_dict({"test": 1})
|
||||
|
||||
|
||||
@pytest.mark.parametrize('_format', ['stdout', 'log', 'json', 'csv'])
|
||||
def test_make_output(_format):
|
||||
@pytest.mark.parametrize('_format', ['stdout', 'log', 'json', 'csv', 'tensorboard'])
|
||||
def test_make_output(tmp_path, _format):
|
||||
"""
|
||||
test make output
|
||||
|
||||
:param _format: (str) output format
|
||||
"""
|
||||
writer = make_output_format(_format, LOG_DIR)
|
||||
writer.writekvs(KEY_VALUES)
|
||||
if _format == 'tensorboard':
|
||||
# Skip if no tensorboard installed
|
||||
pytest.importorskip("tensorboard")
|
||||
|
||||
writer = make_output_format(_format, tmp_path)
|
||||
writer.write(KEY_VALUES, KEY_EXCLUDED)
|
||||
if _format == "csv":
|
||||
read_csv(LOG_DIR + 'progress.csv')
|
||||
read_csv(tmp_path / 'progress.csv')
|
||||
elif _format == 'json':
|
||||
read_json(LOG_DIR + 'progress.json')
|
||||
read_json(tmp_path / 'progress.json')
|
||||
writer.close()
|
||||
|
||||
|
||||
def test_make_output_fail():
|
||||
def test_make_output_fail(tmp_path):
|
||||
"""
|
||||
test value error on logger
|
||||
"""
|
||||
with pytest.raises(ValueError):
|
||||
make_output_format('dummy_format', LOG_DIR)
|
||||
make_output_format('dummy_format', tmp_path)
|
||||
|
|
|
|||
36
tests/test_tensorboard.py
Normal file
36
tests/test_tensorboard.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
import os
|
||||
import pytest
|
||||
|
||||
from stable_baselines3 import A2C, PPO, SAC, TD3
|
||||
|
||||
MODEL_DICT = {
|
||||
'a2c': (A2C, 'CartPole-v1'),
|
||||
'ppo': (PPO, 'CartPole-v1'),
|
||||
'sac': (SAC, 'Pendulum-v0'),
|
||||
'td3': (TD3, 'Pendulum-v0'),
|
||||
}
|
||||
|
||||
N_STEPS = 100
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", MODEL_DICT.keys())
|
||||
def test_tensorboard(tmp_path, model_name):
|
||||
# Skip if no tensorboard installed
|
||||
pytest.importorskip("tensorboard")
|
||||
|
||||
logname = model_name.upper()
|
||||
algo, env_id = MODEL_DICT[model_name]
|
||||
model = algo('MlpPolicy', env_id, verbose=1, tensorboard_log=tmp_path)
|
||||
model.learn(N_STEPS)
|
||||
model.learn(N_STEPS, reset_num_timesteps=False)
|
||||
|
||||
assert os.path.isdir(tmp_path / str(logname + "_1"))
|
||||
assert not os.path.isdir(tmp_path / str(logname + "_2"))
|
||||
|
||||
logname = "tb_multiple_runs_" + model_name
|
||||
model.learn(N_STEPS, tb_log_name=logname)
|
||||
model.learn(N_STEPS, tb_log_name=logname)
|
||||
|
||||
assert os.path.isdir(tmp_path / str(logname + "_1"))
|
||||
# Check that the log dir name increments correctly
|
||||
assert os.path.isdir(tmp_path / str(logname + "_2"))
|
||||
|
|
@ -10,6 +10,8 @@ from stable_baselines3.common.monitor import Monitor
|
|||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.common.cmd_util import make_vec_env, make_atari_env
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
|
||||
from stable_baselines3.common.noise import (
|
||||
VectorizedActionNoise, OrnsteinUhlenbeckActionNoise, ActionNoise)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_id", ['CartPole-v1', lambda: gym.make('CartPole-v1')])
|
||||
|
|
@ -107,3 +109,36 @@ def test_evaluate_policy():
|
|||
|
||||
episode_rewards, _ = evaluate_policy(model, model.get_env(), n_eval_episodes, return_episode_rewards=True)
|
||||
assert len(episode_rewards) == n_eval_episodes
|
||||
|
||||
|
||||
def test_vec_noise():
|
||||
num_envs = 4
|
||||
num_actions = 10
|
||||
mu = np.zeros(num_actions)
|
||||
sigma = np.ones(num_actions) * 0.4
|
||||
base: ActionNoise = OrnsteinUhlenbeckActionNoise(mu, sigma)
|
||||
with pytest.raises(ValueError):
|
||||
vec = VectorizedActionNoise(base, -1)
|
||||
with pytest.raises(ValueError):
|
||||
vec = VectorizedActionNoise(base, None)
|
||||
with pytest.raises(ValueError):
|
||||
vec = VectorizedActionNoise(base, "whatever")
|
||||
|
||||
vec = VectorizedActionNoise(base, num_envs)
|
||||
assert vec.n_envs == num_envs
|
||||
assert vec().shape == (num_envs, num_actions)
|
||||
assert not (vec() == base()).all()
|
||||
with pytest.raises(ValueError):
|
||||
vec = VectorizedActionNoise(None, num_envs)
|
||||
with pytest.raises(TypeError):
|
||||
vec = VectorizedActionNoise(12, num_envs)
|
||||
with pytest.raises(AssertionError):
|
||||
vec.noises = []
|
||||
with pytest.raises(TypeError):
|
||||
vec.noises = None
|
||||
with pytest.raises(ValueError):
|
||||
vec.noises = [None] * vec.n_envs
|
||||
with pytest.raises(AssertionError):
|
||||
vec.noises = [base] * (num_envs - 1)
|
||||
assert all(isinstance(noise, type(base)) for noise in vec.noises)
|
||||
assert len(vec.noises) == num_envs
|
||||
|
|
|
|||
Loading…
Reference in a new issue