uniformize kwargs for SAM (#34578)

* Make kwargs uniform for SAM

* Remove unused attribute

* Make point_pad_value part of image_kwargs

* Update annotations

* Code review - use existing methods

* Use ProcessorTesterMixin

* Do not add ProcessorTesterMixin everywhere
This commit is contained in:
Tibor Reiss 2024-12-23 13:54:57 +01:00 committed by GitHub
parent 2bb60982ac
commit e10be82b71
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 81 additions and 29 deletions

View file

@ -17,13 +17,14 @@ Processor class for SAM.
"""
from copy import deepcopy
from typing import Optional, Union
from typing import List, Optional, Union
import numpy as np
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding
from ...utils import TensorType, is_tf_available, is_torch_available
from ...image_utils import ImageInput, VideoInput
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin
from ...tokenization_utils_base import AudioInput, BatchEncoding, PreTokenizedInput, TextInput
from ...utils import is_tf_available, is_torch_available
if is_torch_available():
@ -33,6 +34,23 @@ if is_tf_available():
import tensorflow as tf
class SamImagesKwargs(ImagesKwargs):
segmentation_maps: Optional[ImageInput]
input_points: Optional[List[List[float]]]
input_labels: Optional[List[List[int]]]
input_boxes: Optional[List[List[List[float]]]]
point_pad_value: Optional[int]
class SamProcessorKwargs(ProcessingKwargs, total=False):
images_kwargs: SamImagesKwargs
_defaults = {
"images_kwargs": {
"point_pad_value": -10,
}
}
class SamProcessor(ProcessorMixin):
r"""
Constructs a SAM processor which wraps a SAM image processor and an 2D points & Bounding boxes processor into a
@ -48,32 +66,50 @@ class SamProcessor(ProcessorMixin):
attributes = ["image_processor"]
image_processor_class = "SamImageProcessor"
# For backward compatibility. See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details.
optional_call_args = [
"segmentation_maps",
"input_points",
"input_labels",
"input_boxes",
]
def __init__(self, image_processor):
super().__init__(image_processor)
self.current_processor = self.image_processor
self.point_pad_value = -10
self.target_size = self.image_processor.size["longest_edge"]
def __call__(
self,
images=None,
segmentation_maps=None,
input_points=None,
input_labels=None,
input_boxes=None,
return_tensors: Optional[Union[str, TensorType]] = None,
images: Optional[ImageInput] = None,
# The following is to capture `segmentation_maps`, `input_points`, `input_labels` and `input_boxes`
# arguments that may be passed as a positional argument.
# See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details,
# or this conversation for more context:
# https://github.com/huggingface/transformers/pull/32544#discussion_r1720208116
# This behavior is only needed for backward compatibility and will be removed in future versions.
*args, # to be deprecated
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
audio: Optional[AudioInput] = None,
video: Optional[VideoInput] = None,
**kwargs,
) -> BatchEncoding:
"""
This method uses [`SamImageProcessor.__call__`] method to prepare image(s) for the model. It also prepares 2D
points and bounding boxes for the model if they are provided.
"""
output_kwargs = self._merge_kwargs(
SamProcessorKwargs,
tokenizer_init_kwargs={},
**kwargs,
**self.prepare_and_validate_optional_call_args(*args),
)
input_points = output_kwargs["images_kwargs"].pop("input_points", None)
input_labels = output_kwargs["images_kwargs"].pop("input_labels", None)
input_boxes = output_kwargs["images_kwargs"].pop("input_boxes", None)
encoding_image_processor = self.image_processor(
images,
segmentation_maps=segmentation_maps,
return_tensors=return_tensors,
**kwargs,
**output_kwargs["images_kwargs"],
)
# pop arguments that are not used in the foward but used nevertheless
@ -94,7 +130,8 @@ class SamProcessor(ProcessorMixin):
input_points=input_points,
input_labels=input_labels,
input_boxes=input_boxes,
return_tensors=return_tensors,
return_tensors=output_kwargs["common_kwargs"].get("return_tensors"),
point_pad_value=output_kwargs["images_kwargs"].get("point_pad_value"),
)
return encoding_image_processor
@ -107,6 +144,7 @@ class SamProcessor(ProcessorMixin):
input_labels=None,
input_boxes=None,
return_tensors="pt",
point_pad_value=-10,
):
if input_points is not None:
if len(original_sizes) != len(input_points):
@ -121,7 +159,9 @@ class SamProcessor(ProcessorMixin):
# check that all arrays have the same shape
if not all(point.shape == input_points[0].shape for point in input_points):
if input_labels is not None:
input_points, input_labels = self._pad_points_and_labels(input_points, input_labels)
input_points, input_labels = self._pad_points_and_labels(
input_points, input_labels, point_pad_value
)
input_points = np.array(input_points)
@ -174,7 +214,7 @@ class SamProcessor(ProcessorMixin):
return encoding_image_processor
def _pad_points_and_labels(self, input_points, input_labels):
def _pad_points_and_labels(self, input_points, input_labels, point_pad_value):
r"""
The method pads the 2D points and labels to the maximum number of points in the batch.
"""
@ -183,9 +223,9 @@ class SamProcessor(ProcessorMixin):
for i, point in enumerate(input_points):
if point.shape[0] != expected_nb_points:
point = np.concatenate(
[point, np.zeros((expected_nb_points - point.shape[0], 2)) + self.point_pad_value], axis=0
[point, np.zeros((expected_nb_points - point.shape[0], 2)) + point_pad_value], axis=0
)
input_labels[i] = np.append(input_labels[i], [self.point_pad_value])
input_labels[i] = np.append(input_labels[i], [point_pad_value])
processed_input_points.append(point)
input_points = processed_input_points
return input_points, input_labels

