mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
[Vision] .to function for ImageProcessors (#20536)
* add v1 with tests * add checker * simplified version * update docstring * better version * fix docstring + change order * make style * tests + change conditions * final tests * modify docstring * Update src/transformers/feature_extraction_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * replace by `ValueError` * fix logic * apply suggestions * `dtype` is not needed * adapt suggestions * remove `_parse_args_to_device` Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
67d32f4649
commit
ef0f85cd57
5 changed files with 102 additions and 14 deletions
|
|
@ -40,6 +40,7 @@ from .utils import (
|
|||
is_tf_available,
|
||||
is_torch_available,
|
||||
is_torch_device,
|
||||
is_torch_dtype,
|
||||
logging,
|
||||
torch_required,
|
||||
)
|
||||
|
|
@ -47,7 +48,7 @@ from .utils import (
|
|||
|
||||
if TYPE_CHECKING:
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch # noqa
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
|
@ -138,7 +139,7 @@ class BatchFeature(UserDict):
|
|||
elif tensor_type == TensorType.PYTORCH:
|
||||
if not is_torch_available():
|
||||
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
|
||||
import torch
|
||||
import torch # noqa
|
||||
|
||||
def as_tensor(value):
|
||||
if isinstance(value, (list, tuple)) and len(value) > 0 and isinstance(value[0], np.ndarray):
|
||||
|
|
@ -175,25 +176,47 @@ class BatchFeature(UserDict):
|
|||
return self
|
||||
|
||||
@torch_required
|
||||
# Copied from transformers.tokenization_utils_base.BatchEncoding.to with BatchEncoding->BatchFeature
|
||||
def to(self, device: Union[str, "torch.device"]) -> "BatchFeature":
|
||||
def to(self, *args, **kwargs) -> "BatchFeature":
|
||||
"""
|
||||
Send all values to device by calling `v.to(device)` (PyTorch only).
|
||||
Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
|
||||
different `dtypes` and sending the `BatchFeature` to a different `device`.
|
||||
|
||||
Args:
|
||||
device (`str` or `torch.device`): The device to put the tensors on.
|
||||
args (`Tuple`):
|
||||
Will be passed to the `to(...)` function of the tensors.
|
||||
kwargs (`Dict`, *optional*):
|
||||
Will be passed to the `to(...)` function of the tensors.
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: The same instance after modification.
|
||||
"""
|
||||
import torch # noqa
|
||||
|
||||
# This check catches things like APEX blindly calling "to" on all inputs to a module
|
||||
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
|
||||
# into a HalfTensor
|
||||
if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
|
||||
self.data = {k: v.to(device=device) for k, v in self.data.items()}
|
||||
else:
|
||||
logger.warning(f"Attempting to cast a BatchFeature to type {str(device)}. This is not supported.")
|
||||
new_data = {}
|
||||
device = kwargs.get("device")
|
||||
# Check if the args are a device or a dtype
|
||||
if device is None and len(args) > 0:
|
||||
# device should be always the first argument
|
||||
arg = args[0]
|
||||
if is_torch_dtype(arg):
|
||||
# The first argument is a dtype
|
||||
pass
|
||||
elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
|
||||
device = arg
|
||||
else:
|
||||
# it's something else
|
||||
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
|
||||
# We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
|
||||
for k, v in self.items():
|
||||
# check if v is a floating point
|
||||
if torch.is_floating_point(v):
|
||||
# cast and send to device
|
||||
new_data[k] = v.to(*args, **kwargs)
|
||||
elif device is not None:
|
||||
new_data[k] = v.to(device=device)
|
||||
else:
|
||||
new_data[k] = v
|
||||
self.data = new_data
|
||||
return self
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ from .generic import (
|
|||
is_tensor,
|
||||
is_tf_tensor,
|
||||
is_torch_device,
|
||||
is_torch_dtype,
|
||||
is_torch_tensor,
|
||||
reshape,
|
||||
squeeze,
|
||||
|
|
|
|||
|
|
@ -123,6 +123,24 @@ def is_torch_device(x):
|
|||
return False if not is_torch_available() else _is_torch_device(x)
|
||||
|
||||
|
||||
def _is_torch_dtype(x):
|
||||
import torch
|
||||
|
||||
if isinstance(x, str):
|
||||
if hasattr(torch, x):
|
||||
x = getattr(torch, x)
|
||||
else:
|
||||
return False
|
||||
return isinstance(x, torch.dtype)
|
||||
|
||||
|
||||
def is_torch_dtype(x):
|
||||
"""
|
||||
Tests if `x` is a torch dtype or not. Safe to call even if torch is not installed.
|
||||
"""
|
||||
return False if not is_torch_available() else _is_torch_dtype(x)
|
||||
|
||||
|
||||
def _is_tensorflow(x):
|
||||
import tensorflow as tf
|
||||
|
||||
|
|
|
|||
|
|
@ -84,6 +84,7 @@ class DeiTFeatureExtractionTester(unittest.TestCase):
|
|||
class DeiTFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase):
|
||||
|
||||
feature_extraction_class = DeiTFeatureExtractor if is_vision_available() else None
|
||||
test_cast_dtype = True
|
||||
|
||||
def setUp(self):
|
||||
self.feature_extract_tester = DeiTFeatureExtractionTester(self)
|
||||
|
|
|
|||
|
|
@ -25,7 +25,15 @@ from pathlib import Path
|
|||
from huggingface_hub import HfFolder, delete_repo, set_access_token
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor
|
||||
from transformers.testing_utils import TOKEN, USER, check_json_file_has_correct_format, get_tests_dir, is_staging_test
|
||||
from transformers.testing_utils import (
|
||||
TOKEN,
|
||||
USER,
|
||||
check_json_file_has_correct_format,
|
||||
get_tests_dir,
|
||||
is_staging_test,
|
||||
require_torch,
|
||||
require_vision,
|
||||
)
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
|
||||
|
|
@ -134,6 +142,8 @@ def prepare_video_inputs(feature_extract_tester, equal_resolution=False, numpify
|
|||
|
||||
|
||||
class FeatureExtractionSavingTestMixin:
|
||||
test_cast_dtype = None
|
||||
|
||||
def test_feat_extract_to_json_string(self):
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
obj = json.loads(feat_extract.to_json_string())
|
||||
|
|
@ -164,6 +174,41 @@ class FeatureExtractionSavingTestMixin:
|
|||
feat_extract = self.feature_extraction_class()
|
||||
self.assertIsNotNone(feat_extract)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_cast_dtype_device(self):
|
||||
if self.test_cast_dtype is not None:
|
||||
# Initialize feature_extractor
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
|
||||
# create random PyTorch tensors
|
||||
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
|
||||
|
||||
encoding = feature_extractor(image_inputs, return_tensors="pt")
|
||||
# for layoutLM compatiblity
|
||||
self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
|
||||
self.assertEqual(encoding.pixel_values.dtype, torch.float32)
|
||||
|
||||
encoding = feature_extractor(image_inputs, return_tensors="pt").to(torch.float16)
|
||||
self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
|
||||
self.assertEqual(encoding.pixel_values.dtype, torch.float16)
|
||||
|
||||
encoding = feature_extractor(image_inputs, return_tensors="pt").to("cpu", torch.bfloat16)
|
||||
self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
|
||||
self.assertEqual(encoding.pixel_values.dtype, torch.bfloat16)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
_ = feature_extractor(image_inputs, return_tensors="pt").to(torch.bfloat16, "cpu")
|
||||
|
||||
# Try with text + image feature
|
||||
encoding = feature_extractor(image_inputs, return_tensors="pt")
|
||||
encoding.update({"input_ids": torch.LongTensor([[1, 2, 3], [4, 5, 6]])})
|
||||
encoding = encoding.to(torch.float16)
|
||||
|
||||
self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
|
||||
self.assertEqual(encoding.pixel_values.dtype, torch.float16)
|
||||
self.assertEqual(encoding.input_ids.dtype, torch.long)
|
||||
|
||||
|
||||
class FeatureExtractorUtilTester(unittest.TestCase):
|
||||
def test_cached_files_are_used_when_internet_is_down(self):
|
||||
|
|
|
|||
Loading…
Reference in a new issue