diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 4d191fc..8e02b74 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -132,8 +132,8 @@ class BaseBuffer(ABC): :return: """ if copy: - return th.tensor(array).to(self.device, non_blocking=True) - return th.as_tensor(array).to(self.device, non_blocking=True) + return th.tensor(array).to(self.device) + return th.as_tensor(array).to(self.device) @staticmethod def _normalize_obs( diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index b7c4cb9..1a3f871 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -461,9 +461,9 @@ def obs_as_tensor( :return: PyTorch tensor of the observation on a desired device. """ if isinstance(obs, np.ndarray): - return th.as_tensor(obs).to(device, non_blocking=True) + return th.as_tensor(obs).to(device) elif isinstance(obs, dict): - return {key: th.as_tensor(_obs).to(device, non_blocking=True) for (key, _obs) in obs.items()} + return {key: th.as_tensor(_obs).to(device) for (key, _obs) in obs.items()} else: raise Exception(f"Unrecognized type of observation {type(obs)}")