mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Fix case of nested tensors in BatchMixFeature
This commit is contained in:
parent
a5bb528471
commit
19876ea405
1 changed files with 17 additions and 17 deletions
|
|
@ -52,7 +52,7 @@ def is_image_or_image_url(elem):
|
|||
|
||||
# Copied from transformers.models.pixtral.image_processing_pixtral.BatchMixFeature
|
||||
class BatchMixFeature(BatchFeature):
|
||||
def to(self, *args, **kwargs) -> "BatchMixFeature":
|
||||
def to(self, device, *args, **kwargs) -> "BatchMixFeature":
|
||||
"""
|
||||
Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
|
||||
different `dtypes` and sending the `BatchFeature` to a different `device`.
|
||||
|
|
@ -66,10 +66,23 @@ class BatchMixFeature(BatchFeature):
|
|||
Returns:
|
||||
[`BatchFeature`]: The same instance after modification.
|
||||
"""
|
||||
def _recursive_to(obj, device, *args, **kwargs):
|
||||
# Lists can be nested, so keep digging until we hit tensors
|
||||
if isinstance(obj, list):
|
||||
return [_recursive_to(o, *args, **kwargs) for o in obj]
|
||||
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
|
||||
elif isinstance(obj, torch.Tensor) and torch.is_floating_point(obj):
|
||||
# cast and send to device
|
||||
return obj.to(*args, **kwargs)
|
||||
elif isinstance(obj, torch.Tensor) and device is not None:
|
||||
# only send to device, don't cast
|
||||
return obj.to(device=device)
|
||||
else:
|
||||
return obj
|
||||
|
||||
requires_backends(self, ["torch"])
|
||||
import torch # noqa
|
||||
|
||||
new_data = {}
|
||||
device = kwargs.get("device")
|
||||
# Check if the args are a device or a dtype
|
||||
if device is None and len(args) > 0:
|
||||
|
|
@ -83,21 +96,8 @@ class BatchMixFeature(BatchFeature):
|
|||
else:
|
||||
# it's something else
|
||||
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
|
||||
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
|
||||
for k, v in self.items():
|
||||
# check if v is a floating point
|
||||
if isinstance(v, list):
|
||||
new_data[k] = [
|
||||
element.to(*args, **kwargs) for sample in v for element in sample if is_torch_tensor(element)
|
||||
]
|
||||
elif isinstance(v, torch.Tensor) and torch.is_floating_point(v):
|
||||
# cast and send to device
|
||||
new_data[k] = v.to(*args, **kwargs)
|
||||
elif isinstance(v, torch.Tensor) and device is not None:
|
||||
new_data[k] = v.to(device=device)
|
||||
else:
|
||||
new_data[k] = v
|
||||
self.data = new_data
|
||||
|
||||
self.data = {k: _recursive_to(v, device, *args, **kwargs) for k, v in self.data.items()}
|
||||
return self
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue