Update README + style fixes

This commit is contained in:
Antonin RAFFIN 2020-05-15 15:07:22 +02:00
parent d8c54313e5
commit f068ada442
2 changed files with 39 additions and 34 deletions

View file

@ -1,13 +1,18 @@
<img src="docs/\_static/img/logo.png" align="right" width="40%"/>
[![pipeline status](https://gitlab.com/araffin/stable-baselines3/badges/master/pipeline.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/master) [![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [![coverage report](https://gitlab.com/araffin/stable-baselines3/badges/master/coverage.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/master)
[![pipeline status](https://gitlab.com/araffin/stable-baselines3/badges/sde/pipeline.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/sde) [![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [![coverage report](https://gitlab.com/araffin/stable-baselines3/badges/sde/coverage.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/sde)
**WARNING: Stable Baselines3 is currently in a beta version, breaking changes may occur before 1.0 is released**
# Generalized State-Dependent Exploration (gSDE) for Deep Reinforcement Learning in Robotics
Note: most of the documentation of [Stable Baselines](https://github.com/hill-a/stable-baselines) should be still valid though.
This branch contains the code for reproducing the results in the paper "Generalized State-Dependent Exploration for Deep Reinforcement Learning in Robotics" by Antonin Raffin and Freek Stulp.
# Stable Baselines3
Arxiv: https://arxiv.org/abs/2005.05719
The main difference with the master branch is that TD3 has support in that branch for gSDE.
## Stable Baselines3
Stable Baselines3 is a set of improved implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines).

View file

@ -6,8 +6,8 @@ import torch.nn as nn
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.policies import (BasePolicy, register_policy, create_mlp,
create_sde_features_extractor, NatureCNN,
BaseFeaturesExtractor, FlattenExtractor)
create_sde_features_extractor, NatureCNN,
BaseFeaturesExtractor, FlattenExtractor)
from stable_baselines3.common.distributions import StateDependentNoiseDistribution
@ -109,19 +109,19 @@ class Actor(BasePolicy):
def _get_data(self) -> Dict[str, Any]:
data = super()._get_data()
data.update(dict(
net_arch=self.net_arch,
features_dim=self.features_dim,
activation_fn=self.activation_fn,
use_sde=self.use_sde,
log_std_init=self.log_std_init,
clip_noise=self.clip_noise,
lr_sde=self.lr_sde,
full_std=self.full_std,
sde_net_arch=self.sde_net_arch,
use_expln=self.use_expln,
features_extractor=self.features_extractor
))
data.update(dict(net_arch=self.net_arch,
features_dim=self.features_dim,
activation_fn=self.activation_fn,
use_sde=self.use_sde,
log_std_init=self.log_std_init,
clip_noise=self.clip_noise,
lr_sde=self.lr_sde,
full_std=self.full_std,
sde_net_arch=self.sde_net_arch,
use_expln=self.use_expln,
features_extractor=self.features_extractor
)
)
return data
def get_std(self) -> th.Tensor:
@ -389,21 +389,21 @@ class TD3Policy(BasePolicy):
def _get_data(self) -> Dict[str, Any]:
data = super()._get_data()
data.update(dict(
net_arch=self.net_args['net_arch'],
activation_fn=self.net_args['activation_fn'],
use_sde=self.actor_kwargs['use_sde'],
log_std_init=self.actor_kwargs['log_std_init'],
clip_noise=self.actor_kwargs['clip_noise'],
lr_sde=self.actor_kwargs['lr_sde'],
sde_net_arch=self.actor_kwargs['sde_net_arch'],
use_expln=self.actor_kwargs['use_expln'],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
features_extractor_class=self.features_extractor_class,
features_extractor_kwargs=self.features_extractor_kwargs
))
data.update(dict(net_arch=self.net_args['net_arch'],
activation_fn=self.net_args['activation_fn'],
use_sde=self.actor_kwargs['use_sde'],
log_std_init=self.actor_kwargs['log_std_init'],
clip_noise=self.actor_kwargs['clip_noise'],
lr_sde=self.actor_kwargs['lr_sde'],
sde_net_arch=self.actor_kwargs['sde_net_arch'],
use_expln=self.actor_kwargs['use_expln'],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
features_extractor_class=self.features_extractor_class,
features_extractor_kwargs=self.features_extractor_kwargs
)
)
return data
def reset_noise(self) -> None: