diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 6ab7477..7b90c72 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -240,8 +240,7 @@ class BaseRLModel(ABC): if (observation_space != env.observation_space # Special cases for images that need to be transposed and not (is_image_space(env.observation_space) - and observation_space == VecTransposeImage.transpose_space(env.observation_space) - )): + and observation_space == VecTransposeImage.transpose_space(env.observation_space))): raise ValueError(f'Observation spaces do not match: {observation_space} != {env.observation_space}') if action_space != env.action_space: raise ValueError(f'Action spaces do not match: {action_space} != {env.action_space}') @@ -820,11 +819,11 @@ class OffPolicyRLModel(BaseRLModel): if action_noise is not None: # NOTE: in the original implementation of TD3, the noise was applied to the unscaled action # Update(October 2019): Not anymore - clipped_action = np.clip(scaled_action + action_noise(), -1, 1) + scaled_action = np.clip(scaled_action + action_noise(), -1, 1) # We store the scaled action in the buffer - buffer_action = clipped_action - action = self.policy.unscale_action(clipped_action) + buffer_action = scaled_action + action = self.policy.unscale_action(scaled_action) else: # Discrete case, no need to normalize or clip buffer_action = unscaled_action diff --git a/stable_baselines3/common/logger.py b/stable_baselines3/common/logger.py index 15074bf..9269eb5 100644 --- a/stable_baselines3/common/logger.py +++ b/stable_baselines3/common/logger.py @@ -20,6 +20,7 @@ class KVWriter(object): """ Key Value writer """ + def writekvs(self, kvs: Dict) -> None: """ write a dictionary to file @@ -39,6 +40,7 @@ class SeqWriter(object): """ sequence writer """ + def writeseq(self, seq: List): """ write an array to file @@ -49,7 +51,7 @@ class SeqWriter(object): class HumanOutputFormat(KVWriter, SeqWriter): - def __init__(self, filename_or_file: Union [str, TextIO]): + def __init__(self, filename_or_file: Union[str, TextIO]): """ log to a file, in a human readable format diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 22b51df..aa375aa 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -166,7 +166,7 @@ class BasePolicy(nn.Module): module.bias.data.fill_(0.0) @staticmethod - def _dummy_schedule(progress: float) -> float: + def _dummy_schedule(_progress: float) -> float: """ (float) Useful for pickling policy.""" return 0.0 diff --git a/stable_baselines3/common/vec_env/__init__.py b/stable_baselines3/common/vec_env/__init__.py index 8f7719a..535ed03 100644 --- a/stable_baselines3/common/vec_env/__init__.py +++ b/stable_baselines3/common/vec_env/__init__.py @@ -4,7 +4,7 @@ from typing import Optional, Union from copy import deepcopy from stable_baselines3.common.vec_env.base_vec_env import (AlreadySteppingError, NotSteppingError, - VecEnv, VecEnvWrapper, CloudpickleWrapper) + VecEnv, VecEnvWrapper, CloudpickleWrapper) from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv from stable_baselines3.common.vec_env.vec_frame_stack import VecFrameStack diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index c5451f0..962f655 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -11,8 +11,9 @@ from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, ob class DummyVecEnv(VecEnv): """ Creates a simple vectorized wrapper for multiple environments, calling each environment in sequence on the current - Python process. This is useful for computationally simple environment such as ````cartpole-v1````, as the overhead of - multiprocess or multithread outweighs the environment computation time. This can also be used for RL methods that + Python process. This is useful for computationally simple environment such as ``cartpole-v1``, + as the overhead of multiprocess or multithread outweighs the environment computation time. + This can also be used for RL methods that require a vectorized environment, but that you want a single environments to train with. :param env_fns: ([Gym Environment]) the list of environments to vectorize diff --git a/stable_baselines3/ppo/policies.py b/stable_baselines3/ppo/policies.py index e997769..8742220 100644 --- a/stable_baselines3/ppo/policies.py +++ b/stable_baselines3/ppo/policies.py @@ -7,11 +7,11 @@ import torch.nn as nn import numpy as np from stable_baselines3.common.policies import (BasePolicy, register_policy, MlpExtractor, - create_sde_features_extractor, NatureCNN, - BaseFeaturesExtractor, FlattenExtractor) + create_sde_features_extractor, NatureCNN, + BaseFeaturesExtractor, FlattenExtractor) from stable_baselines3.common.distributions import (make_proba_distribution, Distribution, - DiagGaussianDistribution, CategoricalDistribution, - StateDependentNoiseDistribution) + DiagGaussianDistribution, CategoricalDistribution, + StateDependentNoiseDistribution) class PPOPolicy(BasePolicy): @@ -47,6 +47,7 @@ class PPOPolicy(BasePolicy): :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, excluding the learning rate, to pass to the optimizer """ + def __init__(self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, @@ -122,20 +123,20 @@ class PPOPolicy(BasePolicy): data = super()._get_data() data.update(dict( - net_arch=self.net_arch, - activation_fn=self.activation_fn, - use_sde=self.use_sde, - log_std_init=self.log_std_init, - squash_output=self.dist_kwargs['squash_output'] if self.dist_kwargs else None, - full_std=self.dist_kwargs['full_std'] if self.dist_kwargs else None, - sde_net_arch=self.dist_kwargs['sde_net_arch'] if self.dist_kwargs else None, - use_expln=self.dist_kwargs['use_expln'] if self.dist_kwargs else None, - lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone - ortho_init=self.ortho_init, - optimizer_class=self.optimizer_class, - optimizer_kwargs=self.optimizer_kwargs, - features_extractor_class=self.features_extractor_class, - features_extractor_kwargs=self.features_extractor_kwargs + net_arch=self.net_arch, + activation_fn=self.activation_fn, + use_sde=self.use_sde, + log_std_init=self.log_std_init, + squash_output=self.dist_kwargs['squash_output'] if self.dist_kwargs else None, + full_std=self.dist_kwargs['full_std'] if self.dist_kwargs else None, + sde_net_arch=self.dist_kwargs['sde_net_arch'] if self.dist_kwargs else None, + use_expln=self.dist_kwargs['use_expln'] if self.dist_kwargs else None, + lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone + ortho_init=self.ortho_init, + optimizer_class=self.optimizer_class, + optimizer_kwargs=self.optimizer_kwargs, + features_extractor_class=self.features_extractor_class, + features_extractor_kwargs=self.features_extractor_kwargs )) return data @@ -145,7 +146,8 @@ class PPOPolicy(BasePolicy): :param n_envs: (int) """ - assert isinstance(self.action_dist, StateDependentNoiseDistribution), 'reset_noise() is only available when using SDE' + assert isinstance(self.action_dist, + StateDependentNoiseDistribution), 'reset_noise() is only available when using SDE' self.action_dist.sample_weights(self.log_std, batch_size=n_envs) def _build(self, lr_schedule: Callable) -> None: @@ -319,6 +321,7 @@ class CnnPolicy(PPOPolicy): :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, excluding the learning rate, to pass to the optimizer """ + def __init__(self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py index 868fde6..d6cb470 100644 --- a/stable_baselines3/sac/policies.py +++ b/stable_baselines3/sac/policies.py @@ -6,8 +6,8 @@ import torch.nn as nn from stable_baselines3.common.preprocessing import get_action_dim from stable_baselines3.common.policies import (BasePolicy, register_policy, create_mlp, - create_sde_features_extractor, NatureCNN, - BaseFeaturesExtractor, FlattenExtractor) + create_sde_features_extractor, NatureCNN, + BaseFeaturesExtractor, FlattenExtractor) from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution # CAP the standard deviation of the actor diff --git a/stable_baselines3/td3/policies.py b/stable_baselines3/td3/policies.py index a90324a..9a67911 100644 --- a/stable_baselines3/td3/policies.py +++ b/stable_baselines3/td3/policies.py @@ -6,7 +6,7 @@ import torch.nn as nn from stable_baselines3.common.preprocessing import get_action_dim from stable_baselines3.common.policies import (BasePolicy, register_policy, create_mlp, - NatureCNN, BaseFeaturesExtractor, FlattenExtractor) + NatureCNN, BaseFeaturesExtractor, FlattenExtractor) class Actor(BasePolicy): @@ -24,6 +24,7 @@ class Actor(BasePolicy): dividing by 255.0 (True by default) :param device: (Union[th.device, str]) Device on which the code should run. """ + def __init__(self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, @@ -39,7 +40,6 @@ class Actor(BasePolicy): device=device, squash_output=True) - self.features_extractor = features_extractor self.normalize_images = normalize_images self.net_arch = net_arch @@ -55,10 +55,10 @@ class Actor(BasePolicy): data = super()._get_data() data.update(dict( - net_arch=self.net_arch, - features_dim=self.features_dim, - activation_fn=self.activation_fn, - features_extractor=self.features_extractor + net_arch=self.net_arch, + features_dim=self.features_dim, + activation_fn=self.activation_fn, + features_extractor=self.features_extractor )) return data @@ -87,6 +87,7 @@ class Critic(BasePolicy): dividing by 255.0 (True by default) :param device: (Union[th.device, str]) Device on which the code should run. """ + def __init__(self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, net_arch: List[int], @@ -141,6 +142,7 @@ class TD3Policy(BasePolicy): :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, excluding the learning rate, to pass to the optimizer """ + def __init__(self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, lr_schedule: Callable, @@ -204,13 +206,13 @@ class TD3Policy(BasePolicy): data = super()._get_data() data.update(dict( - net_arch=self.net_args['net_arch'], - activation_fn=self.net_args['activation_fn'], - lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone - optimizer_class=self.optimizer_class, - optimizer_kwargs=self.optimizer_kwargs, - features_extractor_class=self.features_extractor_class, - features_extractor_kwargs=self.features_extractor_kwargs + net_arch=self.net_args['net_arch'], + activation_fn=self.net_args['activation_fn'], + lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone + optimizer_class=self.optimizer_class, + optimizer_kwargs=self.optimizer_kwargs, + features_extractor_class=self.features_extractor_class, + features_extractor_kwargs=self.features_extractor_kwargs )) return data @@ -250,6 +252,7 @@ class CnnPolicy(TD3Policy): :param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments, excluding the learning rate, to pass to the optimizer """ + def __init__(self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, lr_schedule: Callable,