Documentation update and style fixes (#21)

* Update doc: add gSDE

* Fix codestyle

* Remove travis script

* Add lint check to gitlab
This commit is contained in:
Antonin RAFFIN 2020-05-15 13:54:06 +02:00 committed by GitHub
parent 54f6f5b6fb
commit 15ff6d47ee
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 88 additions and 166 deletions

View file

@ -7,7 +7,7 @@
<!--- Why is this change required? What problem does it solve? -->
<!--- If it fixes an open issue, please link to the issue here. -->
<!--- You can use the syntax `closes #100` if this solves the issue #100 -->
- [ ] I have raised an issue to propose this change ([required](https://github.com/hill-a/stable-baselines/blob/master/CONTRIBUTING.md) for new features and bug fixes)
- [ ] I have raised an issue to propose this change ([required](https://github.com/DLR-RM/stable-baselines3/blob/master/CONTRIBUTING.md) for new features and bug fixes)
## Types of changes
<!--- What types of changes does your code introduce? Put an `x` in all the boxes that apply: -->
@ -19,7 +19,7 @@
## Checklist:
<!--- Go over all the following points, and put an `x` in all the boxes that apply. -->
<!--- If you're unsure about any of these, don't hesitate to ask. We're here to help! -->
- [ ] I've read the [CONTRIBUTION](https://github.com/hill-a/stable-baselines/blob/master/CONTRIBUTING.md) guide (**required**)
- [ ] I've read the [CONTRIBUTION](https://github.com/DLR-RM/stable-baselines3/blob/master/CONTRIBUTING.md) guide (**required**)
- [ ] I have updated the changelog accordingly (**required**).
- [ ] My change requires a change to the documentation.
- [ ] I have updated the tests accordingly (*required for a bug fix or a new feature*).

View file

@ -12,3 +12,8 @@ pytest:
doc-build:
script:
- make doc
lint-check:
script:
- pip install flake8 # TODO: remove when new version on Pypi
- make lint

View file

@ -8,19 +8,9 @@ version: 2
sphinx:
configuration: docs/conf.py
# Build documentation with MkDocs
#mkdocs:
# configuration: mkdocs.yml
# Optionally build your docs in additional formats such as PDF and ePub
formats: all
# Set requirements using conda env
conda:
environment: docs/conda_env.yml
# Optionally set the version of Python and requirements required to build your docs
# python:
# version: 3.7
# install:
# - requirements: docs/requirements.txt

View file

@ -1,49 +0,0 @@
language: python
python:
- "3.6"
env:
global:
- DOCKER_IMAGE=stablebaselines/stable-baselines3-cpu:0.6.0a5
notifications:
email: false
services:
- docker
install:
- docker pull ${DOCKER_IMAGE}
script:
- ./scripts/run_tests_travis.sh "${TEST_GLOB}"
jobs:
include:
# Big test suite. Run in parallel to decrease wall-clock time, and to avoid OOM error from leaks
- stage: Test
name: "Unit Tests a-h"
env: TEST_GLOB="[a-h]*"
- name: "Unit Tests i-l"
env: TEST_GLOB="[i-l]*"
- name: "Unit Tests m-sa"
env: TEST_GLOB="{[m-r]*,sa*}"
- name: "Unit Tests sb-z"
env: TEST_GLOB="{s[b-z]*,[t-z]*}"
- name: "Sphinx Documentation"
script:
- 'docker run -it --rm --mount src=$(pwd),target=/root/code/stable-baselines3,type=bind ${DOCKER_IMAGE} bash -c "cd /root/code/stable-baselines3/ && pushd docs/ && make clean && make html"'
- name: "Type Checking"
script:
- 'docker run --rm --mount src=$(pwd),target=/root/code/stable-baselines3,type=bind ${DOCKER_IMAGE} bash -c "cd /root/code/stable-baselines3/ && pytype --version && pytype"'
- stage: Codacy Trigger
if: type != pull_request
script:
# When all test coverage reports have been uploaded, instruct Codacy to start analysis.
- 'docker run -it --rm --network host --ipc=host --mount src=$(pwd),target=/root/code/stable-baselines3,type=bind --env CODACY_PROJECT_TOKEN=${CODACY_PROJECT_TOKEN} ${DOCKER_IMAGE} bash -c "cd /root/code/stable-baselines3/ && java -jar /root/code/codacy-coverage-reporter.jar final"'

View file

@ -177,18 +177,24 @@ Actions `gym.spaces`:
## Testing the installation
All unit tests in stable baselines3 can be run using pytest runner:
All unit tests in stable baselines3 can be run using `pytest` runner:
```
pip install pytest pytest-cov
make pytest
```
You can also do a static type check using pytype:
You can also do a static type check using `pytype`:
```
pip install pytype
make type
```
Codestyle check with `flake8`:
```
pip install flake8
make lint
```
## Projects Using Stable-Baselines3
We try to maintain a list of project using stable-baselines3 in the [documentation](https://stable-baselines3.readthedocs.io/en/master/misc/projects.html),

View file

@ -77,7 +77,7 @@ State-Dependent Exploration
State-Dependent Exploration (SDE) is a type of exploration that allows to use RL directly on real robots,
that was the starting point for the Stable-Baselines3 library.
I (@araffin) will publish a paper about a generalized version of SDE (the one implemented in SB3) soon.
I (@araffin) published a paper about a generalized version of SDE (the one implemented in SB3): https://arxiv.org/abs/2005.05719
Misc
====

View file

@ -3,7 +3,7 @@
Changelog
==========
Pre-Release 0.6.0a7 (WIP)
Pre-Release 0.6.0a8 (WIP)
------------------------------
@ -41,6 +41,7 @@ Documentation:
^^^^^^^^^^^^^^
- Added most documentation (adapted from Stable-Baselines)
- Added link to CONTRIBUTING.md in the README (@kinalmehta)
- Added gSDE project and update docstrings accordingly
Pre-Release 0.5.0 (2020-05-05)

View file

@ -16,11 +16,12 @@ Please tell us, if you want your project to appear on this page ;)
.. | Github repo: https://github.com/araffin/RL-Racing-Robot
.. Generalized State Dependent Exploration for Deep Reinforcement Learning in Robotics
.. -----------------------------------------------------------------------------------
..
.. An exploration method to train RL agent directly on real robots.
..
.. | Author: Antonin Raffin, Freek Stulp
.. | Github: https://github.com/DLR-RM/stable-baselines3/tree/sde
.. | Paper:
Generalized State Dependent Exploration for Deep Reinforcement Learning in Robotics
-----------------------------------------------------------------------------------
An exploration method to train RL agent directly on real robots.
It was the starting point of Stable-Baselines3.
| Author: Antonin Raffin, Freek Stulp
| Github: https://github.com/DLR-RM/stable-baselines3/tree/sde
| Paper: https://arxiv.org/abs/2005.05719

View file

@ -1,32 +0,0 @@
#!/usr/bin/env bash
DOCKER_CMD="docker run -it --rm --network host --ipc=host --mount src=$(pwd),target=/root/code/stable-baselines3,type=bind"
BASH_CMD="cd /root/code/stable-baselines3/"
if [[ $# -ne 1 ]]; then
echo "usage: $0 <test glob>"
exit 1
fi
if [[ ${DOCKER_IMAGE} = "" ]]; then
echo "Need DOCKER_IMAGE environment variable to be set."
exit 1
fi
TEST_GLOB=$1
set -e # exit immediately on any error
# For pull requests from fork, Codacy token is not available, leading to build failure
if [[ ${CODACY_PROJECT_TOKEN} = "" ]]; then
echo "WARNING: CODACY_PROJECT_TOKEN not set. Skipping Codacy upload."
echo "(This is normal when building in a fork and can be ignored.)"
${DOCKER_CMD} ${DOCKER_IMAGE} \
bash -c "${BASH_CMD} && \
pytest --cov-config .coveragerc --cov-report term --cov=. -v tests/test_${TEST_GLOB}"
else
${DOCKER_CMD} --env CODACY_PROJECT_TOKEN=${CODACY_PROJECT_TOKEN} ${DOCKER_IMAGE} \
bash -c "${BASH_CMD} && \
pytest --cov-config .coveragerc --cov-report term --cov-report xml --cov=. -v tests/test_${TEST_GLOB} && \
java -jar /root/code/codacy-coverage-reporter.jar report -l python -r coverage.xml --partial"
fi

View file

@ -34,9 +34,9 @@ class A2C(PPO):
:param rms_prop_eps: (float) RMSProp epsilon. It stabilizes square root computation in denominator
of RMSProp update
:param use_rms_prop: (bool) Whether to use RMSprop (default) or Adam as optimizer
:param use_sde: (bool) Whether to use State Dependent Exploration (SDE)
:param use_sde: (bool) Whether to use generalized State Dependent Exploration (gSDE)
instead of action noise exploration (default: False)
:param sde_sample_freq: (int) Sample a new noise matrix every n steps when using SDE
:param sde_sample_freq: (int) Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param normalize_advantage: (bool) Whether to normalize or not the advantage
:param tensorboard_log: (str) the log location for tensorboard (if None, no logging)

View file

@ -46,9 +46,9 @@ class BaseRLModel(ABC):
:param monitor_wrapper: (bool) When creating an environment, whether to wrap it
or not in a Monitor wrapper.
:param seed: (Optional[int]) Seed for the pseudo random generators
:param use_sde: (bool) Whether to use State Dependent Exploration (SDE)
:param use_sde: (bool) Whether to use generalized State Dependent Exploration (gSDE)
instead of action noise exploration (default: False)
:param sde_sample_freq: (int) Sample a new noise matrix every n steps when using SDE
:param sde_sample_freq: (int) Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
"""
@ -96,7 +96,7 @@ class BaseRLModel(ABC):
# When using VecNormalize:
self._last_original_obs = None # type: Optional[np.ndarray]
self._episode_num = 0
# Used for SDE only
# Used for gSDE only
self.use_sde = use_sde
self.sde_sample_freq = sde_sample_freq
# Track the training progress (from 1 to 0)
@ -681,11 +681,11 @@ class OffPolicyRLModel(BaseRLModel):
:param seed: Seed for the pseudo random generators
:param use_sde: Whether to use State Dependent Exploration (SDE)
instead of action noise exploration (default: False)
:param sde_sample_freq: Sample a new noise matrix every n steps when using SDE
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param use_sde_at_warmup: (bool) Whether to use SDE instead of uniform sampling
:param use_sde_at_warmup: (bool) Whether to use gSDE instead of uniform sampling
during the warm up phase (before learning starts)
:param sde_support: (bool) Whether the model support SDE or not
:param sde_support: (bool) Whether the model support gSDE or not
"""
def __init__(self,
@ -721,7 +721,7 @@ class OffPolicyRLModel(BaseRLModel):
if sde_support:
self.policy_kwargs['use_sde'] = self.use_sde
self.policy_kwargs['device'] = self.device
# For SDE only
# For gSDE only
self.use_sde_at_warmup = use_sde_at_warmup
def _setup_model(self):

View file

@ -276,7 +276,7 @@ class EvalCallback(EventCallback):
def _init_callback(self):
# Does not work in some corner cases, where the wrapper is not the same
if not type(self.training_env) is type(self.eval_env):
if not isinstance(self.training_env, type(self.eval_env)):
warnings.warn("Training and eval env are not of the same type"
f"{self.training_env} != {self.eval_env}")

View file

@ -294,7 +294,9 @@ class CategoricalDistribution(Distribution):
class StateDependentNoiseDistribution(Distribution):
"""
Distribution class for using State Dependent Exploration (SDE).
Distribution class for using generalized State Dependent Exploration (gSDE).
Paper: https://arxiv.org/abs/2005.05719
It is used to create the noise exploration matrix and
compute the log probability of an action with that noise.
@ -306,7 +308,7 @@ class StateDependentNoiseDistribution(Distribution):
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param squash_output: (bool) Whether to squash the output using a tanh function,
this allows to ensure boundaries.
:param learn_features: (bool) Whether to learn features for SDE or not.
:param learn_features: (bool) Whether to learn features for gSDE or not.
This will enable gradients to be backpropagated through the features
``latent_sde`` in the code.
:param epsilon: (float) small value to avoid NaN due to numerical imprecision.
@ -346,7 +348,7 @@ class StateDependentNoiseDistribution(Distribution):
:return: (th.Tensor)
"""
if self.use_expln:
# From SDE paper, it allows to keep variance
# From gSDE paper, it allows to keep variance
# above zero and prevent it from growing too fast
below_threshold = th.exp(log_std) * (log_std <= 0)
# Avoid NaN: zeros values that are below zero
@ -387,7 +389,7 @@ class StateDependentNoiseDistribution(Distribution):
:param latent_dim: (int) Dimension of the last layer of the policy (before the action layer)
:param log_std_init: (float) Initial value for the log standard deviation
:param latent_sde_dim: (Optional[int]) Dimension of the last layer of the feature extractor
for SDE. By default, it is shared with the policy network.
for gSDE. By default, it is shared with the policy network.
:return: (nn.Linear, nn.Parameter)
"""
# Network for the deterministic action, it represents the mean of the distribution

View file

@ -416,7 +416,7 @@ def create_sde_features_extractor(features_dim: int,
activation_fn: Type[nn.Module]) -> Tuple[nn.Sequential, int]:
"""
Create the neural network that will be used to extract features
for the SDE exploration function.
for the gSDE exploration function.
:param features_dim: (int)
:param sde_net_arch: ([int])

View file

@ -67,14 +67,12 @@ class VecVideoRecorder(VecEnvWrapper):
def start_video_recorder(self):
self.close_video_recorder()
video_name = '{}-step-{}-to-step-{}'.format(self.name_prefix, self.step_id,
self.step_id + self.video_length)
video_name = f'{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}'
base_path = os.path.join(self.video_folder, video_name)
self.video_recorder = video_recorder.VideoRecorder(
env=self.env,
base_path=base_path,
metadata={'step_id': self.step_id}
)
self.video_recorder = video_recorder.VideoRecorder(env=self.env,
base_path=base_path,
metadata={'step_id': self.step_id}
)
self.video_recorder.capture_frame()
self.recorded_frames = 1

View file

@ -28,15 +28,15 @@ class PPOPolicy(BasePolicy):
:param use_sde: (bool) Whether to use State Dependent Exploration or not
:param log_std_init: (float) Initial value for the log standard deviation
:param full_std: (bool) Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using SDE
for the std instead of only (n_features,) when using gSDE
:param sde_net_arch: ([int]) Network architecture for extracting features
when using SDE. If None, the latent features from the policy will be used.
when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param squash_output: (bool) Whether to squash the output using a tanh function,
this allows to ensure boundaries when using SDE.
this allows to ensure boundaries when using gSDE.
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
:param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments
to pass to the feature extractor.
@ -100,7 +100,7 @@ class PPOPolicy(BasePolicy):
self.normalize_images = normalize_images
self.log_std_init = log_std_init
dist_kwargs = None
# Keyword arguments for SDE distribution
# Keyword arguments for gSDE distribution
if use_sde:
dist_kwargs = {
'full_std': full_std,
@ -147,7 +147,7 @@ class PPOPolicy(BasePolicy):
:param n_envs: (int)
"""
assert isinstance(self.action_dist,
StateDependentNoiseDistribution), 'reset_noise() is only available when using SDE'
StateDependentNoiseDistribution), 'reset_noise() is only available when using gSDE'
self.action_dist.sample_weights(self.log_std, batch_size=n_envs)
def _build(self, lr_schedule: Callable) -> None:
@ -162,7 +162,7 @@ class PPOPolicy(BasePolicy):
latent_dim_pi = self.mlp_extractor.latent_dim_pi
# Separate feature extractor for SDE
# Separate feature extractor for gSDE
if self.sde_net_arch is not None:
self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor(self.features_dim,
self.sde_net_arch,
@ -221,7 +221,7 @@ class PPOPolicy(BasePolicy):
:param obs: (th.Tensor) Observation
:return: (Tuple[th.Tensor, th.Tensor, th.Tensor]) Latent codes
for the actor, the value function and for SDE function
for the actor, the value function and for gSDE function
"""
# Preprocess the observation if needed
features = self.extract_features(obs)
@ -238,7 +238,7 @@ class PPOPolicy(BasePolicy):
Retrieve action distribution given the latent codes.
:param latent_pi: (th.Tensor) Latent code for the actor
:param latent_sde: (Optional[th.Tensor]) Latent code for the SDE exploration function
:param latent_sde: (Optional[th.Tensor]) Latent code for the gSDE exploration function
:return: (Distribution) Action distribution
"""
mean_actions = self.action_net(latent_pi)
@ -302,15 +302,15 @@ class CnnPolicy(PPOPolicy):
:param use_sde: (bool) Whether to use State Dependent Exploration or not
:param log_std_init: (float) Initial value for the log standard deviation
:param full_std: (bool) Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using SDE
for the std instead of only (n_features,) when using gSDE
:param sde_net_arch: ([int]) Network architecture for extracting features
when using SDE. If None, the latent features from the policy will be used.
when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param squash_output: (bool) Whether to squash the output using a tanh function,
this allows to ensure boundaries when using SDE.
this allows to ensure boundaries when using gSDE.
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
:param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments
to pass to the feature extractor.

View file

@ -55,9 +55,9 @@ class PPO(BaseRLModel):
:param ent_coef: (float) Entropy coefficient for the loss calculation
:param vf_coef: (float) Value function coefficient for the loss calculation
:param max_grad_norm: (float) The maximum value for the gradient clipping
:param use_sde: (bool) Whether to use State Dependent Exploration (SDE)
:param use_sde: (bool) Whether to use generalized State Dependent Exploration (gSDE)
instead of action noise exploration (default: False)
:param sde_sample_freq: (int) Sample a new noise matrix every n steps when using SDE
:param sde_sample_freq: (int) Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param target_kl: (float) Limit the KL divergence between updates,
because the clipping is not enough to prevent large update

View file

@ -29,14 +29,14 @@ class Actor(BasePolicy):
:param use_sde: (bool) Whether to use State Dependent Exploration or not
:param log_std_init: (float) Initial value for the log standard deviation
:param full_std: (bool) Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,) when using SDE.
for the std instead of only (n_features,) when using gSDE.
:param sde_net_arch: ([int]) Network architecture for extracting features
when using SDE. If None, the latent features from the policy will be used.
when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using SDE to ensure
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param clip_mean: (float) Clip the mean output when using SDE to avoid numerical instability.
:param clip_mean: (float) Clip the mean output when using gSDE to avoid numerical instability.
:param normalize_images: (bool) Whether to normalize images or not,
dividing by 255.0 (True by default)
:param device: (Union[th.device, str]) Device on which the code should run.
@ -82,7 +82,7 @@ class Actor(BasePolicy):
if self.use_sde:
latent_sde_dim = last_layer_dim
# Separate feature extractor for SDE
# Separate feature extractor for gSDE
if sde_net_arch is not None:
self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor(features_dim, sde_net_arch,
activation_fn)
@ -121,7 +121,7 @@ class Actor(BasePolicy):
def get_std(self) -> th.Tensor:
"""
Retrieve the standard deviation of the action distribution.
Only useful when using SDE.
Only useful when using gSDE.
It corresponds to ``th.exp(log_std)`` in the normal case,
but is slightly different when using ``expln`` function
(cf StateDependentNoiseDistribution doc).
@ -129,17 +129,17 @@ class Actor(BasePolicy):
:return: (th.Tensor)
"""
assert isinstance(self.action_dist, StateDependentNoiseDistribution), \
'get_std() is only available when using SDE'
'get_std() is only available when using gSDE'
return self.action_dist.get_std(self.log_std)
def reset_noise(self, batch_size: int = 1) -> None:
"""
Sample new weights for the exploration matrix, when using SDE.
Sample new weights for the exploration matrix, when using gSDE.
:param batch_size: (int)
"""
assert isinstance(self.action_dist, StateDependentNoiseDistribution), \
'reset_noise() is only available when using SDE'
'reset_noise() is only available when using gSDE'
self.action_dist.sample_weights(self.log_std, batch_size=batch_size)
def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
@ -241,12 +241,12 @@ class SACPolicy(BasePolicy):
:param use_sde: (bool) Whether to use State Dependent Exploration or not
:param log_std_init: (float) Initial value for the log standard deviation
:param sde_net_arch: ([int]) Network architecture for extracting features
when using SDE. If None, the latent features from the policy will be used.
when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using SDE to ensure
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param clip_mean: (float) Clip the mean output when using SDE to avoid numerical instability.
:param clip_mean: (float) Clip the mean output when using gSDE to avoid numerical instability.
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
:param features_extractor_kwargs: (Optional[Dict[str, Any]]) Keyword arguments
to pass to the feature extractor.
@ -383,12 +383,12 @@ class CnnPolicy(SACPolicy):
:param use_sde: (bool) Whether to use State Dependent Exploration or not
:param log_std_init: (float) Initial value for the log standard deviation
:param sde_net_arch: ([int]) Network architecture for extracting features
when using SDE. If None, the latent features from the policy will be used.
when using gSDE. If None, the latent features from the policy will be used.
Pass an empty list to use the states as features.
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using SDE to ensure
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure
a positive standard deviation (cf paper). It allows to keep variance
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
:param clip_mean: (float) Clip the mean output when using SDE to avoid numerical instability.
:param clip_mean: (float) Clip the mean output when using gSDE to avoid numerical instability.
:param features_extractor_class: (Type[BaseFeaturesExtractor]) Features extractor to use.
:param normalize_images: (bool) Whether to normalize images or not,
dividing by 255.0 (True by default)

View file

@ -46,11 +46,11 @@ class SAC(OffPolicyRLModel):
Set it to 'auto' to learn it automatically (and 'auto_0.1' for using 0.1 as initial value)
:param target_update_interval: (int) update the target network every ``target_network_update_freq`` steps.
:param target_entropy: (str or float) target entropy when learning ``ent_coef`` (``ent_coef = 'auto'``)
:param use_sde: (bool) Whether to use State Dependent Exploration (SDE)
:param use_sde: (bool) Whether to use generalized State Dependent Exploration (gSDE)
instead of action noise exploration (default: False)
:param sde_sample_freq: (int) Sample a new noise matrix every n steps when using SDE
:param sde_sample_freq: (int) Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param use_sde_at_warmup: (bool) Whether to use SDE instead of uniform sampling
:param use_sde_at_warmup: (bool) Whether to use gSDE instead of uniform sampling
during the warm up phase (before learning starts)
:param create_eval_env: (bool) Whether to create a second environment that will be
used for evaluating the agent periodically. (Only available when passing string for the environment)

View file

@ -1 +1 @@
0.6.0a7
0.6.0a8

View file

@ -6,7 +6,7 @@ import gym
from stable_baselines3 import A2C, PPO, SAC, TD3
from stable_baselines3.common.callbacks import (CallbackList, CheckpointCallback, EvalCallback,
EveryNTimesteps, StopTrainingOnRewardThreshold)
EveryNTimesteps, StopTrainingOnRewardThreshold)
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3])

View file

@ -3,8 +3,8 @@ import torch as th
from stable_baselines3 import A2C, PPO
from stable_baselines3.common.distributions import (DiagGaussianDistribution, TanhBijector,
StateDependentNoiseDistribution,
CategoricalDistribution, SquashedDiagGaussianDistribution)
StateDependentNoiseDistribution,
CategoricalDistribution, SquashedDiagGaussianDistribution)
from stable_baselines3.common.utils import set_random_seed

View file

@ -6,7 +6,7 @@ import numpy as np
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.bit_flipping_env import BitFlippingEnv
from stable_baselines3.common.identity_env import (IdentityEnv, IdentityEnvBox, FakeImageEnv,
IdentityEnvMultiBinary, IdentityEnvMultiDiscrete,)
IdentityEnvMultiBinary, IdentityEnvMultiDiscrete)
ENV_CLASSES = [BitFlippingEnv, IdentityEnv, IdentityEnvBox, IdentityEnvMultiBinary,
IdentityEnvMultiDiscrete, FakeImageEnv]

View file

@ -5,8 +5,8 @@ import pytest
import numpy as np
from stable_baselines3.common.logger import (make_output_format, read_csv, read_json, DEBUG, ScopedConfigure,
info, debug, set_level, configure, logkv, logkvs,
dumpkvs, logkv_mean, warn, error, reset)
info, debug, set_level, configure, logkv, logkvs,
dumpkvs, logkv_mean, warn, error, reset)
KEY_VALUES = {
"test": 1,

View file

@ -187,7 +187,7 @@ def test_save_load_policy(model_class, policy_str):
# create model
model = model_class(policy_str, env, policy_kwargs=dict(net_arch=[16]),
verbose=1, **kwargs)
verbose=1, **kwargs)
model.learn(total_timesteps=500, eval_freq=250)
env.reset()