diff --git a/torchy_baselines/common/policies.py b/torchy_baselines/common/policies.py index d6677d6..58661f1 100644 --- a/torchy_baselines/common/policies.py +++ b/torchy_baselines/common/policies.py @@ -197,11 +197,7 @@ class BasePolicy(nn.Module): :param path: (str) """ - previous_device = self.device - # Convert to cpu before saving - self = self.to('cpu') th.save(self.state_dict(), path) - self = self.to(previous_device) def load(self, path: str) -> None: """ @@ -211,7 +207,6 @@ class BasePolicy(nn.Module): :param path: (str) """ self.load_state_dict(th.load(path)) - self = self.to(self.device) def load_from_vector(self, vector: np.ndarray): """