View file

@ -26,7 +26,7 @@ from transformers.testing_utils import (
)
from transformers.utils import is_tf_available, is_torch_available, is_vision_available
from ...test_processing_common import prepare_image_inputs
from ...test_processing_common import ProcessorTesterMixin, prepare_image_inputs
if is_vision_available():
@ -43,7 +43,9 @@ if is_tf_available():
@require_vision
@require_torchvision
class SamProcessorTest(unittest.TestCase):
class SamProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = SamProcessor
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
image_processor = SamImageProcessor()
@ -56,11 +58,6 @@ class SamProcessorTest(unittest.TestCase):
def tearDown(self):
shutil.rmtree(self.tmpdirname)
# Processor tester class can't use ProcessorTesterMixin atm because the processor is atypical e.g. only contains an image processor
def prepare_image_inputs(self):
"""This function prepares a list of PIL images."""
return prepare_image_inputs()
def prepare_mask_inputs(self):
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
or a list of PyTorch tensors if one specifies torchify=True.
@ -69,6 +66,21 @@ class SamProcessorTest(unittest.TestCase):
mask_inputs = [Image.fromarray(x) for x in mask_inputs]
return mask_inputs
def test_chat_template_save_loading(self):
self.skipTest("SamProcessor does not have a tokenizer")
def test_image_processor_defaults_preserved_by_image_kwargs(self):
self.skipTest("SamProcessor does not have a tokenizer")
def test_kwargs_overrides_default_image_processor_kwargs(self):
self.skipTest("SamProcessor does not have a tokenizer")
def test_kwargs_overrides_default_tokenizer_kwargs(self):
self.skipTest("SamProcessor does not have a tokenizer")
def test_tokenizer_defaults_preserved_by_kwargs(self):
self.skipTest("SamProcessor does not have a tokenizer")
def test_save_load_pretrained_additional_features(self):
processor = SamProcessor(image_processor=self.get_image_processor())
processor.save_pretrained(self.tmpdirname)
@ -165,7 +177,7 @@ class TFSamProcessorTest(unittest.TestCase):
def tearDown(self):
shutil.rmtree(self.tmpdirname)
# Processor tester class can't use ProcessorTesterMixin as processor is atypical e.g. only contains an image processor and it assumes torch
# This is to avoid repeating the skipping of the common tests
def prepare_image_inputs(self):
"""This function prepares a list of PIL images."""
return prepare_image_inputs()
@ -248,7 +260,7 @@ class SamProcessorEquivalenceTest(unittest.TestCase):
def tearDown(self):
shutil.rmtree(self.tmpdirname)
# Processor tester class can't use ProcessorTesterMixin atm because the processor is atypical e.g. only contains an image processor
# This is to avoid repeating the skipping of the common tests
def prepare_image_inputs(self):
"""This function prepares a list of PIL images."""
return prepare_image_inputs()