mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-07-01 03:45:11 +00:00
parent
15834461be
commit
1fec326402
2 changed files with 4 additions and 4 deletions
|
|
@ -132,8 +132,8 @@ class BaseBuffer(ABC):
|
|||
:return:
|
||||
"""
|
||||
if copy:
|
||||
return th.tensor(array).to(self.device, non_blocking=True)
|
||||
return th.as_tensor(array).to(self.device, non_blocking=True)
|
||||
return th.tensor(array).to(self.device)
|
||||
return th.as_tensor(array).to(self.device)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_obs(
|
||||
|
|
|
|||
|
|
@ -461,9 +461,9 @@ def obs_as_tensor(
|
|||
:return: PyTorch tensor of the observation on a desired device.
|
||||
"""
|
||||
if isinstance(obs, np.ndarray):
|
||||
return th.as_tensor(obs).to(device, non_blocking=True)
|
||||
return th.as_tensor(obs).to(device)
|
||||
elif isinstance(obs, dict):
|
||||
return {key: th.as_tensor(_obs).to(device, non_blocking=True) for (key, _obs) in obs.items()}
|
||||
return {key: th.as_tensor(_obs).to(device) for (key, _obs) in obs.items()}
|
||||
else:
|
||||
raise Exception(f"Unrecognized type of observation {type(obs)}")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue