From 6bd2d87f3311968d92a8ec3c68be8b7353a4c055 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 20 Apr 2020 16:21:47 +0200 Subject: [PATCH] Improve doc --- torchy_baselines/common/policies.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/torchy_baselines/common/policies.py b/torchy_baselines/common/policies.py index 41b7211..9aff338 100644 --- a/torchy_baselines/common/policies.py +++ b/torchy_baselines/common/policies.py @@ -200,6 +200,12 @@ class BasePolicy(nn.Module): .format(observation_space)) def _get_data(self) -> Dict[str, Any]: + """ + Get data that need to be saved in order to re-create the policy. + This corresponds to the arguments of the constructor. + + :return: (Dict[str, Any]) + """ return dict( observation_space=self.observation_space, action_space=self.action_space, @@ -218,15 +224,19 @@ class BasePolicy(nn.Module): th.save({'state_dict': self.state_dict(), 'data': self._get_data()}, path) @classmethod - def load(cls, path: str) -> 'BasePolicy': + def load(cls, path: str, device: Union[th.device, str] = 'auto') -> 'BasePolicy': """ Load policy from path. :param path: (str) + :param device: ( Union[th.device, str]) Device on which the policy should be loaded. + :return: (BasePolicy) """ - device = get_device() + device = get_device(device) saved_variables = th.load(path, map_location=device) + # Create policy object model = cls(**saved_variables['data']) + # Load weights model.load_state_dict(saved_variables['state_dict']) model.to(device) return model