From c2c5d0798ffe00fd8e6d48c54e1a1f5fb7873014 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 6 Apr 2020 15:17:30 +0200 Subject: [PATCH] Fix: pass device for SAC --- torchy_baselines/sac/policies.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchy_baselines/sac/policies.py b/torchy_baselines/sac/policies.py index 6ed115e..e83292f 100644 --- a/torchy_baselines/sac/policies.py +++ b/torchy_baselines/sac/policies.py @@ -245,7 +245,8 @@ class SACPolicy(BasePolicy): 'features_dim': self.features_dim, 'net_arch': self.net_arch, 'activation_fn': self.activation_fn, - 'normalize_images': normalize_images + 'normalize_images': normalize_images, + 'device': device } self.actor_kwargs = self.net_args.copy() sde_kwargs = {