mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-24 22:25:13 +00:00
Fix: don't change the device when saving
This commit is contained in:
parent
71ce9ef2f4
commit
6e470d0f72
1 changed files with 0 additions and 5 deletions
|
|
@ -197,11 +197,7 @@ class BasePolicy(nn.Module):
|
|||
|
||||
:param path: (str)
|
||||
"""
|
||||
previous_device = self.device
|
||||
# Convert to cpu before saving
|
||||
self = self.to('cpu')
|
||||
th.save(self.state_dict(), path)
|
||||
self = self.to(previous_device)
|
||||
|
||||
def load(self, path: str) -> None:
|
||||
"""
|
||||
|
|
@ -211,7 +207,6 @@ class BasePolicy(nn.Module):
|
|||
:param path: (str)
|
||||
"""
|
||||
self.load_state_dict(th.load(path))
|
||||
self = self.to(self.device)
|
||||
|
||||
def load_from_vector(self, vector: np.ndarray):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in a new issue