Revert "Use non-blocking move"

This reverts commit 6b147d8cca.
This commit is contained in:
Antonin Raffin 2022-12-17 12:48:37 +01:00
parent 15834461be
commit 1fec326402
No known key found for this signature in database
GPG key ID: B8B48F65CAD6232C
2 changed files with 4 additions and 4 deletions

View file

@ -132,8 +132,8 @@ class BaseBuffer(ABC):
:return:
"""
if copy:
return th.tensor(array).to(self.device, non_blocking=True)
return th.as_tensor(array).to(self.device, non_blocking=True)
return th.tensor(array).to(self.device)
return th.as_tensor(array).to(self.device)
@staticmethod
def _normalize_obs(

View file

@ -461,9 +461,9 @@ def obs_as_tensor(
:return: PyTorch tensor of the observation on a desired device.
"""
if isinstance(obs, np.ndarray):
return th.as_tensor(obs).to(device, non_blocking=True)
return th.as_tensor(obs).to(device)
elif isinstance(obs, dict):
return {key: th.as_tensor(_obs).to(device, non_blocking=True) for (key, _obs) in obs.items()}
return {key: th.as_tensor(_obs).to(device) for (key, _obs) in obs.items()}
else:
raise Exception(f"Unrecognized type of observation {type(obs)}")