Complete save/load for TD3 policy

This commit is contained in:
Antonin RAFFIN 2020-04-20 15:43:23 +02:00
parent 17f9246257
commit a4df08cd28
2 changed files with 75 additions and 10 deletions

View file

@ -1,4 +1,4 @@
from typing import Union, Type, Dict, List, Tuple, Optional
from typing import Union, Type, Dict, List, Tuple, Optional, Any
from itertools import zip_longest
@ -57,11 +57,19 @@ class BasePolicy(nn.Module):
return self._squash_output
@staticmethod
def init_weights(module: nn.Module, gain: float = 1):
def init_weights(module: nn.Module, gain: float = 1) -> None:
"""
Orthogonal initialization (used in PPO and A2C)
"""
if isinstance(module, nn.Linear):
nn.init.orthogonal_(module.weight, gain=gain)
module.bias.data.fill_(0.0)
@staticmethod
def _dummy_schedule(progress: float) -> float:
""" (float) Useful for pickling policy."""
return 0.0
def forward(self, *_args, **kwargs):
raise NotImplementedError()
@ -191,24 +199,37 @@ class BasePolicy(nn.Module):
raise ValueError("Error: Cannot determine if the observation is vectorized with the space type {}."
.format(observation_space))
def _get_data(self) -> Dict[str, Any]:
return dict(
observation_space=self.observation_space,
action_space=self.action_space,
# Passed to the constructor by child classes
# squash_output=self.squash_output,
# features_extractor=self.features_extractor
normalize_images=self.normalize_images,
)
def save(self, path: str) -> None:
"""
Save policy weights to a given location.
NOTE: we don't save policy parameters
Save policy to a given location.
:param path: (str)
"""
th.save(self.state_dict(), path)
th.save({'state_dict': self.state_dict(), 'data': self._get_data()}, path)
def load(self, path: str) -> None:
@classmethod
def load(cls, path: str) -> 'BasePolicy':
"""
Load policy weights from path.
NOTE: we don't load policy parameters
Load policy from path.
:param path: (str)
"""
self.load_state_dict(th.load(path))
device = get_device()
saved_variables = th.load(path, map_location=device)
model = cls(**saved_variables['data'])
model.load_state_dict(saved_variables['state_dict'])
model.to(device)
return model
def load_from_vector(self, vector: np.ndarray):
"""

View file

@ -66,6 +66,15 @@ class Actor(BasePolicy):
self.sde_features_extractor = None
self.features_extractor = features_extractor
self.normalize_images = normalize_images
self.net_arch = net_arch
self.features_dim = features_dim
self.activation_fn = activation_fn
self.clip_noise = clip_noise
self.lr_sde = lr_sde
self.log_std_init = log_std_init
self.sde_net_arch = sde_net_arch
self.use_expln = use_expln
self.full_std = full_std
action_dim = get_action_dim(self.action_space)
@ -89,13 +98,30 @@ class Actor(BasePolicy):
log_std_init=log_std_init)
# Squash output
self.mu = nn.Sequential(action_net, nn.Tanh())
self.clip_noise = clip_noise
self.sde_optimizer = th.optim.Adam([self.log_std], lr=lr_sde)
self.reset_noise()
else:
actor_net = create_mlp(features_dim, action_dim, net_arch, activation_fn, squash_output=True)
self.mu = nn.Sequential(*actor_net)
def _get_data(self) -> Dict[str, Any]:
data = super()._get_data()
data.update(dict(
net_arch=self.net_arch,
features_dim=self.features_dim,
activation_fn=self.activation_fn,
use_sde=self.use_sde,
log_std_init=self.log_std_init,
clip_noise=self.clip_noise,
lr_sde=self.lr_sde,
full_std=self.full_std,
sde_net_arch=self.sde_net_arch,
use_expln=self.use_expln,
features_extractor=self.features_extractor
))
return data
def get_std(self) -> th.Tensor:
"""
Retrieve the standard deviation of the action distribution.
@ -344,6 +370,24 @@ class TD3Policy(BasePolicy):
features_dim=self.features_dim)
self.actor.sde_optimizer.add_param_group({'params': self.vf_net.parameters()}) # pytype: disable=attribute-error
def _get_data(self) -> Dict[str, Any]:
data = super()._get_data()
data.update(dict(
net_arch=self.net_args['net_arch'],
activation_fn=self.net_args['activation_fn'],
use_sde=self.actor_kwargs['use_sde'],
log_std_init=self.actor_kwargs['log_std_init'],
clip_noise=self.actor_kwargs['clip_noise'],
lr_sde=self.actor_kwargs['lr_sde'],
sde_net_arch=self.actor_kwargs['sde_net_arch'],
use_expln=self.actor_kwargs['use_expln'],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
optimizer=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs
))
return data
def reset_noise(self) -> None:
return self.actor.reset_noise()