Cleanup + reformat code

This commit is contained in:
Antonin RAFFIN 2020-05-08 15:10:46 +02:00
parent c20af230f7
commit 413a2386d9
8 changed files with 52 additions and 44 deletions

View file

@ -240,8 +240,7 @@ class BaseRLModel(ABC):
if (observation_space != env.observation_space
# Special cases for images that need to be transposed
and not (is_image_space(env.observation_space)
and observation_space == VecTransposeImage.transpose_space(env.observation_space)
)):
and observation_space == VecTransposeImage.transpose_space(env.observation_space))):
raise ValueError(f'Observation spaces do not match: {observation_space} != {env.observation_space}')
if action_space != env.action_space:
raise ValueError(f'Action spaces do not match: {action_space} != {env.action_space}')
@ -820,11 +819,11 @@ class OffPolicyRLModel(BaseRLModel):
if action_noise is not None:
# NOTE: in the original implementation of TD3, the noise was applied to the unscaled action
# Update(October 2019): Not anymore
clipped_action = np.clip(scaled_action + action_noise(), -1, 1)
scaled_action = np.clip(scaled_action + action_noise(), -1, 1)
# We store the scaled action in the buffer
buffer_action = clipped_action
action = self.policy.unscale_action(clipped_action)
buffer_action = scaled_action
action = self.policy.unscale_action(scaled_action)
else:
# Discrete case, no need to normalize or clip
buffer_action = unscaled_action

View file

@ -20,6 +20,7 @@ class KVWriter(object):
"""
Key Value writer
"""
def writekvs(self, kvs: Dict) -> None:
"""
write a dictionary to file
@ -39,6 +40,7 @@ class SeqWriter(object):
"""
sequence writer
"""
def writeseq(self, seq: List):
"""
write an array to file
@ -49,7 +51,7 @@ class SeqWriter(object):
class HumanOutputFormat(KVWriter, SeqWriter):
def __init__(self, filename_or_file: Union [str, TextIO]):
def __init__(self, filename_or_file: Union[str, TextIO]):
"""
log to a file, in a human readable format

View file

@ -166,7 +166,7 @@ class BasePolicy(nn.Module):
module.bias.data.fill_(0.0)
@staticmethod
def _dummy_schedule(progress: float) -> float:
def _dummy_schedule(_progress: float) -> float:
""" (float) Useful for pickling policy."""
return 0.0

View file

@ -4,7 +4,7 @@ from typing import Optional, Union
from copy import deepcopy
from stable_baselines3.common.vec_env.base_vec_env import (AlreadySteppingError, NotSteppingError,
VecEnv, VecEnvWrapper, CloudpickleWrapper)
VecEnv, VecEnvWrapper, CloudpickleWrapper)
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
from stable_baselines3.common.vec_env.vec_frame_stack import VecFrameStack

View file

@ -11,8 +11,9 @@ from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, ob
class DummyVecEnv(VecEnv):
"""
Creates a simple vectorized wrapper for multiple environments, calling each environment in sequence on the current
Python process. This is useful for computationally simple environment such as ````cartpole-v1````, as the overhead of
multiprocess or multithread outweighs the environment computation time. This can also be used for RL methods that
Python process. This is useful for computationally simple environment such as ``cartpole-v1``,
as the overhead of multiprocess or multithread outweighs the environment computation time.
This can also be used for RL methods that
require a vectorized environment, but that you want a single environments to train with.
:param env_fns: ([Gym Environment]) the list of environments to vectorize

View file

@ -7,11 +7,11 @@ import torch.nn as nn
import numpy as np
from stable_baselines3.common.policies import (BasePolicy, register_policy, MlpExtractor,
create_sde_features_extractor, NatureCNN,
BaseFeaturesExtractor, FlattenExtractor)
create_sde_features_extractor, NatureCNN,
BaseFeaturesExtractor, FlattenExtractor)
from stable_baselines3.common.distributions import (make_proba_distribution, Distribution,
DiagGaussianDistribution, CategoricalDistribution,
StateDependentNoiseDistribution)
DiagGaussianDistribution, CategoricalDistribution,
StateDependentNoiseDistribution)
class PPOPolicy(BasePolicy):
@ -47,6 +47,7 @@ class PPOPolicy(BasePolicy):
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
def __init__(self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
@ -122,20 +123,20 @@ class PPOPolicy(BasePolicy):
data = super()._get_data()
data.update(dict(
net_arch=self.net_arch,
activation_fn=self.activation_fn,
use_sde=self.use_sde,
log_std_init=self.log_std_init,
squash_output=self.dist_kwargs['squash_output'] if self.dist_kwargs else None,
full_std=self.dist_kwargs['full_std'] if self.dist_kwargs else None,
sde_net_arch=self.dist_kwargs['sde_net_arch'] if self.dist_kwargs else None,
use_expln=self.dist_kwargs['use_expln'] if self.dist_kwargs else None,
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
ortho_init=self.ortho_init,
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
features_extractor_class=self.features_extractor_class,
features_extractor_kwargs=self.features_extractor_kwargs
net_arch=self.net_arch,
activation_fn=self.activation_fn,
use_sde=self.use_sde,
log_std_init=self.log_std_init,
squash_output=self.dist_kwargs['squash_output'] if self.dist_kwargs else None,
full_std=self.dist_kwargs['full_std'] if self.dist_kwargs else None,
sde_net_arch=self.dist_kwargs['sde_net_arch'] if self.dist_kwargs else None,
use_expln=self.dist_kwargs['use_expln'] if self.dist_kwargs else None,
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
ortho_init=self.ortho_init,
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
features_extractor_class=self.features_extractor_class,
features_extractor_kwargs=self.features_extractor_kwargs
))
return data
@ -145,7 +146,8 @@ class PPOPolicy(BasePolicy):
:param n_envs: (int)
"""
assert isinstance(self.action_dist, StateDependentNoiseDistribution), 'reset_noise() is only available when using SDE'
assert isinstance(self.action_dist,
StateDependentNoiseDistribution), 'reset_noise() is only available when using SDE'
self.action_dist.sample_weights(self.log_std, batch_size=n_envs)
def _build(self, lr_schedule: Callable) -> None:
@ -319,6 +321,7 @@ class CnnPolicy(PPOPolicy):
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
def __init__(self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,

View file

@ -6,8 +6,8 @@ import torch.nn as nn
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.policies import (BasePolicy, register_policy, create_mlp,
create_sde_features_extractor, NatureCNN,
BaseFeaturesExtractor, FlattenExtractor)
create_sde_features_extractor, NatureCNN,
BaseFeaturesExtractor, FlattenExtractor)
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
# CAP the standard deviation of the actor

View file

@ -6,7 +6,7 @@ import torch.nn as nn
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.policies import (BasePolicy, register_policy, create_mlp,
NatureCNN, BaseFeaturesExtractor, FlattenExtractor)
NatureCNN, BaseFeaturesExtractor, FlattenExtractor)
class Actor(BasePolicy):
@ -24,6 +24,7 @@ class Actor(BasePolicy):
dividing by 255.0 (True by default)
:param device: (Union[th.device, str]) Device on which the code should run.
"""
def __init__(self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
@ -39,7 +40,6 @@ class Actor(BasePolicy):
device=device,
squash_output=True)
self.features_extractor = features_extractor
self.normalize_images = normalize_images
self.net_arch = net_arch
@ -55,10 +55,10 @@ class Actor(BasePolicy):
data = super()._get_data()
data.update(dict(
net_arch=self.net_arch,
features_dim=self.features_dim,
activation_fn=self.activation_fn,
features_extractor=self.features_extractor
net_arch=self.net_arch,
features_dim=self.features_dim,
activation_fn=self.activation_fn,
features_extractor=self.features_extractor
))
return data
@ -87,6 +87,7 @@ class Critic(BasePolicy):
dividing by 255.0 (True by default)
:param device: (Union[th.device, str]) Device on which the code should run.
"""
def __init__(self, observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
net_arch: List[int],
@ -141,6 +142,7 @@ class TD3Policy(BasePolicy):
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
def __init__(self, observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Callable,
@ -204,13 +206,13 @@ class TD3Policy(BasePolicy):
data = super()._get_data()
data.update(dict(
net_arch=self.net_args['net_arch'],
activation_fn=self.net_args['activation_fn'],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
features_extractor_class=self.features_extractor_class,
features_extractor_kwargs=self.features_extractor_kwargs
net_arch=self.net_args['net_arch'],
activation_fn=self.net_args['activation_fn'],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
features_extractor_class=self.features_extractor_class,
features_extractor_kwargs=self.features_extractor_kwargs
))
return data
@ -250,6 +252,7 @@ class CnnPolicy(TD3Policy):
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
def __init__(self, observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Callable,