Skip to content

Commit

Permalink
Fix test device for buffers (#1993)
Browse files Browse the repository at this point in the history
* Prevent test_device from being a noop

* Update changelog

---------

Co-authored-by: Adrià Garriga-Alonso <adria@far.ai>
  • Loading branch information
araffin and rhaps0dy authored Aug 18, 2024
1 parent 4a1137b commit 4a7631b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
2 changes: 2 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ Bug Fixes:
- Fixed error when loading a model that has ``net_arch`` manually set to ``None`` (@jak3122)
- Set requirement numpy<2.0 until PyTorch is compatible (https://github.com/pytorch/pytorch/issues/107302)
- Updated DQN optimizer input to only include q_network parameters, removing the target_q_network ones (@corentinlger)
- Fixed ``test_buffers.py::test_device`` which was not actually checking the device of tensors (@rhaps0dy)


`SB3-Contrib`_
^^^^^^^^^^^^^^
Expand Down
21 changes: 14 additions & 7 deletions tests/test_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,18 +139,25 @@ def test_device_buffer(replay_buffer_cls, device):

# Get data from the buffer
if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]:
# get returns an iterator over minibatches
data = buffer.get(50)
elif replay_buffer_cls in [ReplayBuffer, DictReplayBuffer]:
data = buffer.sample(50)
data = [buffer.sample(50)]

# Check that all data are on the desired device
desired_device = get_device(device).type
for value in list(data):
if isinstance(value, dict):
for key in value.keys():
assert value[key].device.type == desired_device
elif isinstance(value, th.Tensor):
assert value.device.type == desired_device
for minibatch in list(data):
for value in minibatch:
if isinstance(value, dict):
for key in value.keys():
assert value[key].device.type == desired_device
elif isinstance(value, th.Tensor):
assert value.device.type == desired_device
elif isinstance(value, np.ndarray):
# For prioritized replay weights/indices
pass
else:
raise TypeError(f"Unknown value type: {type(value)}")


def test_custom_rollout_buffer():
Expand Down

0 comments on commit 4a7631b

Please sign in to comment.