diff --git a/torchy_baselines/td3/policies.py b/torchy_baselines/td3/policies.py index 89dc019..96fae45 100644 --- a/torchy_baselines/td3/policies.py +++ b/torchy_baselines/td3/policies.py @@ -40,7 +40,7 @@ class Actor(BaseNetwork): if self.full_std: return self.log_std # Reduce the number of parameters: - return th.ones((self.latent_dim, self.action_dim)) * self.log_std + return th.ones((self.latent_dim, self.action_dim)).to(self.log_std.device) * self.log_std def get_distribution_stats(self, obs, action): with th.no_grad():