Review policies

This commit is contained in:
Adam Gleave 2020-07-02 21:04:36 -07:00
parent cc7a58bc5f
commit e9d8e05cc8

View file

@ -1,3 +1,7 @@
"""Policies: abstract base class and concrete implementations."""
from abc import ABC, abstractmethod
import collections
from typing import Union, Type, Dict, List, Tuple, Optional, Any, Callable
from functools import partial
@ -17,7 +21,7 @@ from stable_baselines3.common.distributions import (make_proba_distribution, Dis
StateDependentNoiseDistribution)
class BasePolicy(nn.Module):
class BasePolicy(nn.Module, ABC):
"""
The base policy object
@ -98,13 +102,16 @@ class BasePolicy(nn.Module):
module.bias.data.fill_(0.0)
@staticmethod
def _dummy_schedule(_progress_remaining: float) -> float:
def _dummy_schedule(progress_remaining: float) -> float:
""" (float) Useful for pickling policy."""
del progress_remaining
return 0.0
def forward(self, *_args, **kwargs):
raise NotImplementedError()
@abstractmethod
def forward(self, *args, **kwargs):
del args, kwargs
@abstractmethod
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
"""
Get the action according to the policy for a given observation.
@ -113,7 +120,6 @@ class BasePolicy(nn.Module):
:param deterministic: (bool) Whether to use stochastic or deterministic actions
:return: (th.Tensor) Taken action according to the policy
"""
raise NotImplementedError()
def predict(self,
observation: np.ndarray,
@ -140,10 +146,8 @@ class BasePolicy(nn.Module):
# Handle the different cases for images
# as PyTorch use channel first format
if is_image_space(self.observation_space):
if (observation.shape == self.observation_space.shape
if not (observation.shape == self.observation_space.shape
or observation.shape[1:] == self.observation_space.shape):
pass
else:
# Try to re-order the channels
transpose_obs = VecTransposeImage.transpose_image(observation)
if (transpose_obs.shape == self.observation_space.shape
@ -160,21 +164,21 @@ class BasePolicy(nn.Module):
# Convert to numpy
actions = actions.cpu().numpy()
# Rescale to proper domain when using squashing
if isinstance(self.action_space, gym.spaces.Box) and self.squash_output:
actions = self.unscale_action(actions)
clipped_actions = actions
# Clip the actions to avoid out of bound error when using gaussian distribution
if isinstance(self.action_space, gym.spaces.Box) and not self.squash_output:
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
if isinstance(self.action_space, gym.spaces.Box):
if self.squash_output:
# Rescale to proper domain when using squashing
actions = self.unscale_action(actions)
else:
# Actions could be on arbitrary scale, so clip the actions to avoid
# out of bound error (e.g. if sampling from a Gaussian distribution)
actions = np.clip(actions, self.action_space.low, self.action_space.high)
if not vectorized_env:
if state is not None:
raise ValueError("Error: The environment must be vectorized when using recurrent policies.")
clipped_actions = clipped_actions[0]
actions = actions[0]
return clipped_actions, state
return actions, state
def scale_action(self, action: np.ndarray) -> np.ndarray:
"""
@ -227,7 +231,7 @@ class BasePolicy(nn.Module):
Load policy from path.
:param path: (str)
:param device: ( Union[th.device, str]) Device on which the policy should be loaded.
:param device: (Union[th.device, str]) Device on which the policy should be loaded.
:return: (BasePolicy)
"""
device = get_device(device)
@ -294,7 +298,7 @@ class ActorCriticPolicy(BasePolicy):
def __init__(self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Callable,
lr_schedule: Callable[[float], float],
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
device: Union[th.device, str] = 'auto',
activation_fn: Type[nn.Module] = nn.Tanh,
@ -313,7 +317,7 @@ class ActorCriticPolicy(BasePolicy):
if optimizer_kwargs is None:
optimizer_kwargs = {}
# Small values to avoid NaN in ADAM optimizer
# Small values to avoid NaN in Adam optimizer
if optimizer_class == th.optim.Adam:
optimizer_kwargs['eps'] = 1e-5
@ -366,15 +370,17 @@ class ActorCriticPolicy(BasePolicy):
def _get_data(self) -> Dict[str, Any]:
data = super()._get_data()
default_none_kwargs = self.dist_kwargs or collections.defaultdict()
data.update(dict(
net_arch=self.net_arch,
activation_fn=self.activation_fn,
use_sde=self.use_sde,
log_std_init=self.log_std_init,
squash_output=self.dist_kwargs['squash_output'] if self.dist_kwargs else None,
full_std=self.dist_kwargs['full_std'] if self.dist_kwargs else None,
sde_net_arch=self.dist_kwargs['sde_net_arch'] if self.dist_kwargs else None,
use_expln=self.dist_kwargs['use_expln'] if self.dist_kwargs else None,
squash_output=default_none_kwargs['squash_output'],
full_std=default_none_kwargs['full_std'],
sde_net_arch=default_none_kwargs['sde_net_arch'],
use_expln=default_none_kwargs['use_expln'],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
ortho_init=self.ortho_init,
optimizer_class=self.optimizer_class,
@ -394,7 +400,7 @@ class ActorCriticPolicy(BasePolicy):
StateDependentNoiseDistribution), 'reset_noise() is only available when using gSDE'
self.action_dist.sample_weights(self.log_std, batch_size=n_envs)
def _build(self, lr_schedule: Callable) -> None:
def _build(self, lr_schedule: Callable[[float], float]) -> None:
"""
Create the networks and the optimizer.
@ -651,10 +657,10 @@ def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[
:return: (Type[BasePolicy]) the policy
"""
if base_policy_type not in _policy_registry:
raise ValueError(f"Error: the policy type {base_policy_type} is not registered!")
raise KeyError(f"Error: the policy type {base_policy_type} is not registered!")
if name not in _policy_registry[base_policy_type]:
raise ValueError(f"Error: unknown policy type {name},"
f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!")
raise KeyError(f"Error: unknown policy type {name},"
f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!")
return _policy_registry[base_policy_type][name]