From e9d8e05cc8e2ac09c2f011ef37589bd405f817d9 Mon Sep 17 00:00:00 2001 From: Adam Gleave Date: Thu, 2 Jul 2020 21:04:36 -0700 Subject: [PATCH] Review policies --- stable_baselines3/common/policies.py | 64 +++++++++++++++------------- 1 file changed, 35 insertions(+), 29 deletions(-) diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 6d4c87f..48e066c 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -1,3 +1,7 @@ +"""Policies: abstract base class and concrete implementations.""" + +from abc import ABC, abstractmethod +import collections from typing import Union, Type, Dict, List, Tuple, Optional, Any, Callable from functools import partial @@ -17,7 +21,7 @@ from stable_baselines3.common.distributions import (make_proba_distribution, Dis StateDependentNoiseDistribution) -class BasePolicy(nn.Module): +class BasePolicy(nn.Module, ABC): """ The base policy object @@ -98,13 +102,16 @@ class BasePolicy(nn.Module): module.bias.data.fill_(0.0) @staticmethod - def _dummy_schedule(_progress_remaining: float) -> float: + def _dummy_schedule(progress_remaining: float) -> float: """ (float) Useful for pickling policy.""" + del progress_remaining return 0.0 - def forward(self, *_args, **kwargs): - raise NotImplementedError() + @abstractmethod + def forward(self, *args, **kwargs): + del args, kwargs + @abstractmethod def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: """ Get the action according to the policy for a given observation. @@ -113,7 +120,6 @@ class BasePolicy(nn.Module): :param deterministic: (bool) Whether to use stochastic or deterministic actions :return: (th.Tensor) Taken action according to the policy """ - raise NotImplementedError() def predict(self, observation: np.ndarray, @@ -140,10 +146,8 @@ class BasePolicy(nn.Module): # Handle the different cases for images # as PyTorch use channel first format if is_image_space(self.observation_space): - if (observation.shape == self.observation_space.shape + if not (observation.shape == self.observation_space.shape or observation.shape[1:] == self.observation_space.shape): - pass - else: # Try to re-order the channels transpose_obs = VecTransposeImage.transpose_image(observation) if (transpose_obs.shape == self.observation_space.shape @@ -160,21 +164,21 @@ class BasePolicy(nn.Module): # Convert to numpy actions = actions.cpu().numpy() - # Rescale to proper domain when using squashing - if isinstance(self.action_space, gym.spaces.Box) and self.squash_output: - actions = self.unscale_action(actions) - - clipped_actions = actions - # Clip the actions to avoid out of bound error when using gaussian distribution - if isinstance(self.action_space, gym.spaces.Box) and not self.squash_output: - clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high) + if isinstance(self.action_space, gym.spaces.Box): + if self.squash_output: + # Rescale to proper domain when using squashing + actions = self.unscale_action(actions) + else: + # Actions could be on arbitrary scale, so clip the actions to avoid + # out of bound error (e.g. if sampling from a Gaussian distribution) + actions = np.clip(actions, self.action_space.low, self.action_space.high) if not vectorized_env: if state is not None: raise ValueError("Error: The environment must be vectorized when using recurrent policies.") - clipped_actions = clipped_actions[0] + actions = actions[0] - return clipped_actions, state + return actions, state def scale_action(self, action: np.ndarray) -> np.ndarray: """ @@ -227,7 +231,7 @@ class BasePolicy(nn.Module): Load policy from path. :param path: (str) - :param device: ( Union[th.device, str]) Device on which the policy should be loaded. + :param device: (Union[th.device, str]) Device on which the policy should be loaded. :return: (BasePolicy) """ device = get_device(device) @@ -294,7 +298,7 @@ class ActorCriticPolicy(BasePolicy): def __init__(self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, - lr_schedule: Callable, + lr_schedule: Callable[[float], float], net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, device: Union[th.device, str] = 'auto', activation_fn: Type[nn.Module] = nn.Tanh, @@ -313,7 +317,7 @@ class ActorCriticPolicy(BasePolicy): if optimizer_kwargs is None: optimizer_kwargs = {} - # Small values to avoid NaN in ADAM optimizer + # Small values to avoid NaN in Adam optimizer if optimizer_class == th.optim.Adam: optimizer_kwargs['eps'] = 1e-5 @@ -366,15 +370,17 @@ class ActorCriticPolicy(BasePolicy): def _get_data(self) -> Dict[str, Any]: data = super()._get_data() + default_none_kwargs = self.dist_kwargs or collections.defaultdict() + data.update(dict( net_arch=self.net_arch, activation_fn=self.activation_fn, use_sde=self.use_sde, log_std_init=self.log_std_init, - squash_output=self.dist_kwargs['squash_output'] if self.dist_kwargs else None, - full_std=self.dist_kwargs['full_std'] if self.dist_kwargs else None, - sde_net_arch=self.dist_kwargs['sde_net_arch'] if self.dist_kwargs else None, - use_expln=self.dist_kwargs['use_expln'] if self.dist_kwargs else None, + squash_output=default_none_kwargs['squash_output'], + full_std=default_none_kwargs['full_std'], + sde_net_arch=default_none_kwargs['sde_net_arch'], + use_expln=default_none_kwargs['use_expln'], lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone ortho_init=self.ortho_init, optimizer_class=self.optimizer_class, @@ -394,7 +400,7 @@ class ActorCriticPolicy(BasePolicy): StateDependentNoiseDistribution), 'reset_noise() is only available when using gSDE' self.action_dist.sample_weights(self.log_std, batch_size=n_envs) - def _build(self, lr_schedule: Callable) -> None: + def _build(self, lr_schedule: Callable[[float], float]) -> None: """ Create the networks and the optimizer. @@ -651,10 +657,10 @@ def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[ :return: (Type[BasePolicy]) the policy """ if base_policy_type not in _policy_registry: - raise ValueError(f"Error: the policy type {base_policy_type} is not registered!") + raise KeyError(f"Error: the policy type {base_policy_type} is not registered!") if name not in _policy_registry[base_policy_type]: - raise ValueError(f"Error: unknown policy type {name}," - f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!") + raise KeyError(f"Error: unknown policy type {name}," + f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!") return _policy_registry[base_policy_type][name]