Improve doc

This commit is contained in:
Antonin RAFFIN 2020-04-20 16:21:47 +02:00
parent f347474e6a
commit 6bd2d87f33

View file

@ -200,6 +200,12 @@ class BasePolicy(nn.Module):
.format(observation_space))
def _get_data(self) -> Dict[str, Any]:
"""
Get data that need to be saved in order to re-create the policy.
This corresponds to the arguments of the constructor.
:return: (Dict[str, Any])
"""
return dict(
observation_space=self.observation_space,
action_space=self.action_space,
@ -218,15 +224,19 @@ class BasePolicy(nn.Module):
th.save({'state_dict': self.state_dict(), 'data': self._get_data()}, path)
@classmethod
def load(cls, path: str) -> 'BasePolicy':
def load(cls, path: str, device: Union[th.device, str] = 'auto') -> 'BasePolicy':
"""
Load policy from path.
:param path: (str)
:param device: ( Union[th.device, str]) Device on which the policy should be loaded.
:return: (BasePolicy)
"""
device = get_device()
device = get_device(device)
saved_variables = th.load(path, map_location=device)
# Create policy object
model = cls(**saved_variables['data'])
# Load weights
model.load_state_dict(saved_variables['state_dict'])
model.to(device)
return model