mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-21 22:00:21 +00:00
Refactor BasePolicy by introducing new BaseModel ABC for Critic's to inherit from.
This commit is contained in:
parent
91bbc28c0f
commit
0345591dea
3 changed files with 96 additions and 81 deletions
|
|
@ -21,9 +21,12 @@ from stable_baselines3.common.distributions import (make_proba_distribution, Dis
|
|||
StateDependentNoiseDistribution)
|
||||
|
||||
|
||||
class BasePolicy(nn.Module, ABC):
|
||||
class BaseModel(nn.Module, ABC):
|
||||
"""
|
||||
The base policy object
|
||||
The base model object: makes predictions in response to observations.
|
||||
|
||||
In the case of policies, the prediction is an action. In the case of critics, it is the
|
||||
estimated value of the observation.
|
||||
|
||||
:param observation_space: (gym.spaces.Space) The observation space of the environment
|
||||
:param action_space: (gym.spaces.Space) The action space of the environment
|
||||
|
|
@ -39,8 +42,6 @@ class BasePolicy(nn.Module, ABC):
|
|||
``th.optim.Adam`` by default
|
||||
:param optimizer_kwargs: (Optional[Dict[str, Any]]) Additional keyword arguments,
|
||||
excluding the learning rate, to pass to the optimizer
|
||||
:param squash_output: (bool) For continuous actions, whether the output is squashed
|
||||
or not using a ``tanh()`` function.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
|
@ -52,9 +53,8 @@ class BasePolicy(nn.Module, ABC):
|
|||
features_extractor: Optional[nn.Module] = None,
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
squash_output: bool = False):
|
||||
super(BasePolicy, self).__init__()
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None):
|
||||
super(BaseModel, self).__init__()
|
||||
|
||||
if optimizer_kwargs is None:
|
||||
optimizer_kwargs = {}
|
||||
|
|
@ -67,7 +67,6 @@ class BasePolicy(nn.Module, ABC):
|
|||
self.device = get_device(device)
|
||||
self.features_extractor = features_extractor
|
||||
self.normalize_images = normalize_images
|
||||
self._squash_output = squash_output
|
||||
|
||||
self.optimizer_class = optimizer_class
|
||||
self.optimizer_kwargs = optimizer_kwargs
|
||||
|
|
@ -76,6 +75,10 @@ class BasePolicy(nn.Module, ABC):
|
|||
self.features_extractor_class = features_extractor_class
|
||||
self.features_extractor_kwargs = features_extractor_kwargs
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, *args, **kwargs):
|
||||
del args, kwargs
|
||||
|
||||
def extract_features(self, obs: th.Tensor) -> th.Tensor:
|
||||
"""
|
||||
Preprocess the observation if needed and extract features.
|
||||
|
|
@ -87,9 +90,89 @@ class BasePolicy(nn.Module, ABC):
|
|||
preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images)
|
||||
return self.features_extractor(preprocessed_obs)
|
||||
|
||||
def _get_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get data that need to be saved in order to re-create the model.
|
||||
This corresponds to the arguments of the constructor.
|
||||
|
||||
:return: (Dict[str, Any])
|
||||
"""
|
||||
return dict(
|
||||
observation_space=self.observation_space,
|
||||
action_space=self.action_space,
|
||||
# Passed to the constructor by child class
|
||||
# squash_output=self.squash_output,
|
||||
# features_extractor=self.features_extractor
|
||||
normalize_images=self.normalize_images,
|
||||
)
|
||||
|
||||
def save(self, path: str) -> None:
|
||||
"""
|
||||
Save model to a given location.
|
||||
|
||||
:param path: (str)
|
||||
"""
|
||||
th.save({'state_dict': self.state_dict(), 'data': self._get_data()}, path)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str, device: Union[th.device, str] = 'auto') -> 'BaseModel':
|
||||
"""
|
||||
Load model from path.
|
||||
|
||||
:param path: (str)
|
||||
:param device: (Union[th.device, str]) Device on which the policy should be loaded.
|
||||
:return: (BasePolicy)
|
||||
"""
|
||||
device = get_device(device)
|
||||
saved_variables = th.load(path, map_location=device)
|
||||
# Create policy object
|
||||
model = cls(**saved_variables['data'])
|
||||
# Load weights
|
||||
model.load_state_dict(saved_variables['state_dict'])
|
||||
model.to(device)
|
||||
return model
|
||||
|
||||
def load_from_vector(self, vector: np.ndarray):
|
||||
"""
|
||||
Load parameters from a 1D vector.
|
||||
|
||||
:param vector: (np.ndarray)
|
||||
"""
|
||||
th.nn.utils.vector_to_parameters(th.FloatTensor(vector).to(self.device), self.parameters())
|
||||
|
||||
def parameters_to_vector(self) -> np.ndarray:
|
||||
"""
|
||||
Convert the parameters to a 1D vector.
|
||||
|
||||
:return: (np.ndarray)
|
||||
"""
|
||||
return th.nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy()
|
||||
|
||||
|
||||
class BasePolicy(BaseModel):
|
||||
"""The base policy object.
|
||||
|
||||
Parameters are mostly the same as `BaseModel`; additions are documented below.
|
||||
|
||||
:param args: positional arguments passed through to `BaseModel`.
|
||||
:param kwargs: keyword arguments passed through to `BaseModel`.
|
||||
:param squash_output: (bool) For continuous actions, whether the output is squashed
|
||||
or not using a ``tanh()`` function.
|
||||
"""
|
||||
def __init__(self, *args, squash_output: bool = False, **kwargs):
|
||||
super(BasePolicy, self).__init__(*args, **kwargs)
|
||||
self._squash_output = squash_output
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _dummy_schedule(progress_remaining: float) -> float:
|
||||
""" (float) Useful for pickling policy."""
|
||||
del progress_remaining
|
||||
return 0.0
|
||||
|
||||
@property
|
||||
def squash_output(self) -> bool:
|
||||
""" (bool) Getter for squash_output."""
|
||||
"""(bool) Getter for squash_output."""
|
||||
return self._squash_output
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -101,16 +184,7 @@ class BasePolicy(nn.Module, ABC):
|
|||
nn.init.orthogonal_(module.weight, gain=gain)
|
||||
module.bias.data.fill_(0.0)
|
||||
|
||||
@staticmethod
|
||||
def _dummy_schedule(progress_remaining: float) -> float:
|
||||
""" (float) Useful for pickling policy."""
|
||||
del progress_remaining
|
||||
return 0.0
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, *args, **kwargs):
|
||||
del args, kwargs
|
||||
|
||||
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
"""
|
||||
Get the action according to the policy for a given observation.
|
||||
|
|
@ -122,7 +196,6 @@ class BasePolicy(nn.Module, ABC):
|
|||
:param deterministic: (bool) Whether to use stochastic or deterministic actions
|
||||
:return: (th.Tensor) Taken action according to the policy
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def predict(self,
|
||||
observation: np.ndarray,
|
||||
|
|
@ -205,64 +278,6 @@ class BasePolicy(nn.Module, ABC):
|
|||
low, high = self.action_space.low, self.action_space.high
|
||||
return low + (0.5 * (scaled_action + 1.0) * (high - low))
|
||||
|
||||
def _get_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get data that need to be saved in order to re-create the policy.
|
||||
This corresponds to the arguments of the constructor.
|
||||
|
||||
:return: (Dict[str, Any])
|
||||
"""
|
||||
return dict(
|
||||
observation_space=self.observation_space,
|
||||
action_space=self.action_space,
|
||||
# Passed to the constructor by child class
|
||||
# squash_output=self.squash_output,
|
||||
# features_extractor=self.features_extractor
|
||||
normalize_images=self.normalize_images,
|
||||
)
|
||||
|
||||
def save(self, path: str) -> None:
|
||||
"""
|
||||
Save policy to a given location.
|
||||
|
||||
:param path: (str)
|
||||
"""
|
||||
th.save({'state_dict': self.state_dict(), 'data': self._get_data()}, path)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str, device: Union[th.device, str] = 'auto') -> 'BasePolicy':
|
||||
"""
|
||||
Load policy from path.
|
||||
|
||||
:param path: (str)
|
||||
:param device: (Union[th.device, str]) Device on which the policy should be loaded.
|
||||
:return: (BasePolicy)
|
||||
"""
|
||||
device = get_device(device)
|
||||
saved_variables = th.load(path, map_location=device)
|
||||
# Create policy object
|
||||
model = cls(**saved_variables['data'])
|
||||
# Load weights
|
||||
model.load_state_dict(saved_variables['state_dict'])
|
||||
model.to(device)
|
||||
return model
|
||||
|
||||
def load_from_vector(self, vector: np.ndarray):
|
||||
"""
|
||||
Load parameters from a 1D vector.
|
||||
|
||||
:param vector: (np.ndarray)
|
||||
"""
|
||||
th.nn.utils.vector_to_parameters(th.FloatTensor(vector).to(self.device), self.parameters())
|
||||
|
||||
def parameters_to_vector(self) -> np.ndarray:
|
||||
"""
|
||||
Convert the parameters to a 1D vector.
|
||||
|
||||
:return: (np.ndarray)
|
||||
"""
|
||||
return th.nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy()
|
||||
|
||||
|
||||
class ActorCriticPolicy(BasePolicy):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import torch as th
|
|||
import torch.nn as nn
|
||||
|
||||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
from stable_baselines3.common.policies import BasePolicy, register_policy, create_sde_features_extractor
|
||||
from stable_baselines3.common.policies import BaseModel, BasePolicy, register_policy, create_sde_features_extractor
|
||||
from stable_baselines3.common.torch_layers import create_mlp, NatureCNN, BaseFeaturesExtractor, FlattenExtractor
|
||||
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
|
||||
|
||||
|
|
@ -179,7 +179,7 @@ class Actor(BasePolicy):
|
|||
return self.forward(observation, deterministic)
|
||||
|
||||
|
||||
class Critic(BasePolicy):
|
||||
class Critic(BaseModel):
|
||||
"""
|
||||
Critic network (q-value function) for SAC.
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import torch as th
|
|||
import torch.nn as nn
|
||||
|
||||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
from stable_baselines3.common.policies import BasePolicy, register_policy
|
||||
from stable_baselines3.common.policies import BaseModel, BasePolicy, register_policy
|
||||
from stable_baselines3.common.torch_layers import create_mlp, NatureCNN, BaseFeaturesExtractor, FlattenExtractor
|
||||
|
||||
|
||||
|
|
@ -71,7 +71,7 @@ class Actor(BasePolicy):
|
|||
return self.forward(observation, deterministic=deterministic)
|
||||
|
||||
|
||||
class Critic(BasePolicy):
|
||||
class Critic(BaseModel):
|
||||
"""
|
||||
Critic network for TD3,
|
||||
in fact it represents the action-state value function (Q-value function)
|
||||
|
|
|
|||
Loading…
Reference in a new issue