mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-23 22:20:18 +00:00
Add squash_output attribute to policy
This commit is contained in:
parent
aa8b4eb22a
commit
a2b1bf06d3
6 changed files with 53 additions and 37 deletions
|
|
@ -86,7 +86,7 @@ class BaseRLModel(ABC):
|
|||
self.num_timesteps = 0
|
||||
self.eval_env = None
|
||||
self.seed = seed
|
||||
self.action_noise = None # type: ActionNoise
|
||||
self.action_noise = None # type: Optional[ActionNoise]
|
||||
self.start_time = None
|
||||
self.policy = None
|
||||
self.learning_rate = None
|
||||
|
|
@ -97,8 +97,8 @@ class BaseRLModel(ABC):
|
|||
# this is used to update the learning rate
|
||||
self._current_progress = 1
|
||||
# Buffers for logging
|
||||
self.ep_info_buffer = None # type: deque
|
||||
self.ep_success_buffer = None # type: deque
|
||||
self.ep_info_buffer = None # type: Optional[deque]
|
||||
self.ep_success_buffer = None # type: Optional[deque]
|
||||
|
||||
# Create and wrap the env if needed
|
||||
if env is not None:
|
||||
|
|
@ -387,13 +387,12 @@ class BaseRLModel(ABC):
|
|||
actions = actions.cpu().numpy()
|
||||
|
||||
# Rescale to proper domain when using squashing
|
||||
# TODO: should not be used for a Gaussian distribution?
|
||||
if isinstance(self.action_space, gym.spaces.Box):
|
||||
if isinstance(self.action_space, gym.spaces.Box) and self.policy.squash_output:
|
||||
actions = self.unscale_action(actions)
|
||||
|
||||
clipped_actions = actions
|
||||
# Clip the actions to avoid out of bound error when using gaussian distribution
|
||||
if isinstance(self.action_space, gym.spaces.Box):
|
||||
if isinstance(self.action_space, gym.spaces.Box) and not self.policy.squash_output:
|
||||
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
|
||||
|
||||
if not vectorized_env:
|
||||
|
|
|
|||
|
|
@ -22,14 +22,14 @@ class BaseCallback(ABC):
|
|||
"""
|
||||
def __init__(self, verbose: int = 0):
|
||||
super(BaseCallback, self).__init__()
|
||||
self.model = None # type: BaseRLModel
|
||||
self.model = None # type: Optional[BaseRLModel]
|
||||
self.training_env = None # type: Union[gym.Env, VecEnv, None]
|
||||
self.n_calls = 0 # type: int
|
||||
self.num_timesteps = 0 # type: int
|
||||
self.verbose = verbose
|
||||
self.locals = None # type: Dict[str, Any]
|
||||
self.globals = None # type: Dict[str, Any]
|
||||
self.logger = None # type: Logger
|
||||
self.locals = None # type: Optional[Dict[str, Any]]
|
||||
self.globals = None # type: Optional[Dict[str, Any]]
|
||||
self.logger = None # type: Optional[Logger]
|
||||
# Sometimes, for event callback, it is useful
|
||||
# to have access to the parent object
|
||||
self.parent = None # type: Optional[BaseCallback]
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Union
|
||||
from typing import Union, Type, Dict, List, Tuple
|
||||
|
||||
from itertools import zip_longest
|
||||
|
||||
|
|
@ -14,14 +14,24 @@ class BasePolicy(nn.Module):
|
|||
|
||||
:param observation_space: (gym.spaces.Space) The observation space of the environment
|
||||
:param action_space: (gym.spaces.Space) The action space of the environment
|
||||
:param device: (Union[th.device, str]) Device on which the code should run.
|
||||
:param squash_output: (bool) For continuous actions, whether the output is squashed
|
||||
or not using a `tanh()` function.
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space, device: Union[th.device, str] = 'cpu'):
|
||||
action_space: gym.spaces.Space,
|
||||
device: Union[th.device, str] = 'cpu',
|
||||
squash_output: bool = False):
|
||||
super(BasePolicy, self).__init__()
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
self.device = device
|
||||
self._squash_output = squash_output
|
||||
|
||||
@property
|
||||
def squash_output(self) -> bool:
|
||||
""" (bool) Getter for squash_output."""
|
||||
return self._squash_output
|
||||
|
||||
@staticmethod
|
||||
def init_weights(module: nn.Module, gain: float = 1):
|
||||
|
|
@ -71,21 +81,25 @@ class BasePolicy(nn.Module):
|
|||
return th.nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy()
|
||||
|
||||
|
||||
def create_mlp(input_dim, output_dim, net_arch,
|
||||
activation_fn=nn.ReLU, squash_out=False):
|
||||
def create_mlp(input_dim: int,
|
||||
output_dim: int,
|
||||
net_arch: List[int],
|
||||
activation_fn: nn.Module = nn.ReLU,
|
||||
squash_output: bool = False) -> List[nn.Module]:
|
||||
"""
|
||||
Create a multi layer perceptron (MLP), which is
|
||||
a collection of fully-connected layers each followed by an activation function.
|
||||
|
||||
:param input_dim: (int) Dimension of the input vector
|
||||
:param output_dim: (int)
|
||||
:param net_arch: ([int]) Architecture of the neural net
|
||||
:param net_arch: (List[int]) Architecture of the neural net
|
||||
It represents the number of units per layer.
|
||||
The length of this list is the number of layers.
|
||||
:param activation_fn: (th.nn.Module) The activation function
|
||||
:param activation_fn: (nn.Module) The activation function
|
||||
to use after each layer.
|
||||
:param squash_out: (bool) Whether to squash the output using a Tanh
|
||||
:param squash_output: (bool) Whether to squash the output using a Tanh
|
||||
activation function
|
||||
:return: (List[nn.Module])
|
||||
"""
|
||||
|
||||
if len(net_arch) > 0:
|
||||
|
|
@ -99,12 +113,14 @@ def create_mlp(input_dim, output_dim, net_arch,
|
|||
|
||||
if output_dim > 0:
|
||||
modules.append(nn.Linear(net_arch[-1], output_dim))
|
||||
if squash_out:
|
||||
if squash_output:
|
||||
modules.append(nn.Tanh())
|
||||
return modules
|
||||
|
||||
|
||||
def create_sde_feature_extractor(features_dim, sde_net_arch, activation_fn):
|
||||
def create_sde_feature_extractor(features_dim: int,
|
||||
sde_net_arch: List[int],
|
||||
activation_fn: nn.Module) -> Tuple[nn.Sequential, int]:
|
||||
"""
|
||||
Create the neural network that will be used to extract features
|
||||
for the SDE.
|
||||
|
|
@ -117,7 +133,7 @@ def create_sde_feature_extractor(features_dim, sde_net_arch, activation_fn):
|
|||
# Special case: when using states as features (i.e. sde_net_arch is an empty list)
|
||||
# don't use any activation function
|
||||
sde_activation = activation_fn if len(sde_net_arch) > 0 else None
|
||||
latent_sde_net = create_mlp(features_dim, -1, sde_net_arch, activation_fn=sde_activation, squash_out=False)
|
||||
latent_sde_net = create_mlp(features_dim, -1, sde_net_arch, activation_fn=sde_activation, squash_output=False)
|
||||
latent_sde_dim = sde_net_arch[-1] if len(sde_net_arch) > 0 else features_dim
|
||||
sde_feature_extractor = nn.Sequential(*latent_sde_net)
|
||||
return sde_feature_extractor, latent_sde_dim
|
||||
|
|
@ -131,7 +147,7 @@ class BaseNetwork(nn.Module):
|
|||
def __init__(self):
|
||||
super(BaseNetwork, self).__init__()
|
||||
|
||||
def load_from_vector(self, vector):
|
||||
def load_from_vector(self, vector: np.ndarray):
|
||||
"""
|
||||
Load parameters from a 1D vector.
|
||||
|
||||
|
|
@ -140,7 +156,7 @@ class BaseNetwork(nn.Module):
|
|||
device = next(self.parameters()).device
|
||||
th.nn.utils.vector_to_parameters(th.FloatTensor(vector).to(device), self.parameters())
|
||||
|
||||
def parameters_to_vector(self):
|
||||
def parameters_to_vector(self) -> np.ndarray:
|
||||
"""
|
||||
Convert the parameters to a 1D vector.
|
||||
|
||||
|
|
@ -149,16 +165,16 @@ class BaseNetwork(nn.Module):
|
|||
return th.nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy()
|
||||
|
||||
|
||||
_policy_registry = dict()
|
||||
_policy_registry = dict() # type: Dict[Type[BasePolicy], Dict[str, Type[BasePolicy]]]
|
||||
|
||||
|
||||
def get_policy_from_name(base_policy_type, name):
|
||||
def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[BasePolicy]:
|
||||
"""
|
||||
returns the registed policy from the base type and name
|
||||
Returns the registered policy from the base type and name
|
||||
|
||||
:param base_policy_type: (BasePolicy) the base policy object
|
||||
:param base_policy_type: (Type[BasePolicy]) the base policy class
|
||||
:param name: (str) the policy name
|
||||
:return: (base_policy_type) the policy
|
||||
:return: (Type[BasePolicy]) the policy
|
||||
"""
|
||||
if base_policy_type not in _policy_registry:
|
||||
raise ValueError(f"Error: the policy type {base_policy_type} is not registered!")
|
||||
|
|
@ -168,12 +184,13 @@ def get_policy_from_name(base_policy_type, name):
|
|||
return _policy_registry[base_policy_type][name]
|
||||
|
||||
|
||||
def register_policy(name, policy):
|
||||
def register_policy(name: str, policy: Type[BasePolicy]) -> None:
|
||||
"""
|
||||
returns the registed policy from the base type and name
|
||||
Register a policy, so it can be called using its name.
|
||||
e.g. SAC('MlpPolicy', ...) instead of SAC(MlpPolicy, ...)
|
||||
|
||||
:param name: (str) the policy name
|
||||
:param policy: (subclass of BasePolicy) the policy
|
||||
:param policy: (Type[BasePolicy]) the policy class
|
||||
"""
|
||||
sub_class = None
|
||||
# For building the doc
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ class PPOPolicy(BasePolicy):
|
|||
ortho_init=True, use_sde=False,
|
||||
log_std_init=0.0, full_std=True,
|
||||
sde_net_arch=None, use_expln=False, squash_output=False):
|
||||
super(PPOPolicy, self).__init__(observation_space, action_space, device)
|
||||
super(PPOPolicy, self).__init__(observation_space, action_space, device, squash_output=squash_output)
|
||||
self.obs_dim = self.observation_space.shape[0]
|
||||
|
||||
# Default network architecture, from stable-baselines
|
||||
|
|
|
|||
|
|
@ -200,7 +200,7 @@ class SACPolicy(BasePolicy):
|
|||
learning_rate, net_arch=None, device='cpu',
|
||||
activation_fn=nn.ReLU, use_sde=False,
|
||||
log_std_init=-3, sde_net_arch=None, use_expln=False):
|
||||
super(SACPolicy, self).__init__(observation_space, action_space, device)
|
||||
super(SACPolicy, self).__init__(observation_space, action_space, device, squash_output=True)
|
||||
|
||||
if net_arch is None:
|
||||
net_arch = [256, 256]
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ class Actor(BaseNetwork):
|
|||
self.sde_feature_extractor = None
|
||||
|
||||
if use_sde:
|
||||
latent_pi_net = create_mlp(obs_dim, -1, net_arch, activation_fn, squash_out=False)
|
||||
latent_pi_net = create_mlp(obs_dim, -1, net_arch, activation_fn, squash_output=False)
|
||||
self.latent_pi = nn.Sequential(*latent_pi_net)
|
||||
latent_sde_dim = net_arch[-1]
|
||||
learn_features = sde_net_arch is not None
|
||||
|
|
@ -74,7 +74,7 @@ class Actor(BaseNetwork):
|
|||
self.sde_optimizer = th.optim.Adam([self.log_std], lr=lr_sde)
|
||||
self.reset_noise()
|
||||
else:
|
||||
actor_net = create_mlp(obs_dim, action_dim, net_arch, activation_fn, squash_out=True)
|
||||
actor_net = create_mlp(obs_dim, action_dim, net_arch, activation_fn, squash_output=True)
|
||||
self.mu = nn.Sequential(*actor_net)
|
||||
|
||||
def get_std(self) -> torch.Tensor:
|
||||
|
|
@ -134,7 +134,7 @@ class Actor(BaseNetwork):
|
|||
if self.clip_noise is not None:
|
||||
noise = th.clamp(noise, -self.clip_noise, self.clip_noise)
|
||||
# TODO: Replace with squashing -> need to account for that in the sde update
|
||||
# -> set squash_out=True in the action_dist?
|
||||
# -> set squash_output=True in the action_dist?
|
||||
# NOTE: the clipping is done in the rollout for now
|
||||
return self.mu(latent_pi) + noise
|
||||
# action, _ = self._get_action_dist_from_latent(latent_pi)
|
||||
|
|
@ -215,7 +215,7 @@ class TD3Policy(BasePolicy):
|
|||
learning_rate, net_arch=None, device='cpu',
|
||||
activation_fn=nn.ReLU, use_sde=False, log_std_init=-3,
|
||||
clip_noise=None, lr_sde=3e-4, sde_net_arch=None, use_expln=False):
|
||||
super(TD3Policy, self).__init__(observation_space, action_space, device)
|
||||
super(TD3Policy, self).__init__(observation_space, action_space, device, squash_output=True)
|
||||
|
||||
# Default network architecture, from the original paper
|
||||
if net_arch is None:
|
||||
|
|
|
|||
Loading…
Reference in a new issue