mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-27 03:11:57 +00:00
Improve doc
This commit is contained in:
parent
f347474e6a
commit
6bd2d87f33
1 changed files with 12 additions and 2 deletions
|
|
@ -200,6 +200,12 @@ class BasePolicy(nn.Module):
|
|||
.format(observation_space))
|
||||
|
||||
def _get_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get data that need to be saved in order to re-create the policy.
|
||||
This corresponds to the arguments of the constructor.
|
||||
|
||||
:return: (Dict[str, Any])
|
||||
"""
|
||||
return dict(
|
||||
observation_space=self.observation_space,
|
||||
action_space=self.action_space,
|
||||
|
|
@ -218,15 +224,19 @@ class BasePolicy(nn.Module):
|
|||
th.save({'state_dict': self.state_dict(), 'data': self._get_data()}, path)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str) -> 'BasePolicy':
|
||||
def load(cls, path: str, device: Union[th.device, str] = 'auto') -> 'BasePolicy':
|
||||
"""
|
||||
Load policy from path.
|
||||
|
||||
:param path: (str)
|
||||
:param device: ( Union[th.device, str]) Device on which the policy should be loaded.
|
||||
:return: (BasePolicy)
|
||||
"""
|
||||
device = get_device()
|
||||
device = get_device(device)
|
||||
saved_variables = th.load(path, map_location=device)
|
||||
# Create policy object
|
||||
model = cls(**saved_variables['data'])
|
||||
# Load weights
|
||||
model.load_state_dict(saved_variables['state_dict'])
|
||||
model.to(device)
|
||||
return model
|
||||
|
|
|
|||
Loading…
Reference in a new issue