mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-18 21:30:19 +00:00
Fix device
This commit is contained in:
parent
da325a0ba7
commit
8aac10f3fa
1 changed files with 1 additions and 1 deletions
|
|
@ -40,7 +40,7 @@ class Actor(BaseNetwork):
|
|||
if self.full_std:
|
||||
return self.log_std
|
||||
# Reduce the number of parameters:
|
||||
return th.ones((self.latent_dim, self.action_dim)) * self.log_std
|
||||
return th.ones((self.latent_dim, self.action_dim)).to(self.log_std.device) * self.log_std
|
||||
|
||||
def get_distribution_stats(self, obs, action):
|
||||
with th.no_grad():
|
||||
|
|
|
|||
Loading…
Reference in a new issue