Fix test device for buffers (#1993)

* Prevent test_device from being a noop

* Update changelog

---------

Co-authored-by: Adrià Garriga-Alonso <adria@far.ai>
This commit is contained in:
Antonin RAFFIN 2024-08-18 12:33:22 +02:00 committed by GitHub
parent 4a1137ba3a
commit 4a7631b71d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 16 additions and 7 deletions

View file

@ -30,6 +30,8 @@ Bug Fixes:
- Fixed error when loading a model that has ``net_arch`` manually set to ``None`` (@jak3122)
- Set requirement numpy<2.0 until PyTorch is compatible (https://github.com/pytorch/pytorch/issues/107302)
- Updated DQN optimizer input to only include q_network parameters, removing the target_q_network ones (@corentinlger)
- Fixed ``test_buffers.py::test_device`` which was not actually checking the device of tensors (@rhaps0dy)
`SB3-Contrib`_
^^^^^^^^^^^^^^

View file

@ -139,18 +139,25 @@ def test_device_buffer(replay_buffer_cls, device):
# Get data from the buffer
if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]:
# get returns an iterator over minibatches
data = buffer.get(50)
elif replay_buffer_cls in [ReplayBuffer, DictReplayBuffer]:
data = buffer.sample(50)
data = [buffer.sample(50)]
# Check that all data are on the desired device
desired_device = get_device(device).type
for value in list(data):
if isinstance(value, dict):
for key in value.keys():
assert value[key].device.type == desired_device
elif isinstance(value, th.Tensor):
assert value.device.type == desired_device
for minibatch in list(data):
for value in minibatch:
if isinstance(value, dict):
for key in value.keys():
assert value[key].device.type == desired_device
elif isinstance(value, th.Tensor):
assert value.device.type == desired_device
elif isinstance(value, np.ndarray):
# For prioritized replay weights/indices
pass
else:
raise TypeError(f"Unknown value type: {type(value)}")
def test_custom_rollout_buffer():