Correctly pass device arg during recursion

This commit is contained in:
Matt 2024-12-03 18:07:28 +00:00
parent 0837c7e442
commit 9cd45c8b48
2 changed files with 2 additions and 2 deletions

View file

@ -67,7 +67,7 @@ class BatchMixFeature(BatchFeature):
def _recursive_to(obj, device, *args, **kwargs):
# Lists can be nested, so keep digging until we hit tensors
if isinstance(obj, list):
return [_recursive_to(o, *args, **kwargs) for o in obj]
return [_recursive_to(o, device, *args, **kwargs) for o in obj]
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
elif isinstance(obj, torch.Tensor) and torch.is_floating_point(obj):
# cast and send to device

View file

@ -70,7 +70,7 @@ class BatchMixFeature(BatchFeature):
def _recursive_to(obj, device, *args, **kwargs):
# Lists can be nested, so keep digging until we hit tensors
if isinstance(obj, list):
return [_recursive_to(o, *args, **kwargs) for o in obj]
return [_recursive_to(o, device, *args, **kwargs) for o in obj]
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
elif isinstance(obj, torch.Tensor) and torch.is_floating_point(obj):
# cast and send to device