From f068ada44230670cae9f75c993f01704fcf96d68 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 15 May 2020 15:07:22 +0200 Subject: [PATCH] Update README + style fixes --- README.md | 13 ++++--- stable_baselines3/td3/policies.py | 60 +++++++++++++++---------------- 2 files changed, 39 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index 309d524..8f4edcc 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,18 @@ -[![pipeline status](https://gitlab.com/araffin/stable-baselines3/badges/master/pipeline.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/master) [![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [![coverage report](https://gitlab.com/araffin/stable-baselines3/badges/master/coverage.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/master) +[![pipeline status](https://gitlab.com/araffin/stable-baselines3/badges/sde/pipeline.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/sde) [![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [![coverage report](https://gitlab.com/araffin/stable-baselines3/badges/sde/coverage.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/sde) -**WARNING: Stable Baselines3 is currently in a beta version, breaking changes may occur before 1.0 is released** +# Generalized State-Dependent Exploration (gSDE) for Deep Reinforcement Learning in Robotics -Note: most of the documentation of [Stable Baselines](https://github.com/hill-a/stable-baselines) should be still valid though. +This branch contains the code for reproducing the results in the paper "Generalized State-Dependent Exploration for Deep Reinforcement Learning in Robotics" by Antonin Raffin and Freek Stulp. -# Stable Baselines3 +Arxiv: https://arxiv.org/abs/2005.05719 + +The main difference with the master branch is that TD3 has support in that branch for gSDE. + + +## Stable Baselines3 Stable Baselines3 is a set of improved implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines). diff --git a/stable_baselines3/td3/policies.py b/stable_baselines3/td3/policies.py index edbb886..e0ea24c 100644 --- a/stable_baselines3/td3/policies.py +++ b/stable_baselines3/td3/policies.py @@ -6,8 +6,8 @@ import torch.nn as nn from stable_baselines3.common.preprocessing import get_action_dim from stable_baselines3.common.policies import (BasePolicy, register_policy, create_mlp, - create_sde_features_extractor, NatureCNN, - BaseFeaturesExtractor, FlattenExtractor) + create_sde_features_extractor, NatureCNN, + BaseFeaturesExtractor, FlattenExtractor) from stable_baselines3.common.distributions import StateDependentNoiseDistribution @@ -109,19 +109,19 @@ class Actor(BasePolicy): def _get_data(self) -> Dict[str, Any]: data = super()._get_data() - data.update(dict( - net_arch=self.net_arch, - features_dim=self.features_dim, - activation_fn=self.activation_fn, - use_sde=self.use_sde, - log_std_init=self.log_std_init, - clip_noise=self.clip_noise, - lr_sde=self.lr_sde, - full_std=self.full_std, - sde_net_arch=self.sde_net_arch, - use_expln=self.use_expln, - features_extractor=self.features_extractor - )) + data.update(dict(net_arch=self.net_arch, + features_dim=self.features_dim, + activation_fn=self.activation_fn, + use_sde=self.use_sde, + log_std_init=self.log_std_init, + clip_noise=self.clip_noise, + lr_sde=self.lr_sde, + full_std=self.full_std, + sde_net_arch=self.sde_net_arch, + use_expln=self.use_expln, + features_extractor=self.features_extractor + ) + ) return data def get_std(self) -> th.Tensor: @@ -389,21 +389,21 @@ class TD3Policy(BasePolicy): def _get_data(self) -> Dict[str, Any]: data = super()._get_data() - data.update(dict( - net_arch=self.net_args['net_arch'], - activation_fn=self.net_args['activation_fn'], - use_sde=self.actor_kwargs['use_sde'], - log_std_init=self.actor_kwargs['log_std_init'], - clip_noise=self.actor_kwargs['clip_noise'], - lr_sde=self.actor_kwargs['lr_sde'], - sde_net_arch=self.actor_kwargs['sde_net_arch'], - use_expln=self.actor_kwargs['use_expln'], - lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone - optimizer_class=self.optimizer_class, - optimizer_kwargs=self.optimizer_kwargs, - features_extractor_class=self.features_extractor_class, - features_extractor_kwargs=self.features_extractor_kwargs - )) + data.update(dict(net_arch=self.net_args['net_arch'], + activation_fn=self.net_args['activation_fn'], + use_sde=self.actor_kwargs['use_sde'], + log_std_init=self.actor_kwargs['log_std_init'], + clip_noise=self.actor_kwargs['clip_noise'], + lr_sde=self.actor_kwargs['lr_sde'], + sde_net_arch=self.actor_kwargs['sde_net_arch'], + use_expln=self.actor_kwargs['use_expln'], + lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone + optimizer_class=self.optimizer_class, + optimizer_kwargs=self.optimizer_kwargs, + features_extractor_class=self.features_extractor_class, + features_extractor_kwargs=self.features_extractor_kwargs + ) + ) return data def reset_noise(self) -> None: