mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
add uniform processors for altclip + chinese_clip (#31198)
* add initial design for uniform processors + align model
* add uniform processors for altclip + chinese_clip
* fix mutable default 👀
* add configuration test
* handle structured kwargs w defaults + add test
* protect torch-specific test
* fix style
* fix
* rebase
* update processor to generic kwargs + test
* fix style
* add sensible kwargs merge
* update test
* fix assertEqual
* move kwargs merging to processing common
* rework kwargs for type hinting
* just get Unpack from extensions
* run-slow[align]
* handle kwargs passed as nested dict
* add from_pretrained test for nested kwargs handling
* [run-slow]align
* update documentation + imports
* update audio inputs
* protect audio types, silly
* try removing imports
* make things simpler
* simplerer
* move out kwargs test to common mixin
* [run-slow]align
* skip tests for old processors
* [run-slow]align, clip
* !$#@!! protect imports, darn it
* [run-slow]align, clip
* [run-slow]align, clip
* update common processor testing
* add altclip
* add chinese_clip
* add pad_size
* [run-slow]align, clip, chinese_clip, altclip
* remove duplicated tests
* fix
* update doc
* improve documentation for default values
* add model_max_length testing
This parameter depends on tokenizers received.
* Raise if kwargs are specified in two places
* fix
* match defaults
* force padding
* fix tokenizer test
* clean defaults
* move tests to common
* remove try/catch block
* deprecate kwarg
* format
* add copyright + remove unused method
* [run-slow]altclip, chinese_clip
* clean imports
* fix version
* clean up deprecation
* fix style
* add corner case test on kwarg overlap
* resume processing - add Unpack as importable
* add tmpdirname
* fix altclip
* fix up
* add back crop_size to specific tests
* generalize tests to possible video_processor
* add back crop_size arg
* fixup overlapping kwargs test for qformer_tokenizer
* remove copied from
* fixup chinese_clip tests values
* fixup tests - qformer tokenizers
* [run-slow] altclip, chinese_clip
* remove prepare_image_inputs
This commit is contained in:
parent
4f0246e535
commit
413008c580
10 changed files with 463 additions and 52 deletions
|
|
@ -18,16 +18,11 @@ Image/Text processor class for ALIGN
|
|||
|
||||
from typing import List, Union
|
||||
|
||||
|
||||
try:
|
||||
from typing import Unpack
|
||||
except ImportError:
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from ...image_utils import ImageInput
|
||||
from ...processing_utils import (
|
||||
ProcessingKwargs,
|
||||
ProcessorMixin,
|
||||
Unpack,
|
||||
)
|
||||
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
|
||||
|
||||
|
|
|
|||
|
|
@ -16,10 +16,16 @@
|
|||
Image/Text processor class for AltCLIP
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from typing import List, Union
|
||||
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...tokenization_utils_base import BatchEncoding
|
||||
from ...image_utils import ImageInput
|
||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
|
||||
|
||||
class AltClipProcessorKwargs(ProcessingKwargs, total=False):
|
||||
_defaults = {}
|
||||
|
||||
|
||||
class AltCLIPProcessor(ProcessorMixin):
|
||||
|
|
@ -41,17 +47,8 @@ class AltCLIPProcessor(ProcessorMixin):
|
|||
image_processor_class = "CLIPImageProcessor"
|
||||
tokenizer_class = ("XLMRobertaTokenizer", "XLMRobertaTokenizerFast")
|
||||
|
||||
def __init__(self, image_processor=None, tokenizer=None, **kwargs):
|
||||
feature_extractor = None
|
||||
if "feature_extractor" in kwargs:
|
||||
warnings.warn(
|
||||
"The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`"
|
||||
" instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
feature_extractor = kwargs.pop("feature_extractor")
|
||||
|
||||
image_processor = image_processor if image_processor is not None else feature_extractor
|
||||
@deprecate_kwarg(old_name="feature_extractor", version="5.0.0", new_name="image_processor")
|
||||
def __init__(self, image_processor=None, tokenizer=None):
|
||||
if image_processor is None:
|
||||
raise ValueError("You need to specify an `image_processor`.")
|
||||
if tokenizer is None:
|
||||
|
|
@ -59,7 +56,14 @@ class AltCLIPProcessor(ProcessorMixin):
|
|||
|
||||
super().__init__(image_processor, tokenizer)
|
||||
|
||||
def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput = None,
|
||||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
||||
audio=None,
|
||||
videos=None,
|
||||
**kwargs: Unpack[AltClipProcessorKwargs],
|
||||
) -> BatchEncoding:
|
||||
"""
|
||||
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
||||
and `kwargs` arguments to XLMRobertaTokenizerFast's [`~XLMRobertaTokenizerFast.__call__`] if `text` is not
|
||||
|
|
@ -68,22 +72,20 @@ class AltCLIPProcessor(ProcessorMixin):
|
|||
of the above two methods for more information.
|
||||
|
||||
Args:
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
|
||||
images (`ImageInput`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
Returns:
|
||||
[`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
|
||||
|
||||
|
|
@ -95,13 +97,24 @@ class AltCLIPProcessor(ProcessorMixin):
|
|||
"""
|
||||
|
||||
if text is None and images is None:
|
||||
raise ValueError("You have to specify either text or images. Both cannot be none.")
|
||||
raise ValueError("You must specify either text or images.")
|
||||
|
||||
if text is None and images is None:
|
||||
raise ValueError("You must specify either text or images.")
|
||||
output_kwargs = self._merge_kwargs(
|
||||
AltClipProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if text is not None:
|
||||
encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs)
|
||||
|
||||
encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||
if images is not None:
|
||||
image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs)
|
||||
image_features = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||
|
||||
# BC for explicit return_tensors
|
||||
if "return_tensors" in output_kwargs["common_kwargs"]:
|
||||
return_tensors = output_kwargs["common_kwargs"].pop("return_tensors", None)
|
||||
|
||||
if text is not None and images is not None:
|
||||
encoding["pixel_values"] = image_features.pixel_values
|
||||
|
|
|
|||
|
|
@ -231,6 +231,7 @@ class ChineseCLIPImageProcessor(BaseImageProcessor):
|
|||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
size = size if size is not None else self.size
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
|
|
|
|||
|
|
@ -17,9 +17,15 @@ Image/Text processor class for Chinese-CLIP
|
|||
"""
|
||||
|
||||
import warnings
|
||||
from typing import List, Union
|
||||
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...tokenization_utils_base import BatchEncoding
|
||||
from ...image_utils import ImageInput
|
||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
|
||||
|
||||
|
||||
class ChineseClipProcessorKwargs(ProcessingKwargs, total=False):
|
||||
_defaults = {}
|
||||
|
||||
|
||||
class ChineseCLIPProcessor(ProcessorMixin):
|
||||
|
|
@ -60,7 +66,14 @@ class ChineseCLIPProcessor(ProcessorMixin):
|
|||
super().__init__(image_processor, tokenizer)
|
||||
self.current_processor = self.image_processor
|
||||
|
||||
def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
|
||||
def __call__(
|
||||
self,
|
||||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
||||
images: ImageInput = None,
|
||||
audio=None,
|
||||
videos=None,
|
||||
**kwargs: Unpack[ChineseClipProcessorKwargs],
|
||||
) -> BatchEncoding:
|
||||
"""
|
||||
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
||||
and `kwargs` arguments to BertTokenizerFast's [`~BertTokenizerFast.__call__`] if `text` is not `None` to encode
|
||||
|
|
@ -79,12 +92,10 @@ class ChineseCLIPProcessor(ProcessorMixin):
|
|||
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
Returns:
|
||||
[`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
|
||||
|
||||
|
|
@ -97,12 +108,20 @@ class ChineseCLIPProcessor(ProcessorMixin):
|
|||
|
||||
if text is None and images is None:
|
||||
raise ValueError("You have to specify either text or images. Both cannot be none.")
|
||||
output_kwargs = self._merge_kwargs(
|
||||
ChineseClipProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if text is not None:
|
||||
encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs)
|
||||
|
||||
encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||
if images is not None:
|
||||
image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs)
|
||||
image_features = self.image_processor(images, **output_kwargs["images_kwargs"])
|
||||
|
||||
# BC for explicit return_tensors
|
||||
if "return_tensors" in output_kwargs["common_kwargs"]:
|
||||
return_tensors = output_kwargs["common_kwargs"].pop("return_tensors", None)
|
||||
|
||||
if text is not None and images is not None:
|
||||
encoding["pixel_values"] = image_features.pixel_values
|
||||
|
|
|
|||
|
|
@ -20,11 +20,14 @@ import copy
|
|||
import inspect
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import typing_extensions
|
||||
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
from .image_utils import ChannelDimension, is_valid_image, is_vision_available
|
||||
|
|
@ -67,6 +70,11 @@ AUTO_TO_BASE_CLASS_MAPPING = {
|
|||
"AutoImageProcessor": "ImageProcessingMixin",
|
||||
}
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
Unpack = typing.Unpack
|
||||
else:
|
||||
Unpack = typing_extensions.Unpack
|
||||
|
||||
|
||||
class TextKwargs(TypedDict, total=False):
|
||||
"""
|
||||
|
|
@ -151,6 +159,8 @@ class ImagesKwargs(TypedDict, total=False):
|
|||
Standard deviation to use if normalizing the image.
|
||||
do_pad (`bool`, *optional*):
|
||||
Whether to pad the image to the `(max_height, max_width)` of the images in the batch.
|
||||
pad_size (`Dict[str, int]`, *optional*):
|
||||
The size `{"height": int, "width" int}` to pad the images to.
|
||||
do_center_crop (`bool`, *optional*):
|
||||
Whether to center crop the image.
|
||||
data_format (`ChannelDimension` or `str`, *optional*):
|
||||
|
|
@ -170,6 +180,7 @@ class ImagesKwargs(TypedDict, total=False):
|
|||
image_mean: Optional[Union[float, List[float]]]
|
||||
image_std: Optional[Union[float, List[float]]]
|
||||
do_pad: Optional[bool]
|
||||
pad_size: Optional[Dict[str, int]]
|
||||
do_center_crop: Optional[bool]
|
||||
data_format: Optional[ChannelDimension]
|
||||
input_data_format: Optional[Union[str, ChannelDimension]]
|
||||
|
|
@ -814,7 +825,8 @@ class ProcessorMixin(PushToHubMixin):
|
|||
# check if this key was passed as a flat kwarg.
|
||||
if kwarg_value != "__empty__" and modality_key in non_modality_kwargs:
|
||||
raise ValueError(
|
||||
f"Keyword argument {modality_key} was passed two times: in a dictionary for {modality} and as a **kwarg."
|
||||
f"Keyword argument {modality_key} was passed two times:\n"
|
||||
f"in a dictionary for {modality} and as a **kwarg."
|
||||
)
|
||||
elif modality_key in kwargs:
|
||||
kwarg_value = kwargs.pop(modality_key, "__empty__")
|
||||
|
|
|
|||
165
tests/models/altclip/test_processor_altclip.py
Normal file
165
tests/models/altclip/test_processor_altclip.py
Normal file
|
|
@ -0,0 +1,165 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import XLMRobertaTokenizer, XLMRobertaTokenizerFast
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from transformers import AltCLIPProcessor, CLIPImageProcessor
|
||||
|
||||
|
||||
@require_vision
|
||||
class AltClipProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
processor_class = AltCLIPProcessor
|
||||
|
||||
def setUp(self):
|
||||
self.model_id = "BAAI/AltCLIP"
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
image_processor = CLIPImageProcessor()
|
||||
tokenizer = XLMRobertaTokenizer.from_pretrained(self.model_id)
|
||||
|
||||
processor = self.processor_class(image_processor, tokenizer)
|
||||
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return XLMRobertaTokenizer.from_pretrained(self.model_id, **kwargs)
|
||||
|
||||
def get_rust_tokenizer(self, **kwargs):
|
||||
return XLMRobertaTokenizerFast.from_pretrained(self.model_id, **kwargs)
|
||||
|
||||
def get_image_processor(self, **kwargs):
|
||||
return CLIPImageProcessor.from_pretrained(self.model_id, **kwargs)
|
||||
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_unstructured_kwargs_batched(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = ["lower newer", "upper older longer string"]
|
||||
image_input = self.prepare_image_inputs() * 2
|
||||
inputs = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
return_tensors="pt",
|
||||
crop_size={"height": 214, "width": 214},
|
||||
padding="longest",
|
||||
max_length=76,
|
||||
)
|
||||
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 7)
|
||||
|
||||
def test_structured_kwargs_nested(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
# Define the kwargs for each modality
|
||||
all_kwargs = {
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
"images_kwargs": {"crop_size": {"height": 214, "width": 214}},
|
||||
"text_kwargs": {"padding": "max_length", "max_length": 76},
|
||||
}
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, **all_kwargs)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
|
||||
def test_structured_kwargs_nested_from_dict(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
# Define the kwargs for each modality
|
||||
all_kwargs = {
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
"images_kwargs": {"crop_size": {"height": 214, "width": 214}},
|
||||
"text_kwargs": {"padding": "max_length", "max_length": 76},
|
||||
}
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, **all_kwargs)
|
||||
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
|
||||
def test_unstructured_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
inputs = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
return_tensors="pt",
|
||||
crop_size={"height": 214, "width": 214},
|
||||
padding="max_length",
|
||||
max_length=76,
|
||||
)
|
||||
|
||||
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
|
||||
def test_image_processor_defaults_preserved_by_image_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor", crop_size=(234, 234))
|
||||
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
inputs = processor(text=input_str, images=image_input)
|
||||
self.assertEqual(len(inputs["pixel_values"][0][0]), 234)
|
||||
|
|
@ -206,3 +206,129 @@ class ChineseCLIPProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||
inputs = processor(text=input_str, images=image_input)
|
||||
|
||||
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
|
||||
|
||||
def test_unstructured_kwargs_batched(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = ["lower newer", "upper older longer string"]
|
||||
image_input = self.prepare_image_inputs() * 2
|
||||
inputs = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
return_tensors="pt",
|
||||
crop_size={"height": 214, "width": 214},
|
||||
padding="longest",
|
||||
max_length=76,
|
||||
)
|
||||
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 6)
|
||||
|
||||
def test_structured_kwargs_nested(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
# Define the kwargs for each modality
|
||||
all_kwargs = {
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
"images_kwargs": {"crop_size": {"height": 214, "width": 214}},
|
||||
"text_kwargs": {"padding": "max_length", "max_length": 76},
|
||||
}
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, **all_kwargs)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
|
||||
def test_structured_kwargs_nested_from_dict(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
# Define the kwargs for each modality
|
||||
all_kwargs = {
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
"images_kwargs": {"crop_size": {"height": 214, "width": 214}},
|
||||
"text_kwargs": {"padding": "max_length", "max_length": 76},
|
||||
}
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, **all_kwargs)
|
||||
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
|
||||
def test_unstructured_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
inputs = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
return_tensors="pt",
|
||||
crop_size={"height": 214, "width": 214},
|
||||
padding="max_length",
|
||||
max_length=76,
|
||||
)
|
||||
|
||||
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
|
||||
def test_image_processor_defaults_preserved_by_image_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor", crop_size=(234, 234))
|
||||
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
inputs = processor(text=input_str, images=image_input)
|
||||
self.assertEqual(len(inputs["pixel_values"][0][0]), 234)
|
||||
|
||||
def test_kwargs_overrides_default_image_processor_kwargs(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
image_processor = self.get_component("image_processor", crop_size=(234, 234))
|
||||
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, crop_size=[224, 224])
|
||||
self.assertEqual(len(inputs["pixel_values"][0][0]), 224)
|
||||
|
|
|
|||
|
|
@ -409,3 +409,31 @@ class InstructBlipProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
|
||||
def test_overlapping_text_kwargs_handling(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
processor_kwargs = {}
|
||||
processor_kwargs["image_processor"] = self.get_component("image_processor")
|
||||
processor_kwargs["tokenizer"] = tokenizer = self.get_component("tokenizer")
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
if "video_processor" in self.processor_class.attributes:
|
||||
processor_kwargs["video_processor"] = self.get_component("video_processor")
|
||||
|
||||
qformer_tokenizer = self.get_component("qformer_tokenizer")
|
||||
|
||||
processor = self.processor_class(**processor_kwargs, qformer_tokenizer=qformer_tokenizer)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
_ = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
text_kwargs={"padding": "do_not_pad"},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -423,3 +423,31 @@ class InstructBlipVideoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
|
||||
def test_overlapping_text_kwargs_handling(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
processor_kwargs = {}
|
||||
processor_kwargs["image_processor"] = self.get_component("image_processor")
|
||||
processor_kwargs["tokenizer"] = tokenizer = self.get_component("tokenizer")
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
if "video_processor" in self.processor_class.attributes:
|
||||
processor_kwargs["video_processor"] = self.get_component("video_processor")
|
||||
|
||||
qformer_tokenizer = self.get_component("qformer_tokenizer")
|
||||
|
||||
processor = self.processor_class(**processor_kwargs, qformer_tokenizer=qformer_tokenizer)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
_ = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
text_kwargs={"padding": "do_not_pad"},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -146,7 +146,6 @@ class ProcessorTesterMixin:
|
|||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
inputs = processor(text=input_str, images=image_input, return_tensors="pt")
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 117)
|
||||
|
||||
|
|
@ -175,7 +174,6 @@ class ProcessorTesterMixin:
|
|||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
inputs = processor(
|
||||
text=input_str, images=image_input, return_tensors="pt", max_length=112, padding="max_length"
|
||||
)
|
||||
|
|
@ -238,7 +236,6 @@ class ProcessorTesterMixin:
|
|||
padding="longest",
|
||||
max_length=76,
|
||||
)
|
||||
|
||||
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 6)
|
||||
|
|
@ -311,3 +308,30 @@ class ProcessorTesterMixin:
|
|||
self.assertEqual(inputs["pixel_values"].shape[2], 214)
|
||||
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
|
||||
# TODO: the same test, but for audio + text processors that have strong overlap in kwargs
|
||||
# TODO (molbap) use the same structure of attribute kwargs for other tests to avoid duplication
|
||||
def test_overlapping_text_kwargs_handling(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
processor_kwargs = {}
|
||||
processor_kwargs["image_processor"] = self.get_component("image_processor")
|
||||
processor_kwargs["tokenizer"] = tokenizer = self.get_component("tokenizer")
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
if "video_processor" in self.processor_class.attributes:
|
||||
processor_kwargs["video_processor"] = self.get_component("video_processor")
|
||||
processor = self.processor_class(**processor_kwargs)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
_ = processor(
|
||||
text=input_str,
|
||||
images=image_input,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
text_kwargs={"padding": "do_not_pad"},
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue