Fix device

This commit is contained in:
Antonin Raffin 2019-11-13 14:38:18 +01:00
parent da325a0ba7
commit 8aac10f3fa

View file

@ -40,7 +40,7 @@ class Actor(BaseNetwork):
if self.full_std:
return self.log_std
# Reduce the number of parameters:
return th.ones((self.latent_dim, self.action_dim)) * self.log_std
return th.ones((self.latent_dim, self.action_dim)).to(self.log_std.device) * self.log_std
def get_distribution_stats(self, obs, action):
with th.no_grad():