mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-16 21:10:08 +00:00
Merge pull request #41 from Antonin-Raffin/docs/build
Build documentation
This commit is contained in:
commit
8152b34aaa
19 changed files with 379 additions and 38 deletions
16
README.md
16
README.md
|
|
@ -20,3 +20,19 @@ PyTorch version of [Stable Baselines](https://github.com/hill-a/stable-baselines
|
|||
## Roadmap
|
||||
|
||||
- cf github Roadmap
|
||||
|
||||
|
||||
## Citing the Project
|
||||
|
||||
To cite this repository in publications:
|
||||
|
||||
```
|
||||
@misc{torchy-baselines,
|
||||
author = {Raffin, Antonin and Dormann, Noah and Hill, Ashley and Ernestus, Maximilian and Gleave, Adam and Kanervisto, Anssi},
|
||||
title = {Torchy Baselines},
|
||||
year = {2019},
|
||||
publisher = {GitHub},
|
||||
journal = {GitHub repository},
|
||||
howpublished = {\url{https://github.com/araffin/torchy-baselines}},
|
||||
}
|
||||
```
|
||||
|
|
|
|||
23
docs/conf.py
23
docs/conf.py
|
|
@ -16,6 +16,14 @@ import os
|
|||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# We CANNOT enable 'sphinxcontrib.spelling' because ReadTheDocs.org does not support
|
||||
# PyEnchant.
|
||||
try:
|
||||
import sphinxcontrib.spelling
|
||||
enable_spell_check = True
|
||||
except ImportError:
|
||||
enable_spell_check = False
|
||||
|
||||
# source code directory, relative to this file, for sphinx-autobuild
|
||||
sys.path.insert(0, os.path.abspath('..'))
|
||||
|
||||
|
|
@ -31,16 +39,8 @@ class Mock(MagicMock):
|
|||
# Mock modules that requires C modules
|
||||
# Note: because of that we cannot test examples using CI
|
||||
# 'torch', 'torch.nn', 'torch.nn.functional',
|
||||
MOCK_MODULES = ['joblib', 'scipy', 'scipy.signal',
|
||||
'pandas', 'mpi4py', 'mujoco-py', 'cv2',
|
||||
'tensorflow', 'torch', 'torch.nn', 'torch.nn.functional',
|
||||
'torch.distributions',
|
||||
'tensorflow.contrib', 'tensorflow.contrib.layers',
|
||||
'tensorflow.python', 'tensorflow.python.client', 'tensorflow.python.ops',
|
||||
'tqdm', 'cloudpickle', 'matplotlib', 'matplotlib.pyplot',
|
||||
'seaborn', 'gym', 'gym.spaces', 'gym.core',
|
||||
'tensorflow.core', 'tensorflow.core.util', 'tensorflow.python.util',
|
||||
'gym.wrappers', 'gym.wrappers.monitoring', 'zmq']
|
||||
# DO not mock modules for now, we will need to do that for read the docs later
|
||||
MOCK_MODULES = []
|
||||
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
|
||||
|
||||
|
||||
|
|
@ -76,6 +76,9 @@ extensions = [
|
|||
'sphinx.ext.viewcode',
|
||||
]
|
||||
|
||||
if enable_spell_check:
|
||||
extensions.append('sphinxcontrib.spelling')
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ['_templates']
|
||||
|
||||
|
|
|
|||
|
|
@ -16,8 +16,7 @@ Here is a quick example of how to train and run SAC on a Pendulum environment:
|
|||
from torchy_baselines.common.vec_env import DummyVecEnv
|
||||
from torchy_baselines import SAC
|
||||
|
||||
# The algorithms require a vectorized environment to run
|
||||
env = DummyVecEnv([lambda: gym.make('Pendulum-v0')])
|
||||
env = gym.make('Pendulum-v0')
|
||||
|
||||
model = SAC(MlpPolicy, env, verbose=1)
|
||||
model.learn(total_timesteps=10000)
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ SubprocVecEnv ✔️ ✔️ ✔️ ✔️ ✔️
|
|||
|
||||
When using vectorized environments, the environments are automatically reset at the end of each episode.
|
||||
Thus, the observation returned for the i-th environment when ``done[i]`` is true will in fact be the first observation of the next episode, not the last observation of the episode that has just terminated.
|
||||
You can access the "real" final observation of the terminated episode—that is, the one that accompanied the ``done`` event provided by the underlying environment—using the ``terminal_observation`` keys in the info dicts returned by the vecenv.
|
||||
You can access the "real" final observation of the terminated episode—that is, the one that accompanied the ``done`` event provided by the underlying environment—using the ``terminal_observation`` keys in the info dicts returned by the `VecEnv`.
|
||||
|
||||
.. warning::
|
||||
|
||||
|
|
|
|||
|
|
@ -28,6 +28,8 @@ RL Baselines zoo also offers a simple interface to train, evaluate agents and do
|
|||
:caption: RL Algorithms
|
||||
|
||||
modules/base
|
||||
modules/a2c
|
||||
modules/cem_rl
|
||||
modules/ppo
|
||||
modules/sac
|
||||
modules/td3
|
||||
|
|
@ -47,12 +49,12 @@ To cite this project in publications:
|
|||
.. code-block:: bibtex
|
||||
|
||||
@misc{torchy-baselines,
|
||||
author = {Raffin, Antonin and Hill, Ashley and Ernestus, Maximilian and Gleave, Adam and Kanervisto, Anssi},
|
||||
author = {Raffin, Antonin and Dormann, Noah and Hill, Ashley and Ernestus, Maximilian and Gleave, Adam and Kanervisto, Anssi},
|
||||
title = {Torchy Baselines},
|
||||
year = {2019},
|
||||
publisher = {GitHub},
|
||||
journal = {GitHub repository},
|
||||
howpublished = {\url{https://github.com/hill-a/stable-baselines}},
|
||||
howpublished = {\url{https://github.com/araffin/torchy-baselines}},
|
||||
}
|
||||
|
||||
Indices and tables
|
||||
|
|
|
|||
|
|
@ -3,17 +3,14 @@
|
|||
Changelog
|
||||
==========
|
||||
|
||||
|
||||
Pre-Release 0.0.3a0 (WIP)
|
||||
-------------------------
|
||||
**Initial Release**
|
||||
Pre-Release 0.2.0a0 (WIP)
|
||||
------------------------------
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Initial release of CEM-RL, PPO, SAC and TD3
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
|
@ -22,6 +19,32 @@ Deprecations:
|
|||
^^^^^^^^^^^^^
|
||||
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
- fix documentation build
|
||||
|
||||
|
||||
Pre-Release 0.1.0 (2020-01-20)
|
||||
------------------------------
|
||||
**First Release: base algorithms and state-dependent exploration**
|
||||
|
||||
Breaking Changes:
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Initial release of A2C, CEM-RL, PPO, SAC and TD3, working only with `Box` input space
|
||||
- State-Dependent Exploration (SDE) for A2C, PPO, SAC and TD3
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
Others:
|
||||
^^^^^^^
|
||||
|
||||
|
|
|
|||
73
docs/modules/a2c.rst
Normal file
73
docs/modules/a2c.rst
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
.. _a2c:
|
||||
|
||||
.. automodule:: torchy_baselines.a2c
|
||||
|
||||
|
||||
A2C
|
||||
====
|
||||
|
||||
A synchronous, deterministic variant of `Asynchronous Advantage Actor Critic (A3C) <https://arxiv.org/abs/1602.01783>`_.
|
||||
It uses multiple workers to avoid the use of a replay buffer.
|
||||
|
||||
|
||||
Notes
|
||||
-----
|
||||
|
||||
- Original paper: https://arxiv.org/abs/1602.01783
|
||||
- OpenAI blog post: https://openai.com/blog/baselines-acktr-a2c/
|
||||
|
||||
|
||||
Can I use?
|
||||
----------
|
||||
|
||||
- Recurrent policies: ✔️
|
||||
- Multi processing: ✔️
|
||||
- Gym spaces:
|
||||
|
||||
|
||||
============= ====== ===========
|
||||
Space Action Observation
|
||||
============= ====== ===========
|
||||
Discrete ❌ ❌
|
||||
Box ✔️ ✔️
|
||||
MultiDiscrete ❌ ❌
|
||||
MultiBinary ❌ ❌
|
||||
============= ====== ===========
|
||||
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
Train a A2C agent on `CartPole-v1` using 4 processes.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import gym
|
||||
|
||||
from torchy_baselines.common.policies import MlpPolicy
|
||||
from torchy_baselines.common import make_vec_env
|
||||
from torchy_baselines import A2C
|
||||
|
||||
# Parallel environments
|
||||
env = make_vec_env('CartPole-v1', n_envs=4)
|
||||
|
||||
model = A2C(MlpPolicy, env, verbose=1)
|
||||
model.learn(total_timesteps=25000)
|
||||
model.save("a2c_cartpole")
|
||||
|
||||
del model # remove to demonstrate saving and loading
|
||||
|
||||
model = A2C.load("a2c_cartpole")
|
||||
|
||||
obs = env.reset()
|
||||
while True:
|
||||
action, _states = model.predict(obs)
|
||||
obs, rewards, dones, info = env.step(action)
|
||||
env.render()
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
.. autoclass:: A2C
|
||||
:members:
|
||||
:inherited-members:
|
||||
96
docs/modules/cem_rl.rst
Normal file
96
docs/modules/cem_rl.rst
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
.. _cem_rl:
|
||||
|
||||
.. automodule:: torchy_baselines.cem_rl
|
||||
|
||||
|
||||
CEM RL
|
||||
======
|
||||
|
||||
Combining cross-entropy method (CEM) and Twin Delayed Deep Deterministic policy gradient (TD3).
|
||||
|
||||
|
||||
.. rubric:: Available Policies
|
||||
|
||||
.. autosummary::
|
||||
:nosignatures:
|
||||
|
||||
MlpPolicy
|
||||
|
||||
|
||||
Notes
|
||||
-----
|
||||
|
||||
- Original paper: https://arxiv.org/abs/1810.01222 and https://openreview.net/forum?id=BkeU5j0ctQ
|
||||
- Original Implementation: https://github.com/apourchot/CEM-RL
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
CEM RL is currently implemented for TD3
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
The default policies for CEM RL differ a bit from others MlpPolicy: it uses ReLU instead of tanh activation,
|
||||
to match the original paper
|
||||
|
||||
|
||||
Can I use?
|
||||
----------
|
||||
|
||||
- Recurrent policies: ❌
|
||||
- Multi processing: ❌
|
||||
- Gym spaces:
|
||||
|
||||
|
||||
============= ====== ===========
|
||||
Space Action Observation
|
||||
============= ====== ===========
|
||||
Discrete ❌ ❌
|
||||
Box ✔️ ✔️
|
||||
MultiDiscrete ❌ ❌
|
||||
MultiBinary ❌ ❌
|
||||
============= ====== ===========
|
||||
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import numpy as np
|
||||
|
||||
from torchy_baselines import CEMRL
|
||||
from torchy_baselines.td3.policies import MlpPolicy
|
||||
|
||||
# n_grad = 0 corresponds to CEM (in fact CMA-ES without history)
|
||||
model = CEMRL(MlpPolicy, 'Pendulum-v0', pop_size=10, n_grad=5, verbose=1)
|
||||
model.learn(total_timesteps=50000, log_interval=10)
|
||||
model.save("td3_pendulum")
|
||||
env = model.get_env()
|
||||
|
||||
del model # remove to demonstrate saving and loading
|
||||
|
||||
model = CEMRL.load("td3_pendulum")
|
||||
|
||||
obs = env.reset()
|
||||
while True:
|
||||
action, _states = model.predict(obs)
|
||||
obs, rewards, dones, info = env.step(action)
|
||||
env.render()
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
.. autoclass:: CEMRL
|
||||
:members:
|
||||
:inherited-members:
|
||||
|
||||
.. _cemrl_policies:
|
||||
|
||||
CEM RL Policies
|
||||
---------------
|
||||
|
||||
.. autoclass:: MlpPolicy
|
||||
:members:
|
||||
:inherited-members:
|
||||
|
|
@ -66,7 +66,7 @@ Example
|
|||
|
||||
from torchy_baselines import TD3
|
||||
from torchy_baselines.td3.policies import MlpPolicy
|
||||
from torchy_baselines.ddpg.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
|
||||
from torchy_baselines.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
|
||||
|
||||
# The noise objects for TD3
|
||||
n_actions = env.action_space.shape[-1]
|
||||
|
|
|
|||
115
docs/spelling_wordlist.txt
Normal file
115
docs/spelling_wordlist.txt
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
py
|
||||
env
|
||||
atari
|
||||
argparse
|
||||
Argparse
|
||||
TensorFlow
|
||||
feedforward
|
||||
envs
|
||||
VecEnv
|
||||
pretrain
|
||||
petrained
|
||||
tf
|
||||
np
|
||||
mujoco
|
||||
cpu
|
||||
ndarray
|
||||
ndarrays
|
||||
timestep
|
||||
timesteps
|
||||
stepsize
|
||||
dataset
|
||||
adam
|
||||
fn
|
||||
normalisation
|
||||
Kullback
|
||||
Leibler
|
||||
boolean
|
||||
deserialized
|
||||
pretrained
|
||||
minibatch
|
||||
subprocesses
|
||||
ArgumentParser
|
||||
Tensorflow
|
||||
Gaussian
|
||||
approximator
|
||||
minibatches
|
||||
hyperparameters
|
||||
hyperparameter
|
||||
vectorized
|
||||
rl
|
||||
colab
|
||||
dataloader
|
||||
npz
|
||||
datasets
|
||||
vf
|
||||
logits
|
||||
num
|
||||
Utils
|
||||
backpropagate
|
||||
prepend
|
||||
NaN
|
||||
preprocessing
|
||||
Cloudpickle
|
||||
async
|
||||
multiprocess
|
||||
tensorflow
|
||||
mlp
|
||||
cnn
|
||||
neglogp
|
||||
tanh
|
||||
coef
|
||||
repo
|
||||
Huber
|
||||
params
|
||||
ppo
|
||||
arxiv
|
||||
Arxiv
|
||||
func
|
||||
DQN
|
||||
Uhlenbeck
|
||||
Ornstein
|
||||
multithread
|
||||
cancelled
|
||||
Tensorboard
|
||||
parallelize
|
||||
customising
|
||||
serializable
|
||||
Multiprocessed
|
||||
cartpole
|
||||
toolset
|
||||
lstm
|
||||
rescale
|
||||
ffmpeg
|
||||
avconv
|
||||
unnormalized
|
||||
Github
|
||||
pre
|
||||
preprocess
|
||||
backend
|
||||
attr
|
||||
preprocess
|
||||
Antonin
|
||||
Raffin
|
||||
araffin
|
||||
Homebrew
|
||||
Numpy
|
||||
Theano
|
||||
rollout
|
||||
kfac
|
||||
Piecewise
|
||||
csv
|
||||
nvidia
|
||||
visdom
|
||||
tensorboard
|
||||
preprocessed
|
||||
namespace
|
||||
sklearn
|
||||
GoalEnv
|
||||
Torchy
|
||||
pytorch
|
||||
dicts
|
||||
optimizers
|
||||
Deprecations
|
||||
forkserver
|
||||
cuda
|
||||
6
setup.py
6
setup.py
|
|
@ -22,7 +22,9 @@ setup(name='torchy_baselines',
|
|||
'docs': [
|
||||
'sphinx',
|
||||
'sphinx-autobuild',
|
||||
'sphinx-rtd-theme'
|
||||
'sphinx-rtd-theme',
|
||||
# For spelling
|
||||
'sphinxcontrib.spelling'
|
||||
],
|
||||
'extra': [
|
||||
# For render
|
||||
|
|
@ -40,7 +42,7 @@ setup(name='torchy_baselines',
|
|||
license="MIT",
|
||||
long_description="",
|
||||
long_description_content_type='text/markdown',
|
||||
version="0.1.0",
|
||||
version="0.2.0a0",
|
||||
)
|
||||
|
||||
# python setup.py sdist
|
||||
|
|
|
|||
|
|
@ -4,4 +4,4 @@ from torchy_baselines.ppo import PPO
|
|||
from torchy_baselines.sac import SAC
|
||||
from torchy_baselines.td3 import TD3
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__version__ = "0.2.0a0"
|
||||
|
|
|
|||
|
|
@ -35,9 +35,17 @@ class CEMRL(TD3):
|
|||
:param batch_size: (int) Minibatch size for each gradient update
|
||||
:param tau: (float) the soft update coefficient ("polyak update" of the target networks, between 0 and 1)
|
||||
:param action_noise: (ActionNoise) the action noise type. Cf common.noise for the different action noise type.
|
||||
:param target_policy_noise: (float) Standard deviation of gaussian noise added to target policy
|
||||
:param target_policy_noise: (float) Standard deviation of Gaussian noise added to target policy
|
||||
(smoothing noise)
|
||||
:param target_noise_clip: (float) Limit for absolute value of target policy smoothing noise.
|
||||
:param n_episodes_rollout: (int) Update the model every `n_episodes_rollout` episodes.
|
||||
Note that this cannot be used at the same time as `train_freq`
|
||||
:param update_style: (str) Update style for the individual that will use the gradient:
|
||||
- original: original implementation (actor_steps // n_grad steps for the critic
|
||||
and actor_steps gradient steps per individual)
|
||||
- original_td3: same as before but the target networks are only update afterward
|
||||
- td3_like: use policy delay and `actor_steps` steps for both the critic and the individual
|
||||
- other: `2 * (actor_steps // self.n_grad)` for the critic and the individual
|
||||
:param create_eval_env: (bool) Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
|
||||
|
|
|
|||
|
|
@ -183,7 +183,7 @@ class BaseRLModel(object):
|
|||
def safe_mean(arr):
|
||||
"""
|
||||
Compute the mean of an array if there is at least one element.
|
||||
For empty array, return nan. It is used for logging only.
|
||||
For empty array, return NaN. It is used for logging only.
|
||||
|
||||
:param arr: (np.ndarray)
|
||||
:return: (float)
|
||||
|
|
@ -192,7 +192,7 @@ class BaseRLModel(object):
|
|||
|
||||
def get_env(self):
|
||||
"""
|
||||
returns the current environment (can be None if not defined)
|
||||
Returns the current environment (can be None if not defined).
|
||||
|
||||
:return: (gym.Env) The current environment
|
||||
"""
|
||||
|
|
@ -201,10 +201,10 @@ class BaseRLModel(object):
|
|||
@staticmethod
|
||||
def check_env(env, observation_space, action_space):
|
||||
"""
|
||||
Checks the validity of the environment and returns if it is coherent
|
||||
Checks the validity of the environment and returns if it is consistent.
|
||||
Checked parameters:
|
||||
- observation_space
|
||||
- action_space
|
||||
- observation_space
|
||||
- action_space
|
||||
:return: (bool) True if environment seems to be coherent
|
||||
"""
|
||||
if observation_space != env.observation_space:
|
||||
|
|
@ -219,8 +219,8 @@ class BaseRLModel(object):
|
|||
Checks the validity of the environment, and if it is coherent, set it as the current environment.
|
||||
Furthermore wrap any non vectorized env into a vectorized
|
||||
checked parameters:
|
||||
- observation_space
|
||||
- action_space
|
||||
- observation_space
|
||||
- action_space
|
||||
|
||||
:param env: (gym.Env) The environment for learning a policy
|
||||
"""
|
||||
|
|
@ -312,7 +312,7 @@ class BaseRLModel(object):
|
|||
Load the model from a zip-file
|
||||
|
||||
:param load_path: (str) the location of the saved data
|
||||
:param env: (Gym Envrionment) the new environment to run the loaded model on
|
||||
:param env: (Gym Environment) the new environment to run the loaded model on
|
||||
(can be None if you only need prediction from a trained model) has priority over any saved environment
|
||||
:param kwargs: extra arguments to change the model when loading
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -162,7 +162,7 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
|
|||
# It will be clipped to avoid NaN when inversing tanh
|
||||
gaussian_action = TanhBijector.inverse(action)
|
||||
|
||||
# Log likelihood for a gaussian distribution
|
||||
# Log likelihood for a Gaussian distribution
|
||||
log_prob = super(SquashedDiagGaussianDistribution, self).log_prob(gaussian_action)
|
||||
# Squash correction (from original SAC implementation)
|
||||
# this comes from the fact that tanh is bijective and differentiable
|
||||
|
|
@ -289,7 +289,7 @@ class StateDependentNoiseDistribution(Distribution):
|
|||
def sample_weights(self, log_std, batch_size=1):
|
||||
"""
|
||||
Sample weights for the noise exploration matrix,
|
||||
using a centered gaussian distribution.
|
||||
using a centered Gaussian distribution.
|
||||
|
||||
:param log_std: (th.Tensor)
|
||||
:param batch_size: (int)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ class ActionNoise(object):
|
|||
|
||||
class NormalActionNoise(ActionNoise):
|
||||
"""
|
||||
A gaussian action noise
|
||||
A Gaussian action noise
|
||||
|
||||
:param mean: (float) the mean value of the noise
|
||||
:param sigma: (float) the scale of the noise (std here)
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ class Actor(BaseNetwork):
|
|||
self.mu, self.log_std = self.action_dist.proba_distribution_net(latent_dim=net_arch[-1],
|
||||
latent_sde_dim=latent_sde_dim,
|
||||
log_std_init=log_std_init)
|
||||
# Avoid saturation by limiting the mean of the gaussian to be in [-1, 1]
|
||||
# Avoid saturation by limiting the mean of the Gaussian to be in [-1, 1]
|
||||
# self.mu = nn.Sequential(self.mu, nn.Tanh())
|
||||
self.mu = nn.Sequential(self.mu, nn.Hardtanh(min_val=-2.0, max_val=2.0))
|
||||
# Small positive slope to have non-zero gradient
|
||||
|
|
|
|||
|
|
@ -37,6 +37,8 @@ class SAC(BaseRLModel):
|
|||
:param target_update_interval: (int) update the target network every `target_network_update_freq` steps.
|
||||
:param train_freq: (int) Update the model every `train_freq` steps.
|
||||
:param gradient_steps: (int) How many gradient update after each step
|
||||
:param n_episodes_rollout: (int) Update the model every `n_episodes_rollout` episodes.
|
||||
Note that this cannot be used at the same time as `train_freq`
|
||||
:param target_entropy: (str or float) target entropy when learning ent_coef (ent_coef = 'auto')
|
||||
:param action_noise: (ActionNoise) the action noise type (None by default), this can help
|
||||
for hard exploration problem. Cf common.noise for the different action noise type.
|
||||
|
|
|
|||
|
|
@ -29,9 +29,11 @@ class TD3(BaseRLModel):
|
|||
:param batch_size: (int) Minibatch size for each gradient update
|
||||
:param train_freq: (int) Update the model every `train_freq` steps.
|
||||
:param gradient_steps: (int) How many gradient update after each step
|
||||
:param n_episodes_rollout: (int) Update the model every `n_episodes_rollout` episodes.
|
||||
Note that this cannot be used at the same time as `train_freq`
|
||||
:param tau: (float) the soft update coefficient ("polyak update" of the target networks, between 0 and 1)
|
||||
:param action_noise: (ActionNoise) the action noise type. Cf common.noise for the different action noise type.
|
||||
:param target_policy_noise: (float) Standard deviation of gaussian noise added to target policy
|
||||
:param target_policy_noise: (float) Standard deviation of Gaussian noise added to target policy
|
||||
(smoothing noise)
|
||||
:param target_noise_clip: (float) Limit for absolute value of target policy smoothing noise.
|
||||
:param use_sde: (bool) Whether to use State Dependent Exploration (SDE)
|
||||
|
|
|
|||
Loading…
Reference in a new issue