From a4df08cd287e8c57a214f988880cd2016f86f835 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 20 Apr 2020 15:43:23 +0200 Subject: [PATCH] Complete save/load for TD3 policy --- torchy_baselines/common/policies.py | 39 ++++++++++++++++++------ torchy_baselines/td3/policies.py | 46 ++++++++++++++++++++++++++++- 2 files changed, 75 insertions(+), 10 deletions(-) diff --git a/torchy_baselines/common/policies.py b/torchy_baselines/common/policies.py index 98ac717..982a400 100644 --- a/torchy_baselines/common/policies.py +++ b/torchy_baselines/common/policies.py @@ -1,4 +1,4 @@ -from typing import Union, Type, Dict, List, Tuple, Optional +from typing import Union, Type, Dict, List, Tuple, Optional, Any from itertools import zip_longest @@ -57,11 +57,19 @@ class BasePolicy(nn.Module): return self._squash_output @staticmethod - def init_weights(module: nn.Module, gain: float = 1): + def init_weights(module: nn.Module, gain: float = 1) -> None: + """ + Orthogonal initialization (used in PPO and A2C) + """ if isinstance(module, nn.Linear): nn.init.orthogonal_(module.weight, gain=gain) module.bias.data.fill_(0.0) + @staticmethod + def _dummy_schedule(progress: float) -> float: + """ (float) Useful for pickling policy.""" + return 0.0 + def forward(self, *_args, **kwargs): raise NotImplementedError() @@ -191,24 +199,37 @@ class BasePolicy(nn.Module): raise ValueError("Error: Cannot determine if the observation is vectorized with the space type {}." .format(observation_space)) + def _get_data(self) -> Dict[str, Any]: + return dict( + observation_space=self.observation_space, + action_space=self.action_space, + # Passed to the constructor by child classes + # squash_output=self.squash_output, + # features_extractor=self.features_extractor + normalize_images=self.normalize_images, + ) def save(self, path: str) -> None: """ - Save policy weights to a given location. - NOTE: we don't save policy parameters + Save policy to a given location. :param path: (str) """ - th.save(self.state_dict(), path) + th.save({'state_dict': self.state_dict(), 'data': self._get_data()}, path) - def load(self, path: str) -> None: + @classmethod + def load(cls, path: str) -> 'BasePolicy': """ - Load policy weights from path. - NOTE: we don't load policy parameters + Load policy from path. :param path: (str) """ - self.load_state_dict(th.load(path)) + device = get_device() + saved_variables = th.load(path, map_location=device) + model = cls(**saved_variables['data']) + model.load_state_dict(saved_variables['state_dict']) + model.to(device) + return model def load_from_vector(self, vector: np.ndarray): """ diff --git a/torchy_baselines/td3/policies.py b/torchy_baselines/td3/policies.py index 8a8459c..c1ecbc8 100644 --- a/torchy_baselines/td3/policies.py +++ b/torchy_baselines/td3/policies.py @@ -66,6 +66,15 @@ class Actor(BasePolicy): self.sde_features_extractor = None self.features_extractor = features_extractor self.normalize_images = normalize_images + self.net_arch = net_arch + self.features_dim = features_dim + self.activation_fn = activation_fn + self.clip_noise = clip_noise + self.lr_sde = lr_sde + self.log_std_init = log_std_init + self.sde_net_arch = sde_net_arch + self.use_expln = use_expln + self.full_std = full_std action_dim = get_action_dim(self.action_space) @@ -89,13 +98,30 @@ class Actor(BasePolicy): log_std_init=log_std_init) # Squash output self.mu = nn.Sequential(action_net, nn.Tanh()) - self.clip_noise = clip_noise self.sde_optimizer = th.optim.Adam([self.log_std], lr=lr_sde) self.reset_noise() else: actor_net = create_mlp(features_dim, action_dim, net_arch, activation_fn, squash_output=True) self.mu = nn.Sequential(*actor_net) + def _get_data(self) -> Dict[str, Any]: + data = super()._get_data() + + data.update(dict( + net_arch=self.net_arch, + features_dim=self.features_dim, + activation_fn=self.activation_fn, + use_sde=self.use_sde, + log_std_init=self.log_std_init, + clip_noise=self.clip_noise, + lr_sde=self.lr_sde, + full_std=self.full_std, + sde_net_arch=self.sde_net_arch, + use_expln=self.use_expln, + features_extractor=self.features_extractor + )) + return data + def get_std(self) -> th.Tensor: """ Retrieve the standard deviation of the action distribution. @@ -344,6 +370,24 @@ class TD3Policy(BasePolicy): features_dim=self.features_dim) self.actor.sde_optimizer.add_param_group({'params': self.vf_net.parameters()}) # pytype: disable=attribute-error + def _get_data(self) -> Dict[str, Any]: + data = super()._get_data() + + data.update(dict( + net_arch=self.net_args['net_arch'], + activation_fn=self.net_args['activation_fn'], + use_sde=self.actor_kwargs['use_sde'], + log_std_init=self.actor_kwargs['log_std_init'], + clip_noise=self.actor_kwargs['clip_noise'], + lr_sde=self.actor_kwargs['lr_sde'], + sde_net_arch=self.actor_kwargs['sde_net_arch'], + use_expln=self.actor_kwargs['use_expln'], + lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone + optimizer=self.optimizer_class, + optimizer_kwargs=self.optimizer_kwargs + )) + return data + def reset_noise(self) -> None: return self.actor.reset_noise()