mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-03 03:59:13 +00:00
Review policies
This commit is contained in:
parent
cc7a58bc5f
commit
e9d8e05cc8
1 changed files with 35 additions and 29 deletions
|
|
@ -1,3 +1,7 @@
|
|||
"""Policies: abstract base class and concrete implementations."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import collections
|
||||
from typing import Union, Type, Dict, List, Tuple, Optional, Any, Callable
|
||||
from functools import partial
|
||||
|
||||
|
|
@ -17,7 +21,7 @@ from stable_baselines3.common.distributions import (make_proba_distribution, Dis
|
|||
StateDependentNoiseDistribution)
|
||||
|
||||
|
||||
class BasePolicy(nn.Module):
|
||||
class BasePolicy(nn.Module, ABC):
|
||||
"""
|
||||
The base policy object
|
||||
|
||||
|
|
@ -98,13 +102,16 @@ class BasePolicy(nn.Module):
|
|||
module.bias.data.fill_(0.0)
|
||||
|
||||
@staticmethod
|
||||
def _dummy_schedule(_progress_remaining: float) -> float:
|
||||
def _dummy_schedule(progress_remaining: float) -> float:
|
||||
""" (float) Useful for pickling policy."""
|
||||
del progress_remaining
|
||||
return 0.0
|
||||
|
||||
def forward(self, *_args, **kwargs):
|
||||
raise NotImplementedError()
|
||||
@abstractmethod
|
||||
def forward(self, *args, **kwargs):
|
||||
del args, kwargs
|
||||
|
||||
@abstractmethod
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
"""
|
||||
Get the action according to the policy for a given observation.
|
||||
|
|
@ -113,7 +120,6 @@ class BasePolicy(nn.Module):
|
|||
:param deterministic: (bool) Whether to use stochastic or deterministic actions
|
||||
:return: (th.Tensor) Taken action according to the policy
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def predict(self,
|
||||
observation: np.ndarray,
|
||||
|
|
@ -140,10 +146,8 @@ class BasePolicy(nn.Module):
|
|||
# Handle the different cases for images
|
||||
# as PyTorch use channel first format
|
||||
if is_image_space(self.observation_space):
|
||||
if (observation.shape == self.observation_space.shape
|
||||
if not (observation.shape == self.observation_space.shape
|
||||
or observation.shape[1:] == self.observation_space.shape):
|
||||
pass
|
||||
else:
|
||||
# Try to re-order the channels
|
||||
transpose_obs = VecTransposeImage.transpose_image(observation)
|
||||
if (transpose_obs.shape == self.observation_space.shape
|
||||
|
|
@ -160,21 +164,21 @@ class BasePolicy(nn.Module):
|
|||
# Convert to numpy
|
||||
actions = actions.cpu().numpy()
|
||||
|
||||
# Rescale to proper domain when using squashing
|
||||
if isinstance(self.action_space, gym.spaces.Box) and self.squash_output:
|
||||
actions = self.unscale_action(actions)
|
||||
|
||||
clipped_actions = actions
|
||||
# Clip the actions to avoid out of bound error when using gaussian distribution
|
||||
if isinstance(self.action_space, gym.spaces.Box) and not self.squash_output:
|
||||
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
|
||||
if isinstance(self.action_space, gym.spaces.Box):
|
||||
if self.squash_output:
|
||||
# Rescale to proper domain when using squashing
|
||||
actions = self.unscale_action(actions)
|
||||
else:
|
||||
# Actions could be on arbitrary scale, so clip the actions to avoid
|
||||
# out of bound error (e.g. if sampling from a Gaussian distribution)
|
||||
actions = np.clip(actions, self.action_space.low, self.action_space.high)
|
||||
|
||||
if not vectorized_env:
|
||||
if state is not None:
|
||||
raise ValueError("Error: The environment must be vectorized when using recurrent policies.")
|
||||
clipped_actions = clipped_actions[0]
|
||||
actions = actions[0]
|
||||
|
||||
return clipped_actions, state
|
||||
return actions, state
|
||||
|
||||
def scale_action(self, action: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
|
|
@ -227,7 +231,7 @@ class BasePolicy(nn.Module):
|
|||
Load policy from path.
|
||||
|
||||
:param path: (str)
|
||||
:param device: ( Union[th.device, str]) Device on which the policy should be loaded.
|
||||
:param device: (Union[th.device, str]) Device on which the policy should be loaded.
|
||||
:return: (BasePolicy)
|
||||
"""
|
||||
device = get_device(device)
|
||||
|
|
@ -294,7 +298,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||
def __init__(self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: Callable,
|
||||
lr_schedule: Callable[[float], float],
|
||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||
device: Union[th.device, str] = 'auto',
|
||||
activation_fn: Type[nn.Module] = nn.Tanh,
|
||||
|
|
@ -313,7 +317,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||
|
||||
if optimizer_kwargs is None:
|
||||
optimizer_kwargs = {}
|
||||
# Small values to avoid NaN in ADAM optimizer
|
||||
# Small values to avoid NaN in Adam optimizer
|
||||
if optimizer_class == th.optim.Adam:
|
||||
optimizer_kwargs['eps'] = 1e-5
|
||||
|
||||
|
|
@ -366,15 +370,17 @@ class ActorCriticPolicy(BasePolicy):
|
|||
def _get_data(self) -> Dict[str, Any]:
|
||||
data = super()._get_data()
|
||||
|
||||
default_none_kwargs = self.dist_kwargs or collections.defaultdict()
|
||||
|
||||
data.update(dict(
|
||||
net_arch=self.net_arch,
|
||||
activation_fn=self.activation_fn,
|
||||
use_sde=self.use_sde,
|
||||
log_std_init=self.log_std_init,
|
||||
squash_output=self.dist_kwargs['squash_output'] if self.dist_kwargs else None,
|
||||
full_std=self.dist_kwargs['full_std'] if self.dist_kwargs else None,
|
||||
sde_net_arch=self.dist_kwargs['sde_net_arch'] if self.dist_kwargs else None,
|
||||
use_expln=self.dist_kwargs['use_expln'] if self.dist_kwargs else None,
|
||||
squash_output=default_none_kwargs['squash_output'],
|
||||
full_std=default_none_kwargs['full_std'],
|
||||
sde_net_arch=default_none_kwargs['sde_net_arch'],
|
||||
use_expln=default_none_kwargs['use_expln'],
|
||||
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
|
||||
ortho_init=self.ortho_init,
|
||||
optimizer_class=self.optimizer_class,
|
||||
|
|
@ -394,7 +400,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||
StateDependentNoiseDistribution), 'reset_noise() is only available when using gSDE'
|
||||
self.action_dist.sample_weights(self.log_std, batch_size=n_envs)
|
||||
|
||||
def _build(self, lr_schedule: Callable) -> None:
|
||||
def _build(self, lr_schedule: Callable[[float], float]) -> None:
|
||||
"""
|
||||
Create the networks and the optimizer.
|
||||
|
||||
|
|
@ -651,10 +657,10 @@ def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[
|
|||
:return: (Type[BasePolicy]) the policy
|
||||
"""
|
||||
if base_policy_type not in _policy_registry:
|
||||
raise ValueError(f"Error: the policy type {base_policy_type} is not registered!")
|
||||
raise KeyError(f"Error: the policy type {base_policy_type} is not registered!")
|
||||
if name not in _policy_registry[base_policy_type]:
|
||||
raise ValueError(f"Error: unknown policy type {name},"
|
||||
f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!")
|
||||
raise KeyError(f"Error: unknown policy type {name},"
|
||||
f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!")
|
||||
return _policy_registry[base_policy_type][name]
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue