mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-03 03:59:13 +00:00
Update README + style fixes
This commit is contained in:
parent
d8c54313e5
commit
f068ada442
2 changed files with 39 additions and 34 deletions
13
README.md
13
README.md
|
|
@ -1,13 +1,18 @@
|
|||
<img src="docs/\_static/img/logo.png" align="right" width="40%"/>
|
||||
|
||||
[](https://gitlab.com/araffin/stable-baselines3/-/commits/master) [](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [](https://gitlab.com/araffin/stable-baselines3/-/commits/master)
|
||||
[](https://gitlab.com/araffin/stable-baselines3/-/commits/sde) [](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [](https://gitlab.com/araffin/stable-baselines3/-/commits/sde)
|
||||
|
||||
|
||||
**WARNING: Stable Baselines3 is currently in a beta version, breaking changes may occur before 1.0 is released**
|
||||
# Generalized State-Dependent Exploration (gSDE) for Deep Reinforcement Learning in Robotics
|
||||
|
||||
Note: most of the documentation of [Stable Baselines](https://github.com/hill-a/stable-baselines) should be still valid though.
|
||||
This branch contains the code for reproducing the results in the paper "Generalized State-Dependent Exploration for Deep Reinforcement Learning in Robotics" by Antonin Raffin and Freek Stulp.
|
||||
|
||||
# Stable Baselines3
|
||||
Arxiv: https://arxiv.org/abs/2005.05719
|
||||
|
||||
The main difference with the master branch is that TD3 has support in that branch for gSDE.
|
||||
|
||||
|
||||
## Stable Baselines3
|
||||
|
||||
Stable Baselines3 is a set of improved implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines).
|
||||
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ import torch.nn as nn
|
|||
|
||||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
from stable_baselines3.common.policies import (BasePolicy, register_policy, create_mlp,
|
||||
create_sde_features_extractor, NatureCNN,
|
||||
BaseFeaturesExtractor, FlattenExtractor)
|
||||
create_sde_features_extractor, NatureCNN,
|
||||
BaseFeaturesExtractor, FlattenExtractor)
|
||||
from stable_baselines3.common.distributions import StateDependentNoiseDistribution
|
||||
|
||||
|
||||
|
|
@ -109,19 +109,19 @@ class Actor(BasePolicy):
|
|||
def _get_data(self) -> Dict[str, Any]:
|
||||
data = super()._get_data()
|
||||
|
||||
data.update(dict(
|
||||
net_arch=self.net_arch,
|
||||
features_dim=self.features_dim,
|
||||
activation_fn=self.activation_fn,
|
||||
use_sde=self.use_sde,
|
||||
log_std_init=self.log_std_init,
|
||||
clip_noise=self.clip_noise,
|
||||
lr_sde=self.lr_sde,
|
||||
full_std=self.full_std,
|
||||
sde_net_arch=self.sde_net_arch,
|
||||
use_expln=self.use_expln,
|
||||
features_extractor=self.features_extractor
|
||||
))
|
||||
data.update(dict(net_arch=self.net_arch,
|
||||
features_dim=self.features_dim,
|
||||
activation_fn=self.activation_fn,
|
||||
use_sde=self.use_sde,
|
||||
log_std_init=self.log_std_init,
|
||||
clip_noise=self.clip_noise,
|
||||
lr_sde=self.lr_sde,
|
||||
full_std=self.full_std,
|
||||
sde_net_arch=self.sde_net_arch,
|
||||
use_expln=self.use_expln,
|
||||
features_extractor=self.features_extractor
|
||||
)
|
||||
)
|
||||
return data
|
||||
|
||||
def get_std(self) -> th.Tensor:
|
||||
|
|
@ -389,21 +389,21 @@ class TD3Policy(BasePolicy):
|
|||
def _get_data(self) -> Dict[str, Any]:
|
||||
data = super()._get_data()
|
||||
|
||||
data.update(dict(
|
||||
net_arch=self.net_args['net_arch'],
|
||||
activation_fn=self.net_args['activation_fn'],
|
||||
use_sde=self.actor_kwargs['use_sde'],
|
||||
log_std_init=self.actor_kwargs['log_std_init'],
|
||||
clip_noise=self.actor_kwargs['clip_noise'],
|
||||
lr_sde=self.actor_kwargs['lr_sde'],
|
||||
sde_net_arch=self.actor_kwargs['sde_net_arch'],
|
||||
use_expln=self.actor_kwargs['use_expln'],
|
||||
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
|
||||
optimizer_class=self.optimizer_class,
|
||||
optimizer_kwargs=self.optimizer_kwargs,
|
||||
features_extractor_class=self.features_extractor_class,
|
||||
features_extractor_kwargs=self.features_extractor_kwargs
|
||||
))
|
||||
data.update(dict(net_arch=self.net_args['net_arch'],
|
||||
activation_fn=self.net_args['activation_fn'],
|
||||
use_sde=self.actor_kwargs['use_sde'],
|
||||
log_std_init=self.actor_kwargs['log_std_init'],
|
||||
clip_noise=self.actor_kwargs['clip_noise'],
|
||||
lr_sde=self.actor_kwargs['lr_sde'],
|
||||
sde_net_arch=self.actor_kwargs['sde_net_arch'],
|
||||
use_expln=self.actor_kwargs['use_expln'],
|
||||
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
|
||||
optimizer_class=self.optimizer_class,
|
||||
optimizer_kwargs=self.optimizer_kwargs,
|
||||
features_extractor_class=self.features_extractor_class,
|
||||
features_extractor_kwargs=self.features_extractor_kwargs
|
||||
)
|
||||
)
|
||||
return data
|
||||
|
||||
def reset_noise(self) -> None:
|
||||
|
|
|
|||
Loading…
Reference in a new issue