Fix: don't change the device when saving

This commit is contained in:
Antonin RAFFIN 2020-03-31 18:38:37 +02:00
parent 71ce9ef2f4
commit 6e470d0f72

View file

@ -197,11 +197,7 @@ class BasePolicy(nn.Module):
:param path: (str)
"""
previous_device = self.device
# Convert to cpu before saving
self = self.to('cpu')
th.save(self.state_dict(), path)
self = self.to(previous_device)
def load(self, path: str) -> None:
"""
@ -211,7 +207,6 @@ class BasePolicy(nn.Module):
:param path: (str)
"""
self.load_state_dict(th.load(path))
self = self.to(self.device)
def load_from_vector(self, vector: np.ndarray):
"""