diff --git a/torchy_baselines/common/policies.py b/torchy_baselines/common/policies.py index 41b7211..9aff338 100644 --- a/torchy_baselines/common/policies.py +++ b/torchy_baselines/common/policies.py @@ -200,6 +200,12 @@ class BasePolicy(nn.Module): .format(observation_space)) def _get_data(self) -> Dict[str, Any]: + """ + Get data that need to be saved in order to re-create the policy. + This corresponds to the arguments of the constructor. + + :return: (Dict[str, Any]) + """ return dict( observation_space=self.observation_space, action_space=self.action_space, @@ -218,15 +224,19 @@ class BasePolicy(nn.Module): th.save({'state_dict': self.state_dict(), 'data': self._get_data()}, path) @classmethod - def load(cls, path: str) -> 'BasePolicy': + def load(cls, path: str, device: Union[th.device, str] = 'auto') -> 'BasePolicy': """ Load policy from path. :param path: (str) + :param device: ( Union[th.device, str]) Device on which the policy should be loaded. + :return: (BasePolicy) """ - device = get_device() + device = get_device(device) saved_variables = th.load(path, map_location=device) + # Create policy object model = cls(**saved_variables['data']) + # Load weights model.load_state_dict(saved_variables['state_dict']) model.to(device) return model