mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-01 23:30:53 +00:00
Documentation update and style fixes (#21)
* Update doc: add gSDE * Fix codestyle * Remove travis script * Add lint check to gitlab
This commit is contained in:
parent
54f6f5b6fb
commit
15ff6d47ee
25 changed files with 88 additions and 166 deletions
4
.github/PULL_REQUEST_TEMPLATE.md
vendored
4
.github/PULL_REQUEST_TEMPLATE.md
vendored
|
|
@ -7,7 +7,7 @@
|
|||
<!--- Why is this change required? What problem does it solve? -->
|
||||
<!--- If it fixes an open issue, please link to the issue here. -->
|
||||
<!--- You can use the syntax `closes #100` if this solves the issue #100 -->
|
||||
- [ ] I have raised an issue to propose this change ([required](https://github.com/hill-a/stable-baselines/blob/master/CONTRIBUTING.md) for new features and bug fixes)
|
||||
- [ ] I have raised an issue to propose this change ([required](https://github.com/DLR-RM/stable-baselines3/blob/master/CONTRIBUTING.md) for new features and bug fixes)
|
||||
|
||||
## Types of changes
|
||||
<!--- What types of changes does your code introduce? Put an `x` in all the boxes that apply: -->
|
||||
|
|
@ -19,7 +19,7 @@
|
|||
## Checklist:
|
||||
<!--- Go over all the following points, and put an `x` in all the boxes that apply. -->
|
||||
<!--- If you're unsure about any of these, don't hesitate to ask. We're here to help! -->
|
||||
- [ ] I've read the [CONTRIBUTION](https://github.com/hill-a/stable-baselines/blob/master/CONTRIBUTING.md) guide (**required**)
|
||||
- [ ] I've read the [CONTRIBUTION](https://github.com/DLR-RM/stable-baselines3/blob/master/CONTRIBUTING.md) guide (**required**)
|
||||
- [ ] I have updated the changelog accordingly (**required**).
|
||||
- [ ] My change requires a change to the documentation.
|
||||
- [ ] I have updated the tests accordingly (*required for a bug fix or a new feature*).
|
||||
|
|
|
|||
|
|
@ -12,3 +12,8 @@ pytest:
|
|||
doc-build:
|
||||
script:
|
||||
- make doc
|
||||
|
||||
lint-check:
|
||||
script:
|
||||
- pip install flake8 # TODO: remove when new version on Pypi
|
||||
- make lint
|
||||
|
|
|
|||
|
|
@ -8,19 +8,9 @@ version: 2
|
|||
sphinx:
|
||||
configuration: docs/conf.py
|
||||
|
||||
# Build documentation with MkDocs
|
||||
#mkdocs:
|
||||
# configuration: mkdocs.yml
|
||||
|
||||
# Optionally build your docs in additional formats such as PDF and ePub
|
||||
formats: all
|
||||
|
||||
# Set requirements using conda env
|
||||
conda:
|
||||
environment: docs/conda_env.yml
|
||||
|
||||
# Optionally set the version of Python and requirements required to build your docs
|
||||
# python:
|
||||
# version: 3.7
|
||||
# install:
|
||||
# - requirements: docs/requirements.txt
|
||||
|
|
|
|||
49
.travis.yml
49
.travis.yml
|
|
@ -1,49 +0,0 @@
|
|||
language: python
|
||||
python:
|
||||
- "3.6"
|
||||
|
||||
env:
|
||||
global:
|
||||
- DOCKER_IMAGE=stablebaselines/stable-baselines3-cpu:0.6.0a5
|
||||
|
||||
notifications:
|
||||
email: false
|
||||
|
||||
services:
|
||||
- docker
|
||||
|
||||
install:
|
||||
- docker pull ${DOCKER_IMAGE}
|
||||
|
||||
script:
|
||||
- ./scripts/run_tests_travis.sh "${TEST_GLOB}"
|
||||
|
||||
jobs:
|
||||
include:
|
||||
# Big test suite. Run in parallel to decrease wall-clock time, and to avoid OOM error from leaks
|
||||
- stage: Test
|
||||
name: "Unit Tests a-h"
|
||||
env: TEST_GLOB="[a-h]*"
|
||||
|
||||
- name: "Unit Tests i-l"
|
||||
env: TEST_GLOB="[i-l]*"
|
||||
|
||||
- name: "Unit Tests m-sa"
|
||||
env: TEST_GLOB="{[m-r]*,sa*}"
|
||||
|
||||
- name: "Unit Tests sb-z"
|
||||
env: TEST_GLOB="{s[b-z]*,[t-z]*}"
|
||||
|
||||
- name: "Sphinx Documentation"
|
||||
script:
|
||||
- 'docker run -it --rm --mount src=$(pwd),target=/root/code/stable-baselines3,type=bind ${DOCKER_IMAGE} bash -c "cd /root/code/stable-baselines3/ && pushd docs/ && make clean && make html"'
|
||||
|
||||
- name: "Type Checking"
|
||||
script:
|
||||
- 'docker run --rm --mount src=$(pwd),target=/root/code/stable-baselines3,type=bind ${DOCKER_IMAGE} bash -c "cd /root/code/stable-baselines3/ && pytype --version && pytype"'
|
||||
|
||||
- stage: Codacy Trigger
|
||||
if: type != pull_request
|
||||
script:
|
||||
# When all test coverage reports have been uploaded, instruct Codacy to start analysis.
|
||||
- 'docker run -it --rm --network host --ipc=host --mount src=$(pwd),target=/root/code/stable-baselines3,type=bind --env CODACY_PROJECT_TOKEN=${CODACY_PROJECT_TOKEN} ${DOCKER_IMAGE} bash -c "cd /root/code/stable-baselines3/ && java -jar /root/code/codacy-coverage-reporter.jar final"'
|
||||
10
README.md
10
README.md
|
|
@ -177,18 +177,24 @@ Actions `gym.spaces`:
|
|||
|
||||
|
||||
## Testing the installation
|
||||
All unit tests in stable baselines3 can be run using pytest runner:
|
||||
All unit tests in stable baselines3 can be run using `pytest` runner:
|
||||
```
|
||||
pip install pytest pytest-cov
|
||||
make pytest
|
||||
```
|
||||
|
||||
You can also do a static type check using pytype:
|
||||
You can also do a static type check using `pytype`:
|
||||
```
|
||||
pip install pytype
|
||||
make type
|
||||
```
|
||||
|
||||
Codestyle check with `flake8`:
|
||||
```
|
||||
pip install flake8
|
||||
make lint
|
||||
```
|
||||
|
||||
## Projects Using Stable-Baselines3
|
||||
|
||||
We try to maintain a list of project using stable-baselines3 in the [documentation](https://stable-baselines3.readthedocs.io/en/master/misc/projects.html),
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ State-Dependent Exploration
|
|||
|
||||
State-Dependent Exploration (SDE) is a type of exploration that allows to use RL directly on real robots,
|
||||
that was the starting point for the Stable-Baselines3 library.
|
||||
I (@araffin) will publish a paper about a generalized version of SDE (the one implemented in SB3) soon.
|
||||
I (@araffin) published a paper about a generalized version of SDE (the one implemented in SB3): https://arxiv.org/abs/2005.05719
|
||||
|
||||
Misc
|
||||
====
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
Pre-Release 0.6.0a7 (WIP)
|
||||
Pre-Release 0.6.0a8 (WIP)
|
||||
------------------------------
|
||||
|
||||
|
||||
|
|
@ -41,6 +41,7 @@ Documentation:
|
|||
^^^^^^^^^^^^^^
|
||||
- Added most documentation (adapted from Stable-Baselines)
|
||||
- Added link to CONTRIBUTING.md in the README (@kinalmehta)
|
||||
- Added gSDE project and update docstrings accordingly
|
||||
|
||||
|
||||
Pre-Release 0.5.0 (2020-05-05)
|
||||
|
|
|
|||
|
|
@ -16,11 +16,12 @@ Please tell us, if you want your project to appear on this page ;)
|
|||
.. | Github repo: https://github.com/araffin/RL-Racing-Robot
|
||||
|
||||
|
||||
.. Generalized State Dependent Exploration for Deep Reinforcement Learning in Robotics
|
||||
.. -----------------------------------------------------------------------------------
|
||||
..
|
||||
.. An exploration method to train RL agent directly on real robots.
|
||||
..
|
||||
.. | Author: Antonin Raffin, Freek Stulp
|
||||
.. | Github: https://github.com/DLR-RM/stable-baselines3/tree/sde
|
||||
.. | Paper:
|
||||
Generalized State Dependent Exploration for Deep Reinforcement Learning in Robotics
|
||||
-----------------------------------------------------------------------------------
|
||||
|
||||
An exploration method to train RL agent directly on real robots.
|
||||
It was the starting point of Stable-Baselines3.
|
||||
|
||||
| Author: Antonin Raffin, Freek Stulp
|
||||
| Github: https://github.com/DLR-RM/stable-baselines3/tree/sde
|
||||
| Paper: https://arxiv.org/abs/2005.05719
|
||||
|
|
|
|||
|
|
@ -1,32 +0,0 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
DOCKER_CMD="docker run -it --rm --network host --ipc=host --mount src=$(pwd),target=/root/code/stable-baselines3,type=bind"
|
||||
BASH_CMD="cd /root/code/stable-baselines3/"
|
||||
|
||||
if [[ $# -ne 1 ]]; then
|
||||
echo "usage: $0 <test glob>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ ${DOCKER_IMAGE} = "" ]]; then
|
||||
echo "Need DOCKER_IMAGE environment variable to be set."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
TEST_GLOB=$1
|
||||
|
||||
set -e # exit immediately on any error
|
||||
|
||||
# For pull requests from fork, Codacy token is not available, leading to build failure
|
||||
if [[ ${CODACY_PROJECT_TOKEN} = "" ]]; then
|
||||
echo "WARNING: CODACY_PROJECT_TOKEN not set. Skipping Codacy upload."
|
||||
echo "(This is normal when building in a fork and can be ignored.)"
|
||||
${DOCKER_CMD} ${DOCKER_IMAGE} \
|
||||
bash -c "${BASH_CMD} && \
|
||||
pytest --cov-config .coveragerc --cov-report term --cov=. -v tests/test_${TEST_GLOB}"
|
||||
else
|
||||
${DOCKER_CMD} --env CODACY_PROJECT_TOKEN=${CODACY_PROJECT_TOKEN} ${DOCKER_IMAGE} \
|
||||
bash -c "${BASH_CMD} && \
|
||||
pytest --cov-config .coveragerc --cov-report term --cov-report xml --cov=. -v tests/test_${TEST_GLOB} && \
|
||||
java -jar /root/code/codacy-coverage-reporter.jar report -l python -r coverage.xml --partial"
|
||||
fi
|
||||
|
|
@ -34,9 +34,9 @@ class A2C(PPO):
|
|||
:param rms_prop_eps: (float) RMSProp epsilon. It stabilizes square root computation in denominator
|
||||
of RMSProp update
|
||||
:param use_rms_prop: (bool) Whether to use RMSprop (default) or Adam as optimizer
|
||||
:param use_sde: (bool) Whether to use State Dependent Exploration (SDE)
|
||||
:param use_sde: (bool) Whether to use generalized State Dependent Exploration (gSDE)
|
||||
instead of action noise exploration (default: False)
|
||||
:param sde_sample_freq: (int) Sample a new noise matrix every n steps when using SDE
|
||||
:param sde_sample_freq: (int) Sample a new noise matrix every n steps when using gSDE
|
||||
Default: -1 (only sample at the beginning of the rollout)
|
||||
:param normalize_advantage: (bool) Whether to normalize or not the advantage
|
||||
:param tensorboard_log: (str) the log location for tensorboard (if None, no logging)
|
||||
|
|
|
|||
|
|
@ -46,9 +46,9 @@ class BaseRLModel(ABC):
|
|||
:param monitor_wrapper: (bool) When creating an environment, whether to wrap it
|
||||
or not in a Monitor wrapper.
|
||||
:param seed: (Optional[int]) Seed for the pseudo random generators
|
||||
:param use_sde: (bool) Whether to use State Dependent Exploration (SDE)
|
||||
:param use_sde: (bool) Whether to use generalized State Dependent Exploration (gSDE)
|
||||
instead of action noise exploration (default: False)
|
||||
:param sde_sample_freq: (int) Sample a new noise matrix every n steps when using SDE
|
||||
:param sde_sample_freq: (int) Sample a new noise matrix every n steps when using gSDE
|
||||
Default: -1 (only sample at the beginning of the rollout)
|
||||
"""
|
||||
|
||||
|
|
@ -96,7 +96,7 @@ class BaseRLModel(ABC):
|
|||
# When using VecNormalize:
|
||||
self._last_original_obs = None # type: Optional[np.ndarray]
|
||||
self._episode_num = 0
|
||||
# Used for SDE only
|
||||
# Used for gSDE only
|
||||
self.use_sde = use_sde
|
||||
self.sde_sample_freq = sde_sample_freq
|
||||
# Track the training progress (from 1 to 0)
|
||||
|
|
@ -681,11 +681,11 @@ class OffPolicyRLModel(BaseRLModel):
|
|||
:param seed: Seed for the pseudo random generators
|
||||
:param use_sde: Whether to use State Dependent Exploration (SDE)
|
||||
instead of action noise exploration (default: False)
|
||||
:param sde_sample_freq: Sample a new noise matrix every n steps when using SDE
|
||||
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
|
||||
Default: -1 (only sample at the beginning of the rollout)
|
||||
:param use_sde_at_warmup: (bool) Whether to use SDE instead of uniform sampling
|
||||
:param use_sde_at_warmup: (bool) Whether to use gSDE instead of uniform sampling
|
||||
during the warm up phase (before learning starts)
|
||||
:param sde_support: (bool) Whether the model support SDE or not
|
||||
:param sde_support: (bool) Whether the model support gSDE or not
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
|
@ -721,7 +721,7 @@ class OffPolicyRLModel(BaseRLModel):
|
|||
if sde_support:
|
||||
self.policy_kwargs['use_sde'] = self.use_sde
|
||||
self.policy_kwargs['device'] = self.device
|
||||
# For SDE only
|
||||
# For gSDE only
|
||||
self.use_sde_at_warmup = use_sde_at_warmup
|
||||
|
||||
def _setup_model(self):
|
||||
|
|
|
|||
|
|
@ -276,7 +276,7 @@ class EvalCallback(EventCallback):
|
|||
|
||||
def _init_callback(self):
|
||||
# Does not work in some corner cases, where the wrapper is not the same
|
||||
if not type(self.training_env) is type(self.eval_env):
|
||||
if not isinstance(self.training_env, type(self.eval_env)):
|
||||
warnings.warn("Training and eval env are not of the same type"
|
||||
f"{self.training_env} != {self.eval_env}")
|
||||
|
||||
|
|
|
|||
|
|
@ -294,7 +294,9 @@ class CategoricalDistribution(Distribution):
|
|||
|
||||
class StateDependentNoiseDistribution(Distribution):
|
||||
"""
|
||||
Distribution class for using State Dependent Exploration (SDE).
|
||||
Distribution class for using generalized State Dependent Exploration (gSDE).
|
||||
Paper: https://arxiv.org/abs/2005.05719
|
||||
|
||||
It is used to create the noise exploration matrix and
|
||||
compute the log probability of an action with that noise.
|
||||
|
||||
|
|
@ -306,7 +308,7 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||
:param squash_output: (bool) Whether to squash the output using a tanh function,
|
||||
this allows to ensure boundaries.
|
||||
:param learn_features: (bool) Whether to learn features for SDE or not.
|
||||
:param learn_features: (bool) Whether to learn features for gSDE or not.
|
||||
This will enable gradients to be backpropagated through the features
|
||||
``latent_sde`` in the code.
|
||||
:param epsilon: (float) small value to avoid NaN due to numerical imprecision.
|
||||
|
|
@ -346,7 +348,7 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
:return: (th.Tensor)
|
||||
"""
|
||||
if self.use_expln:
|
||||
# From SDE paper, it allows to keep variance
|
||||
# From gSDE paper, it allows to keep variance
|
||||
# above zero and prevent it from growing too fast
|
||||
below_threshold = th.exp(log_std) * (log_std <= 0)
|
||||
# Avoid NaN: zeros values that are below zero
|
||||
|
|
@ -387,7 +389,7 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
:param latent_dim: (int) Dimension of the last layer of the policy (before the action layer)
|
||||
:param log_std_init: (float) Initial value for the log standard deviation
|
||||
:param latent_sde_dim: (Optional[int]) Dimension of the last layer of the feature extractor
|
||||
for SDE. By default, it is shared with the policy network.
|
||||
for gSDE. By default, it is shared with the policy network.
|
||||
:return: (nn.Linear, nn.Parameter)
|
||||
"""
|
||||
# Network for the deterministic action, it represents the mean of the distribution
|
||||
|
|
|
|||
|
|
@ -416,7 +416,7 @@ def create_sde_features_extractor(features_dim: int,
|
|||
activation_fn: Type[nn.Module]) -> Tuple[nn.Sequential, int]:
|
||||
"""
|
||||
Create the neural network that will be used to extract features
|
||||
for the SDE exploration function.
|
||||
for the gSDE exploration function.
|
||||
|
||||
:param features_dim: (int)
|
||||
:param sde_net_arch: ([int])
|
||||
|
|
|
|||
|
|
@ -67,14 +67,12 @@ class VecVideoRecorder(VecEnvWrapper):
|
|||
def start_video_recorder(self):
|
||||
self.close_video_recorder()
|
||||
|
||||
video_name = '{}-step-{}-to-step-{}'.format(self.name_prefix, self.step_id,
|
||||
self.step_id + self.video_length)
|
||||
video_name = f'{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}'
|
||||
base_path = os.path.join(self.video_folder, video_name)
|
||||
self.video_recorder = video_recorder.VideoRecorder(
|
||||
env=self.env,
|
||||
base_path=base_path,
|
||||
metadata={'step_id': self.step_id}
|
||||
)
|
||||
self.video_recorder = video_recorder.VideoRecorder(env=self.env,
|
||||
base_path=base_path,
|
||||
metadata={'step_id': self.step_id}
|
||||
)
|
||||
|
||||
self.video_recorder.capture_frame()
|
||||
self.recorded_frames = 1
|
||||
|
|
|
|||
|
|
@ -28,15 +28,15 @@ class PPOPolicy(BasePolicy):
|
|||
:param use_sde: (bool) Whether to use State Dependent Exploration or not
|
||||
:param log_std_init: (float) Initial value for the log standard deviation
|
||||
:param full_std: (bool) Whether to use (n_features x n_actions) parameters
|
||||
for the std instead of only (n_features,) when using SDE
|
||||
for the std instead of only (n_features,) when using gSDE
|
||||
:param sde_net_arch: ([int]) Network architecture for extracting features
|
||||
when using SDE. If None, the latent features from the policy will be used.
|
||||
when using gSDE. If None, the latent features from the policy will be used.
|
||||
Pass an empty list to use the states as features.
|
||||
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` to ensure
|
||||
a positive standard deviation (cf paper). It allows to keep variance
|
||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||
:param squash_output: (bool) Whether to squash the output using a tanh function,
|
||||
this allows to ensure boundaries when using SDE.
|
||||
this allows to ensure boundaries when using gSDE.
|
||||
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
|
||||
:param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments
|
||||
to pass to the feature extractor.
|
||||
|
|
@ -100,7 +100,7 @@ class PPOPolicy(BasePolicy):
|
|||
self.normalize_images = normalize_images
|
||||
self.log_std_init = log_std_init
|
||||
dist_kwargs = None
|
||||
# Keyword arguments for SDE distribution
|
||||
# Keyword arguments for gSDE distribution
|
||||
if use_sde:
|
||||
dist_kwargs = {
|
||||
'full_std': full_std,
|
||||
|
|
@ -147,7 +147,7 @@ class PPOPolicy(BasePolicy):
|
|||
:param n_envs: (int)
|
||||
"""
|
||||
assert isinstance(self.action_dist,
|
||||
StateDependentNoiseDistribution), 'reset_noise() is only available when using SDE'
|
||||
StateDependentNoiseDistribution), 'reset_noise() is only available when using gSDE'
|
||||
self.action_dist.sample_weights(self.log_std, batch_size=n_envs)
|
||||
|
||||
def _build(self, lr_schedule: Callable) -> None:
|
||||
|
|
@ -162,7 +162,7 @@ class PPOPolicy(BasePolicy):
|
|||
|
||||
latent_dim_pi = self.mlp_extractor.latent_dim_pi
|
||||
|
||||
# Separate feature extractor for SDE
|
||||
# Separate feature extractor for gSDE
|
||||
if self.sde_net_arch is not None:
|
||||
self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor(self.features_dim,
|
||||
self.sde_net_arch,
|
||||
|
|
@ -221,7 +221,7 @@ class PPOPolicy(BasePolicy):
|
|||
|
||||
:param obs: (th.Tensor) Observation
|
||||
:return: (Tuple[th.Tensor, th.Tensor, th.Tensor]) Latent codes
|
||||
for the actor, the value function and for SDE function
|
||||
for the actor, the value function and for gSDE function
|
||||
"""
|
||||
# Preprocess the observation if needed
|
||||
features = self.extract_features(obs)
|
||||
|
|
@ -238,7 +238,7 @@ class PPOPolicy(BasePolicy):
|
|||
Retrieve action distribution given the latent codes.
|
||||
|
||||
:param latent_pi: (th.Tensor) Latent code for the actor
|
||||
:param latent_sde: (Optional[th.Tensor]) Latent code for the SDE exploration function
|
||||
:param latent_sde: (Optional[th.Tensor]) Latent code for the gSDE exploration function
|
||||
:return: (Distribution) Action distribution
|
||||
"""
|
||||
mean_actions = self.action_net(latent_pi)
|
||||
|
|
@ -302,15 +302,15 @@ class CnnPolicy(PPOPolicy):
|
|||
:param use_sde: (bool) Whether to use State Dependent Exploration or not
|
||||
:param log_std_init: (float) Initial value for the log standard deviation
|
||||
:param full_std: (bool) Whether to use (n_features x n_actions) parameters
|
||||
for the std instead of only (n_features,) when using SDE
|
||||
for the std instead of only (n_features,) when using gSDE
|
||||
:param sde_net_arch: ([int]) Network architecture for extracting features
|
||||
when using SDE. If None, the latent features from the policy will be used.
|
||||
when using gSDE. If None, the latent features from the policy will be used.
|
||||
Pass an empty list to use the states as features.
|
||||
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` to ensure
|
||||
a positive standard deviation (cf paper). It allows to keep variance
|
||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||
:param squash_output: (bool) Whether to squash the output using a tanh function,
|
||||
this allows to ensure boundaries when using SDE.
|
||||
this allows to ensure boundaries when using gSDE.
|
||||
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
|
||||
:param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments
|
||||
to pass to the feature extractor.
|
||||
|
|
|
|||
|
|
@ -55,9 +55,9 @@ class PPO(BaseRLModel):
|
|||
:param ent_coef: (float) Entropy coefficient for the loss calculation
|
||||
:param vf_coef: (float) Value function coefficient for the loss calculation
|
||||
:param max_grad_norm: (float) The maximum value for the gradient clipping
|
||||
:param use_sde: (bool) Whether to use State Dependent Exploration (SDE)
|
||||
:param use_sde: (bool) Whether to use generalized State Dependent Exploration (gSDE)
|
||||
instead of action noise exploration (default: False)
|
||||
:param sde_sample_freq: (int) Sample a new noise matrix every n steps when using SDE
|
||||
:param sde_sample_freq: (int) Sample a new noise matrix every n steps when using gSDE
|
||||
Default: -1 (only sample at the beginning of the rollout)
|
||||
:param target_kl: (float) Limit the KL divergence between updates,
|
||||
because the clipping is not enough to prevent large update
|
||||
|
|
|
|||
|
|
@ -29,14 +29,14 @@ class Actor(BasePolicy):
|
|||
:param use_sde: (bool) Whether to use State Dependent Exploration or not
|
||||
:param log_std_init: (float) Initial value for the log standard deviation
|
||||
:param full_std: (bool) Whether to use (n_features x n_actions) parameters
|
||||
for the std instead of only (n_features,) when using SDE.
|
||||
for the std instead of only (n_features,) when using gSDE.
|
||||
:param sde_net_arch: ([int]) Network architecture for extracting features
|
||||
when using SDE. If None, the latent features from the policy will be used.
|
||||
when using gSDE. If None, the latent features from the policy will be used.
|
||||
Pass an empty list to use the states as features.
|
||||
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using SDE to ensure
|
||||
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
|
||||
a positive standard deviation (cf paper). It allows to keep variance
|
||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||
:param clip_mean: (float) Clip the mean output when using SDE to avoid numerical instability.
|
||||
:param clip_mean: (float) Clip the mean output when using gSDE to avoid numerical instability.
|
||||
:param normalize_images: (bool) Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
:param device: (Union[th.device, str]) Device on which the code should run.
|
||||
|
|
@ -82,7 +82,7 @@ class Actor(BasePolicy):
|
|||
|
||||
if self.use_sde:
|
||||
latent_sde_dim = last_layer_dim
|
||||
# Separate feature extractor for SDE
|
||||
# Separate feature extractor for gSDE
|
||||
if sde_net_arch is not None:
|
||||
self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor(features_dim, sde_net_arch,
|
||||
activation_fn)
|
||||
|
|
@ -121,7 +121,7 @@ class Actor(BasePolicy):
|
|||
def get_std(self) -> th.Tensor:
|
||||
"""
|
||||
Retrieve the standard deviation of the action distribution.
|
||||
Only useful when using SDE.
|
||||
Only useful when using gSDE.
|
||||
It corresponds to ``th.exp(log_std)`` in the normal case,
|
||||
but is slightly different when using ``expln`` function
|
||||
(cf StateDependentNoiseDistribution doc).
|
||||
|
|
@ -129,17 +129,17 @@ class Actor(BasePolicy):
|
|||
:return: (th.Tensor)
|
||||
"""
|
||||
assert isinstance(self.action_dist, StateDependentNoiseDistribution), \
|
||||
'get_std() is only available when using SDE'
|
||||
'get_std() is only available when using gSDE'
|
||||
return self.action_dist.get_std(self.log_std)
|
||||
|
||||
def reset_noise(self, batch_size: int = 1) -> None:
|
||||
"""
|
||||
Sample new weights for the exploration matrix, when using SDE.
|
||||
Sample new weights for the exploration matrix, when using gSDE.
|
||||
|
||||
:param batch_size: (int)
|
||||
"""
|
||||
assert isinstance(self.action_dist, StateDependentNoiseDistribution), \
|
||||
'reset_noise() is only available when using SDE'
|
||||
'reset_noise() is only available when using gSDE'
|
||||
self.action_dist.sample_weights(self.log_std, batch_size=batch_size)
|
||||
|
||||
def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
|
||||
|
|
@ -241,12 +241,12 @@ class SACPolicy(BasePolicy):
|
|||
:param use_sde: (bool) Whether to use State Dependent Exploration or not
|
||||
:param log_std_init: (float) Initial value for the log standard deviation
|
||||
:param sde_net_arch: ([int]) Network architecture for extracting features
|
||||
when using SDE. If None, the latent features from the policy will be used.
|
||||
when using gSDE. If None, the latent features from the policy will be used.
|
||||
Pass an empty list to use the states as features.
|
||||
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using SDE to ensure
|
||||
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
|
||||
a positive standard deviation (cf paper). It allows to keep variance
|
||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||
:param clip_mean: (float) Clip the mean output when using SDE to avoid numerical instability.
|
||||
:param clip_mean: (float) Clip the mean output when using gSDE to avoid numerical instability.
|
||||
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
|
||||
:param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments
|
||||
to pass to the feature extractor.
|
||||
|
|
@ -383,12 +383,12 @@ class CnnPolicy(SACPolicy):
|
|||
:param use_sde: (bool) Whether to use State Dependent Exploration or not
|
||||
:param log_std_init: (float) Initial value for the log standard deviation
|
||||
:param sde_net_arch: ([int]) Network architecture for extracting features
|
||||
when using SDE. If None, the latent features from the policy will be used.
|
||||
when using gSDE. If None, the latent features from the policy will be used.
|
||||
Pass an empty list to use the states as features.
|
||||
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using SDE to ensure
|
||||
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
|
||||
a positive standard deviation (cf paper). It allows to keep variance
|
||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||
:param clip_mean: (float) Clip the mean output when using SDE to avoid numerical instability.
|
||||
:param clip_mean: (float) Clip the mean output when using gSDE to avoid numerical instability.
|
||||
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
|
||||
:param normalize_images: (bool) Whether to normalize images or not,
|
||||
dividing by 255.0 (True by default)
|
||||
|
|
|
|||
|
|
@ -46,11 +46,11 @@ class SAC(OffPolicyRLModel):
|
|||
Set it to 'auto' to learn it automatically (and 'auto_0.1' for using 0.1 as initial value)
|
||||
:param target_update_interval: (int) update the target network every ``target_network_update_freq`` steps.
|
||||
:param target_entropy: (str or float) target entropy when learning ``ent_coef`` (``ent_coef = 'auto'``)
|
||||
:param use_sde: (bool) Whether to use State Dependent Exploration (SDE)
|
||||
:param use_sde: (bool) Whether to use generalized State Dependent Exploration (gSDE)
|
||||
instead of action noise exploration (default: False)
|
||||
:param sde_sample_freq: (int) Sample a new noise matrix every n steps when using SDE
|
||||
:param sde_sample_freq: (int) Sample a new noise matrix every n steps when using gSDE
|
||||
Default: -1 (only sample at the beginning of the rollout)
|
||||
:param use_sde_at_warmup: (bool) Whether to use SDE instead of uniform sampling
|
||||
:param use_sde_at_warmup: (bool) Whether to use gSDE instead of uniform sampling
|
||||
during the warm up phase (before learning starts)
|
||||
:param create_eval_env: (bool) Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
0.6.0a7
|
||||
0.6.0a8
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import gym
|
|||
|
||||
from stable_baselines3 import A2C, PPO, SAC, TD3
|
||||
from stable_baselines3.common.callbacks import (CallbackList, CheckpointCallback, EvalCallback,
|
||||
EveryNTimesteps, StopTrainingOnRewardThreshold)
|
||||
EveryNTimesteps, StopTrainingOnRewardThreshold)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3])
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ import torch as th
|
|||
|
||||
from stable_baselines3 import A2C, PPO
|
||||
from stable_baselines3.common.distributions import (DiagGaussianDistribution, TanhBijector,
|
||||
StateDependentNoiseDistribution,
|
||||
CategoricalDistribution, SquashedDiagGaussianDistribution)
|
||||
StateDependentNoiseDistribution,
|
||||
CategoricalDistribution, SquashedDiagGaussianDistribution)
|
||||
from stable_baselines3.common.utils import set_random_seed
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import numpy as np
|
|||
from stable_baselines3.common.env_checker import check_env
|
||||
from stable_baselines3.common.bit_flipping_env import BitFlippingEnv
|
||||
from stable_baselines3.common.identity_env import (IdentityEnv, IdentityEnvBox, FakeImageEnv,
|
||||
IdentityEnvMultiBinary, IdentityEnvMultiDiscrete,)
|
||||
IdentityEnvMultiBinary, IdentityEnvMultiDiscrete)
|
||||
|
||||
ENV_CLASSES = [BitFlippingEnv, IdentityEnv, IdentityEnvBox, IdentityEnvMultiBinary,
|
||||
IdentityEnvMultiDiscrete, FakeImageEnv]
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@ import pytest
|
|||
import numpy as np
|
||||
|
||||
from stable_baselines3.common.logger import (make_output_format, read_csv, read_json, DEBUG, ScopedConfigure,
|
||||
info, debug, set_level, configure, logkv, logkvs,
|
||||
dumpkvs, logkv_mean, warn, error, reset)
|
||||
info, debug, set_level, configure, logkv, logkvs,
|
||||
dumpkvs, logkv_mean, warn, error, reset)
|
||||
|
||||
KEY_VALUES = {
|
||||
"test": 1,
|
||||
|
|
|
|||
|
|
@ -187,7 +187,7 @@ def test_save_load_policy(model_class, policy_str):
|
|||
|
||||
# create model
|
||||
model = model_class(policy_str, env, policy_kwargs=dict(net_arch=[16]),
|
||||
verbose=1, **kwargs)
|
||||
verbose=1, **kwargs)
|
||||
model.learn(total_timesteps=500, eval_freq=250)
|
||||
|
||||
env.reset()
|
||||
|
|
|
|||
Loading…
Reference in a new issue