From 8aac10f3fa24cf3051a0a67a39cbdd55270feb57 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 13 Nov 2019 14:38:18 +0100 Subject: [PATCH] Fix device --- torchy_baselines/td3/policies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchy_baselines/td3/policies.py b/torchy_baselines/td3/policies.py index 89dc019..96fae45 100644 --- a/torchy_baselines/td3/policies.py +++ b/torchy_baselines/td3/policies.py @@ -40,7 +40,7 @@ class Actor(BaseNetwork): if self.full_std: return self.log_std # Reduce the number of parameters: - return th.ones((self.latent_dim, self.action_dim)) * self.log_std + return th.ones((self.latent_dim, self.action_dim)).to(self.log_std.device) * self.log_std def get_distribution_stats(self, obs, action): with th.no_grad():