diff --git a/src/transformers/models/pixtral/processing_pixtral.py b/src/transformers/models/pixtral/processing_pixtral.py index 5913e8688..e104c9858 100644 --- a/src/transformers/models/pixtral/processing_pixtral.py +++ b/src/transformers/models/pixtral/processing_pixtral.py @@ -52,7 +52,7 @@ def is_image_or_image_url(elem): # Copied from transformers.models.pixtral.image_processing_pixtral.BatchMixFeature class BatchMixFeature(BatchFeature): - def to(self, *args, **kwargs) -> "BatchMixFeature": + def to(self, device, *args, **kwargs) -> "BatchMixFeature": """ Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in different `dtypes` and sending the `BatchFeature` to a different `device`. @@ -66,10 +66,23 @@ class BatchMixFeature(BatchFeature): Returns: [`BatchFeature`]: The same instance after modification. """ + def _recursive_to(obj, device, *args, **kwargs): + # Lists can be nested, so keep digging until we hit tensors + if isinstance(obj, list): + return [_recursive_to(o, *args, **kwargs) for o in obj] + # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor` + elif isinstance(obj, torch.Tensor) and torch.is_floating_point(obj): + # cast and send to device + return obj.to(*args, **kwargs) + elif isinstance(obj, torch.Tensor) and device is not None: + # only send to device, don't cast + return obj.to(device=device) + else: + return obj + requires_backends(self, ["torch"]) import torch # noqa - new_data = {} device = kwargs.get("device") # Check if the args are a device or a dtype if device is None and len(args) > 0: @@ -83,21 +96,8 @@ class BatchMixFeature(BatchFeature): else: # it's something else raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.") - # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor` - for k, v in self.items(): - # check if v is a floating point - if isinstance(v, list): - new_data[k] = [ - element.to(*args, **kwargs) for sample in v for element in sample if is_torch_tensor(element) - ] - elif isinstance(v, torch.Tensor) and torch.is_floating_point(v): - # cast and send to device - new_data[k] = v.to(*args, **kwargs) - elif isinstance(v, torch.Tensor) and device is not None: - new_data[k] = v.to(device=device) - else: - new_data[k] = v - self.data = new_data + + self.data = {k: _recursive_to(v, device, *args, **kwargs) for k, v in self.data.items()} return self