Fix case of nested tensors in BatchMixFeature

This commit is contained in:
Matt 2024-12-03 17:49:47 +00:00
parent a5bb528471
commit 19876ea405

View file

@ -52,7 +52,7 @@ def is_image_or_image_url(elem):
# Copied from transformers.models.pixtral.image_processing_pixtral.BatchMixFeature
class BatchMixFeature(BatchFeature):
def to(self, *args, **kwargs) -> "BatchMixFeature":
def to(self, device, *args, **kwargs) -> "BatchMixFeature":
"""
Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
different `dtypes` and sending the `BatchFeature` to a different `device`.
@ -66,10 +66,23 @@ class BatchMixFeature(BatchFeature):
Returns:
[`BatchFeature`]: The same instance after modification.
"""
def _recursive_to(obj, device, *args, **kwargs):
# Lists can be nested, so keep digging until we hit tensors
if isinstance(obj, list):
return [_recursive_to(o, *args, **kwargs) for o in obj]
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
elif isinstance(obj, torch.Tensor) and torch.is_floating_point(obj):
# cast and send to device
return obj.to(*args, **kwargs)
elif isinstance(obj, torch.Tensor) and device is not None:
# only send to device, don't cast
return obj.to(device=device)
else:
return obj
requires_backends(self, ["torch"])
import torch # noqa
new_data = {}
device = kwargs.get("device")
# Check if the args are a device or a dtype
if device is None and len(args) > 0:
@ -83,21 +96,8 @@ class BatchMixFeature(BatchFeature):
else:
# it's something else
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
for k, v in self.items():
# check if v is a floating point
if isinstance(v, list):
new_data[k] = [
element.to(*args, **kwargs) for sample in v for element in sample if is_torch_tensor(element)
]
elif isinstance(v, torch.Tensor) and torch.is_floating_point(v):
# cast and send to device
new_data[k] = v.to(*args, **kwargs)
elif isinstance(v, torch.Tensor) and device is not None:
new_data[k] = v.to(device=device)
else:
new_data[k] = v
self.data = new_data
self.data = {k: _recursive_to(v, device, *args, **kwargs) for k, v in self.data.items()}
return self