mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-28 03:21:16 +00:00
Complete save/load for TD3 policy
This commit is contained in:
parent
17f9246257
commit
a4df08cd28
2 changed files with 75 additions and 10 deletions
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Union, Type, Dict, List, Tuple, Optional
|
||||
from typing import Union, Type, Dict, List, Tuple, Optional, Any
|
||||
|
||||
from itertools import zip_longest
|
||||
|
||||
|
|
@ -57,11 +57,19 @@ class BasePolicy(nn.Module):
|
|||
return self._squash_output
|
||||
|
||||
@staticmethod
|
||||
def init_weights(module: nn.Module, gain: float = 1):
|
||||
def init_weights(module: nn.Module, gain: float = 1) -> None:
|
||||
"""
|
||||
Orthogonal initialization (used in PPO and A2C)
|
||||
"""
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.orthogonal_(module.weight, gain=gain)
|
||||
module.bias.data.fill_(0.0)
|
||||
|
||||
@staticmethod
|
||||
def _dummy_schedule(progress: float) -> float:
|
||||
""" (float) Useful for pickling policy."""
|
||||
return 0.0
|
||||
|
||||
def forward(self, *_args, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
|
@ -191,24 +199,37 @@ class BasePolicy(nn.Module):
|
|||
raise ValueError("Error: Cannot determine if the observation is vectorized with the space type {}."
|
||||
.format(observation_space))
|
||||
|
||||
def _get_data(self) -> Dict[str, Any]:
|
||||
return dict(
|
||||
observation_space=self.observation_space,
|
||||
action_space=self.action_space,
|
||||
# Passed to the constructor by child classes
|
||||
# squash_output=self.squash_output,
|
||||
# features_extractor=self.features_extractor
|
||||
normalize_images=self.normalize_images,
|
||||
)
|
||||
|
||||
def save(self, path: str) -> None:
|
||||
"""
|
||||
Save policy weights to a given location.
|
||||
NOTE: we don't save policy parameters
|
||||
Save policy to a given location.
|
||||
|
||||
:param path: (str)
|
||||
"""
|
||||
th.save(self.state_dict(), path)
|
||||
th.save({'state_dict': self.state_dict(), 'data': self._get_data()}, path)
|
||||
|
||||
def load(self, path: str) -> None:
|
||||
@classmethod
|
||||
def load(cls, path: str) -> 'BasePolicy':
|
||||
"""
|
||||
Load policy weights from path.
|
||||
NOTE: we don't load policy parameters
|
||||
Load policy from path.
|
||||
|
||||
:param path: (str)
|
||||
"""
|
||||
self.load_state_dict(th.load(path))
|
||||
device = get_device()
|
||||
saved_variables = th.load(path, map_location=device)
|
||||
model = cls(**saved_variables['data'])
|
||||
model.load_state_dict(saved_variables['state_dict'])
|
||||
model.to(device)
|
||||
return model
|
||||
|
||||
def load_from_vector(self, vector: np.ndarray):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -66,6 +66,15 @@ class Actor(BasePolicy):
|
|||
self.sde_features_extractor = None
|
||||
self.features_extractor = features_extractor
|
||||
self.normalize_images = normalize_images
|
||||
self.net_arch = net_arch
|
||||
self.features_dim = features_dim
|
||||
self.activation_fn = activation_fn
|
||||
self.clip_noise = clip_noise
|
||||
self.lr_sde = lr_sde
|
||||
self.log_std_init = log_std_init
|
||||
self.sde_net_arch = sde_net_arch
|
||||
self.use_expln = use_expln
|
||||
self.full_std = full_std
|
||||
|
||||
action_dim = get_action_dim(self.action_space)
|
||||
|
||||
|
|
@ -89,13 +98,30 @@ class Actor(BasePolicy):
|
|||
log_std_init=log_std_init)
|
||||
# Squash output
|
||||
self.mu = nn.Sequential(action_net, nn.Tanh())
|
||||
self.clip_noise = clip_noise
|
||||
self.sde_optimizer = th.optim.Adam([self.log_std], lr=lr_sde)
|
||||
self.reset_noise()
|
||||
else:
|
||||
actor_net = create_mlp(features_dim, action_dim, net_arch, activation_fn, squash_output=True)
|
||||
self.mu = nn.Sequential(*actor_net)
|
||||
|
||||
def _get_data(self) -> Dict[str, Any]:
|
||||
data = super()._get_data()
|
||||
|
||||
data.update(dict(
|
||||
net_arch=self.net_arch,
|
||||
features_dim=self.features_dim,
|
||||
activation_fn=self.activation_fn,
|
||||
use_sde=self.use_sde,
|
||||
log_std_init=self.log_std_init,
|
||||
clip_noise=self.clip_noise,
|
||||
lr_sde=self.lr_sde,
|
||||
full_std=self.full_std,
|
||||
sde_net_arch=self.sde_net_arch,
|
||||
use_expln=self.use_expln,
|
||||
features_extractor=self.features_extractor
|
||||
))
|
||||
return data
|
||||
|
||||
def get_std(self) -> th.Tensor:
|
||||
"""
|
||||
Retrieve the standard deviation of the action distribution.
|
||||
|
|
@ -344,6 +370,24 @@ class TD3Policy(BasePolicy):
|
|||
features_dim=self.features_dim)
|
||||
self.actor.sde_optimizer.add_param_group({'params': self.vf_net.parameters()}) # pytype: disable=attribute-error
|
||||
|
||||
def _get_data(self) -> Dict[str, Any]:
|
||||
data = super()._get_data()
|
||||
|
||||
data.update(dict(
|
||||
net_arch=self.net_args['net_arch'],
|
||||
activation_fn=self.net_args['activation_fn'],
|
||||
use_sde=self.actor_kwargs['use_sde'],
|
||||
log_std_init=self.actor_kwargs['log_std_init'],
|
||||
clip_noise=self.actor_kwargs['clip_noise'],
|
||||
lr_sde=self.actor_kwargs['lr_sde'],
|
||||
sde_net_arch=self.actor_kwargs['sde_net_arch'],
|
||||
use_expln=self.actor_kwargs['use_expln'],
|
||||
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
|
||||
optimizer=self.optimizer_class,
|
||||
optimizer_kwargs=self.optimizer_kwargs
|
||||
))
|
||||
return data
|
||||
|
||||
def reset_noise(self) -> None:
|
||||
return self.actor.reset_noise()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue