Use non-blocking move

This commit is contained in:
Antonin Raffin 2022-12-17 12:10:42 +01:00
parent 6d55a09f81
commit 6b147d8cca
No known key found for this signature in database
GPG key ID: B8B48F65CAD6232C
2 changed files with 4 additions and 4 deletions

View file

@ -132,8 +132,8 @@ class BaseBuffer(ABC):
:return:
"""
if copy:
return th.tensor(array).to(self.device)
return th.as_tensor(array).to(self.device)
return th.tensor(array).to(self.device, non_blocking=True)
return th.as_tensor(array).to(self.device, non_blocking=True)
@staticmethod
def _normalize_obs(

View file

@ -461,9 +461,9 @@ def obs_as_tensor(
:return: PyTorch tensor of the observation on a desired device.
"""
if isinstance(obs, np.ndarray):
return th.as_tensor(obs).to(device)
return th.as_tensor(obs).to(device, non_blocking=True)
elif isinstance(obs, dict):
return {key: th.as_tensor(_obs).to(device) for (key, _obs) in obs.items()}
return {key: th.as_tensor(_obs).to(device, non_blocking=True) for (key, _obs) in obs.items()}
else:
raise Exception(f"Unrecognized type of observation {type(obs)}")