mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-14 20:58:03 +00:00
Fix test device for buffers (#1993)
* Prevent test_device from being a noop * Update changelog --------- Co-authored-by: Adrià Garriga-Alonso <adria@far.ai>
This commit is contained in:
parent
4a1137ba3a
commit
4a7631b71d
2 changed files with 16 additions and 7 deletions
|
|
@ -30,6 +30,8 @@ Bug Fixes:
|
|||
- Fixed error when loading a model that has ``net_arch`` manually set to ``None`` (@jak3122)
|
||||
- Set requirement numpy<2.0 until PyTorch is compatible (https://github.com/pytorch/pytorch/issues/107302)
|
||||
- Updated DQN optimizer input to only include q_network parameters, removing the target_q_network ones (@corentinlger)
|
||||
- Fixed ``test_buffers.py::test_device`` which was not actually checking the device of tensors (@rhaps0dy)
|
||||
|
||||
|
||||
`SB3-Contrib`_
|
||||
^^^^^^^^^^^^^^
|
||||
|
|
|
|||
|
|
@ -139,18 +139,25 @@ def test_device_buffer(replay_buffer_cls, device):
|
|||
|
||||
# Get data from the buffer
|
||||
if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]:
|
||||
# get returns an iterator over minibatches
|
||||
data = buffer.get(50)
|
||||
elif replay_buffer_cls in [ReplayBuffer, DictReplayBuffer]:
|
||||
data = buffer.sample(50)
|
||||
data = [buffer.sample(50)]
|
||||
|
||||
# Check that all data are on the desired device
|
||||
desired_device = get_device(device).type
|
||||
for value in list(data):
|
||||
if isinstance(value, dict):
|
||||
for key in value.keys():
|
||||
assert value[key].device.type == desired_device
|
||||
elif isinstance(value, th.Tensor):
|
||||
assert value.device.type == desired_device
|
||||
for minibatch in list(data):
|
||||
for value in minibatch:
|
||||
if isinstance(value, dict):
|
||||
for key in value.keys():
|
||||
assert value[key].device.type == desired_device
|
||||
elif isinstance(value, th.Tensor):
|
||||
assert value.device.type == desired_device
|
||||
elif isinstance(value, np.ndarray):
|
||||
# For prioritized replay weights/indices
|
||||
pass
|
||||
else:
|
||||
raise TypeError(f"Unknown value type: {type(value)}")
|
||||
|
||||
|
||||
def test_custom_rollout_buffer():
|
||||
|
|
|
|||
Loading…
Reference in a new issue