mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Correctly pass device arg during recursion
This commit is contained in:
parent
0837c7e442
commit
9cd45c8b48
2 changed files with 2 additions and 2 deletions
|
|
@ -67,7 +67,7 @@ class BatchMixFeature(BatchFeature):
|
|||
def _recursive_to(obj, device, *args, **kwargs):
|
||||
# Lists can be nested, so keep digging until we hit tensors
|
||||
if isinstance(obj, list):
|
||||
return [_recursive_to(o, *args, **kwargs) for o in obj]
|
||||
return [_recursive_to(o, device, *args, **kwargs) for o in obj]
|
||||
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
|
||||
elif isinstance(obj, torch.Tensor) and torch.is_floating_point(obj):
|
||||
# cast and send to device
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ class BatchMixFeature(BatchFeature):
|
|||
def _recursive_to(obj, device, *args, **kwargs):
|
||||
# Lists can be nested, so keep digging until we hit tensors
|
||||
if isinstance(obj, list):
|
||||
return [_recursive_to(o, *args, **kwargs) for o in obj]
|
||||
return [_recursive_to(o, device, *args, **kwargs) for o in obj]
|
||||
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
|
||||
elif isinstance(obj, torch.Tensor) and torch.is_floating_point(obj):
|
||||
# cast and send to device
|
||||
|
|
|
|||
Loading…
Reference in a new issue