mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-06 00:03:28 +00:00
Cleanup + reformat code
This commit is contained in:
parent
c20af230f7
commit
413a2386d9
8 changed files with 52 additions and 44 deletions
|
|
@ -240,8 +240,7 @@ class BaseRLModel(ABC):
|
|||
if (observation_space != env.observation_space
|
||||
# Special cases for images that need to be transposed
|
||||
and not (is_image_space(env.observation_space)
|
||||
and observation_space == VecTransposeImage.transpose_space(env.observation_space)
|
||||
)):
|
||||
and observation_space == VecTransposeImage.transpose_space(env.observation_space))):
|
||||
raise ValueError(f'Observation spaces do not match: {observation_space} != {env.observation_space}')
|
||||
if action_space != env.action_space:
|
||||
raise ValueError(f'Action spaces do not match: {action_space} != {env.action_space}')
|
||||
|
|
@ -820,11 +819,11 @@ class OffPolicyRLModel(BaseRLModel):
|
|||
if action_noise is not None:
|
||||
# NOTE: in the original implementation of TD3, the noise was applied to the unscaled action
|
||||
# Update(October 2019): Not anymore
|
||||
clipped_action = np.clip(scaled_action + action_noise(), -1, 1)
|
||||
scaled_action = np.clip(scaled_action + action_noise(), -1, 1)
|
||||
|
||||
# We store the scaled action in the buffer
|
||||
buffer_action = clipped_action
|
||||
action = self.policy.unscale_action(clipped_action)
|
||||
buffer_action = scaled_action
|
||||
action = self.policy.unscale_action(scaled_action)
|
||||
else:
|
||||
# Discrete case, no need to normalize or clip
|
||||
buffer_action = unscaled_action
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ class KVWriter(object):
|
|||
"""
|
||||
Key Value writer
|
||||
"""
|
||||
|
||||
def writekvs(self, kvs: Dict) -> None:
|
||||
"""
|
||||
write a dictionary to file
|
||||
|
|
@ -39,6 +40,7 @@ class SeqWriter(object):
|
|||
"""
|
||||
sequence writer
|
||||
"""
|
||||
|
||||
def writeseq(self, seq: List):
|
||||
"""
|
||||
write an array to file
|
||||
|
|
@ -49,7 +51,7 @@ class SeqWriter(object):
|
|||
|
||||
|
||||
class HumanOutputFormat(KVWriter, SeqWriter):
|
||||
def __init__(self, filename_or_file: Union [str, TextIO]):
|
||||
def __init__(self, filename_or_file: Union[str, TextIO]):
|
||||
"""
|
||||
log to a file, in a human readable format
|
||||
|
||||
|
|
|
|||
|
|
@ -166,7 +166,7 @@ class BasePolicy(nn.Module):
|
|||
module.bias.data.fill_(0.0)
|
||||
|
||||
@staticmethod
|
||||
def _dummy_schedule(progress: float) -> float:
|
||||
def _dummy_schedule(_progress: float) -> float:
|
||||
""" (float) Useful for pickling policy."""
|
||||
return 0.0
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from typing import Optional, Union
|
|||
from copy import deepcopy
|
||||
|
||||
from stable_baselines3.common.vec_env.base_vec_env import (AlreadySteppingError, NotSteppingError,
|
||||
VecEnv, VecEnvWrapper, CloudpickleWrapper)
|
||||
VecEnv, VecEnvWrapper, CloudpickleWrapper)
|
||||
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
|
||||
from stable_baselines3.common.vec_env.vec_frame_stack import VecFrameStack
|
||||
|
|
|
|||
|
|
@ -11,8 +11,9 @@ from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, ob
|
|||
class DummyVecEnv(VecEnv):
|
||||
"""
|
||||
Creates a simple vectorized wrapper for multiple environments, calling each environment in sequence on the current
|
||||
Python process. This is useful for computationally simple environment such as ````cartpole-v1````, as the overhead of
|
||||
multiprocess or multithread outweighs the environment computation time. This can also be used for RL methods that
|
||||
Python process. This is useful for computationally simple environment such as ``cartpole-v1``,
|
||||
as the overhead of multiprocess or multithread outweighs the environment computation time.
|
||||
This can also be used for RL methods that
|
||||
require a vectorized environment, but that you want a single environments to train with.
|
||||
|
||||
:param env_fns: ([Gym Environment]) the list of environments to vectorize
|
||||
|
|
|
|||
|
|
@ -7,11 +7,11 @@ import torch.nn as nn
|
|||
import numpy as np
|
||||
|
||||
from stable_baselines3.common.policies import (BasePolicy, register_policy, MlpExtractor,
|
||||
create_sde_features_extractor, NatureCNN,
|
||||
BaseFeaturesExtractor, FlattenExtractor)
|
||||
create_sde_features_extractor, NatureCNN,
|
||||
BaseFeaturesExtractor, FlattenExtractor)
|
||||
from stable_baselines3.common.distributions import (make_proba_distribution, Distribution,
|
||||
DiagGaussianDistribution, CategoricalDistribution,
|
||||
StateDependentNoiseDistribution)
|
||||
DiagGaussianDistribution, CategoricalDistribution,
|
||||
StateDependentNoiseDistribution)
|
||||
|
||||
|
||||
class PPOPolicy(BasePolicy):
|
||||
|
|
@ -47,6 +47,7 @@ class PPOPolicy(BasePolicy):
|
|||
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
|
|
@ -122,20 +123,20 @@ class PPOPolicy(BasePolicy):
|
|||
data = super()._get_data()
|
||||
|
||||
data.update(dict(
|
||||
net_arch=self.net_arch,
|
||||
activation_fn=self.activation_fn,
|
||||
use_sde=self.use_sde,
|
||||
log_std_init=self.log_std_init,
|
||||
squash_output=self.dist_kwargs['squash_output'] if self.dist_kwargs else None,
|
||||
full_std=self.dist_kwargs['full_std'] if self.dist_kwargs else None,
|
||||
sde_net_arch=self.dist_kwargs['sde_net_arch'] if self.dist_kwargs else None,
|
||||
use_expln=self.dist_kwargs['use_expln'] if self.dist_kwargs else None,
|
||||
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
|
||||
ortho_init=self.ortho_init,
|
||||
optimizer_class=self.optimizer_class,
|
||||
optimizer_kwargs=self.optimizer_kwargs,
|
||||
features_extractor_class=self.features_extractor_class,
|
||||
features_extractor_kwargs=self.features_extractor_kwargs
|
||||
net_arch=self.net_arch,
|
||||
activation_fn=self.activation_fn,
|
||||
use_sde=self.use_sde,
|
||||
log_std_init=self.log_std_init,
|
||||
squash_output=self.dist_kwargs['squash_output'] if self.dist_kwargs else None,
|
||||
full_std=self.dist_kwargs['full_std'] if self.dist_kwargs else None,
|
||||
sde_net_arch=self.dist_kwargs['sde_net_arch'] if self.dist_kwargs else None,
|
||||
use_expln=self.dist_kwargs['use_expln'] if self.dist_kwargs else None,
|
||||
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
|
||||
ortho_init=self.ortho_init,
|
||||
optimizer_class=self.optimizer_class,
|
||||
optimizer_kwargs=self.optimizer_kwargs,
|
||||
features_extractor_class=self.features_extractor_class,
|
||||
features_extractor_kwargs=self.features_extractor_kwargs
|
||||
))
|
||||
return data
|
||||
|
||||
|
|
@ -145,7 +146,8 @@ class PPOPolicy(BasePolicy):
|
|||
|
||||
:param n_envs: (int)
|
||||
"""
|
||||
assert isinstance(self.action_dist, StateDependentNoiseDistribution), 'reset_noise() is only available when using SDE'
|
||||
assert isinstance(self.action_dist,
|
||||
StateDependentNoiseDistribution), 'reset_noise() is only available when using SDE'
|
||||
self.action_dist.sample_weights(self.log_std, batch_size=n_envs)
|
||||
|
||||
def _build(self, lr_schedule: Callable) -> None:
|
||||
|
|
@ -319,6 +321,7 @@ class CnnPolicy(PPOPolicy):
|
|||
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ import torch.nn as nn
|
|||
|
||||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
from stable_baselines3.common.policies import (BasePolicy, register_policy, create_mlp,
|
||||
create_sde_features_extractor, NatureCNN,
|
||||
BaseFeaturesExtractor, FlattenExtractor)
|
||||
create_sde_features_extractor, NatureCNN,
|
||||
BaseFeaturesExtractor, FlattenExtractor)
|
||||
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
|
||||
|
||||
# CAP the standard deviation of the actor
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import torch.nn as nn
|
|||
|
||||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
from stable_baselines3.common.policies import (BasePolicy, register_policy, create_mlp,
|
||||
NatureCNN, BaseFeaturesExtractor, FlattenExtractor)
|
||||
NatureCNN, BaseFeaturesExtractor, FlattenExtractor)
|
||||
|
||||
|
||||
class Actor(BasePolicy):
|
||||
|
|
@ -24,6 +24,7 @@ class Actor(BasePolicy):
|
|||
dividing by 255.0 (True by default)
|
||||
:param device: (Union[th.device, str]) Device on which the code should run.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
|
|
@ -39,7 +40,6 @@ class Actor(BasePolicy):
|
|||
device=device,
|
||||
squash_output=True)
|
||||
|
||||
|
||||
self.features_extractor = features_extractor
|
||||
self.normalize_images = normalize_images
|
||||
self.net_arch = net_arch
|
||||
|
|
@ -55,10 +55,10 @@ class Actor(BasePolicy):
|
|||
data = super()._get_data()
|
||||
|
||||
data.update(dict(
|
||||
net_arch=self.net_arch,
|
||||
features_dim=self.features_dim,
|
||||
activation_fn=self.activation_fn,
|
||||
features_extractor=self.features_extractor
|
||||
net_arch=self.net_arch,
|
||||
features_dim=self.features_dim,
|
||||
activation_fn=self.activation_fn,
|
||||
features_extractor=self.features_extractor
|
||||
))
|
||||
return data
|
||||
|
||||
|
|
@ -87,6 +87,7 @@ class Critic(BasePolicy):
|
|||
dividing by 255.0 (True by default)
|
||||
:param device: (Union[th.device, str]) Device on which the code should run.
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
net_arch: List[int],
|
||||
|
|
@ -141,6 +142,7 @@ class TD3Policy(BasePolicy):
|
|||
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: Callable,
|
||||
|
|
@ -204,13 +206,13 @@ class TD3Policy(BasePolicy):
|
|||
data = super()._get_data()
|
||||
|
||||
data.update(dict(
|
||||
net_arch=self.net_args['net_arch'],
|
||||
activation_fn=self.net_args['activation_fn'],
|
||||
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
|
||||
optimizer_class=self.optimizer_class,
|
||||
optimizer_kwargs=self.optimizer_kwargs,
|
||||
features_extractor_class=self.features_extractor_class,
|
||||
features_extractor_kwargs=self.features_extractor_kwargs
|
||||
net_arch=self.net_args['net_arch'],
|
||||
activation_fn=self.net_args['activation_fn'],
|
||||
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
|
||||
optimizer_class=self.optimizer_class,
|
||||
optimizer_kwargs=self.optimizer_kwargs,
|
||||
features_extractor_class=self.features_extractor_class,
|
||||
features_extractor_kwargs=self.features_extractor_kwargs
|
||||
))
|
||||
return data
|
||||
|
||||
|
|
@ -250,6 +252,7 @@ class CnnPolicy(TD3Policy):
|
|||
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: Callable,
|
||||
|
|
|
|||
Loading…
Reference in a new issue