diff --git a/torchy_baselines/sac/policies.py b/torchy_baselines/sac/policies.py index 6ed115e..e83292f 100644 --- a/torchy_baselines/sac/policies.py +++ b/torchy_baselines/sac/policies.py @@ -245,7 +245,8 @@ class SACPolicy(BasePolicy): 'features_dim': self.features_dim, 'net_arch': self.net_arch, 'activation_fn': self.activation_fn, - 'normalize_images': normalize_images + 'normalize_images': normalize_images, + 'device': device } self.actor_kwargs = self.net_args.copy() sde_kwargs = {