From 6b147d8cca54e525d960d6d076bc7bc30cb1c73f Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sat, 17 Dec 2022 12:10:42 +0100 Subject: [PATCH] Use non-blocking move --- stable_baselines3/common/buffers.py | 4 ++-- stable_baselines3/common/utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 1896924..09b80c3 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -132,8 +132,8 @@ class BaseBuffer(ABC): :return: """ if copy: - return th.tensor(array).to(self.device) - return th.as_tensor(array).to(self.device) + return th.tensor(array).to(self.device, non_blocking=True) + return th.as_tensor(array).to(self.device, non_blocking=True) @staticmethod def _normalize_obs( diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 1a3f871..b7c4cb9 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -461,9 +461,9 @@ def obs_as_tensor( :return: PyTorch tensor of the observation on a desired device. """ if isinstance(obs, np.ndarray): - return th.as_tensor(obs).to(device) + return th.as_tensor(obs).to(device, non_blocking=True) elif isinstance(obs, dict): - return {key: th.as_tensor(_obs).to(device) for (key, _obs) in obs.items()} + return {key: th.as_tensor(_obs).to(device, non_blocking=True) for (key, _obs) in obs.items()} else: raise Exception(f"Unrecognized type of observation {type(obs)}")