diff --git a/src/transformers/models/pixtral/image_processing_pixtral.py b/src/transformers/models/pixtral/image_processing_pixtral.py index 9f8a5ede3..69eb7c28e 100644 --- a/src/transformers/models/pixtral/image_processing_pixtral.py +++ b/src/transformers/models/pixtral/image_processing_pixtral.py @@ -67,7 +67,7 @@ class BatchMixFeature(BatchFeature): def _recursive_to(obj, device, *args, **kwargs): # Lists can be nested, so keep digging until we hit tensors if isinstance(obj, list): - return [_recursive_to(o, *args, **kwargs) for o in obj] + return [_recursive_to(o, device, *args, **kwargs) for o in obj] # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor` elif isinstance(obj, torch.Tensor) and torch.is_floating_point(obj): # cast and send to device diff --git a/src/transformers/models/pixtral/processing_pixtral.py b/src/transformers/models/pixtral/processing_pixtral.py index 4c5bfa844..218c290ae 100644 --- a/src/transformers/models/pixtral/processing_pixtral.py +++ b/src/transformers/models/pixtral/processing_pixtral.py @@ -70,7 +70,7 @@ class BatchMixFeature(BatchFeature): def _recursive_to(obj, device, *args, **kwargs): # Lists can be nested, so keep digging until we hit tensors if isinstance(obj, list): - return [_recursive_to(o, *args, **kwargs) for o in obj] + return [_recursive_to(o, device, *args, **kwargs) for o in obj] # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor` elif isinstance(obj, torch.Tensor) and torch.is_floating_point(obj): # cast and send to device