mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
uniformize kwargs for SAM (#34578)
* Make kwargs uniform for SAM * Remove unused attribute * Make point_pad_value part of image_kwargs * Update annotations * Code review - use existing methods * Use ProcessorTesterMixin * Do not add ProcessorTesterMixin everywhere
This commit is contained in:
parent
2bb60982ac
commit
e10be82b71
2 changed files with 81 additions and 29 deletions
|
|
@ -17,13 +17,14 @@ Processor class for SAM.
|
|||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ...processing_utils import ProcessorMixin
|
||||
from ...tokenization_utils_base import BatchEncoding
|
||||
from ...utils import TensorType, is_tf_available, is_torch_available
|
||||
from ...image_utils import ImageInput, VideoInput
|
||||
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin
|
||||
from ...tokenization_utils_base import AudioInput, BatchEncoding, PreTokenizedInput, TextInput
|
||||
from ...utils import is_tf_available, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
|
|
@ -33,6 +34,23 @@ if is_tf_available():
|
|||
import tensorflow as tf
|
||||
|
||||
|
||||
class SamImagesKwargs(ImagesKwargs):
|
||||
segmentation_maps: Optional[ImageInput]
|
||||
input_points: Optional[List[List[float]]]
|
||||
input_labels: Optional[List[List[int]]]
|
||||
input_boxes: Optional[List[List[List[float]]]]
|
||||
point_pad_value: Optional[int]
|
||||
|
||||
|
||||
class SamProcessorKwargs(ProcessingKwargs, total=False):
|
||||
images_kwargs: SamImagesKwargs
|
||||
_defaults = {
|
||||
"images_kwargs": {
|
||||
"point_pad_value": -10,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class SamProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a SAM processor which wraps a SAM image processor and an 2D points & Bounding boxes processor into a
|
||||
|
|
@ -48,32 +66,50 @@ class SamProcessor(ProcessorMixin):
|
|||
|
||||
attributes = ["image_processor"]
|
||||
image_processor_class = "SamImageProcessor"
|
||||
# For backward compatibility. See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details.
|
||||
optional_call_args = [
|
||||
"segmentation_maps",
|
||||
"input_points",
|
||||
"input_labels",
|
||||
"input_boxes",
|
||||
]
|
||||
|
||||
def __init__(self, image_processor):
|
||||
super().__init__(image_processor)
|
||||
self.current_processor = self.image_processor
|
||||
self.point_pad_value = -10
|
||||
self.target_size = self.image_processor.size["longest_edge"]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images=None,
|
||||
segmentation_maps=None,
|
||||
input_points=None,
|
||||
input_labels=None,
|
||||
input_boxes=None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
images: Optional[ImageInput] = None,
|
||||
# The following is to capture `segmentation_maps`, `input_points`, `input_labels` and `input_boxes`
|
||||
# arguments that may be passed as a positional argument.
|
||||
# See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details,
|
||||
# or this conversation for more context:
|
||||
# https://github.com/huggingface/transformers/pull/32544#discussion_r1720208116
|
||||
# This behavior is only needed for backward compatibility and will be removed in future versions.
|
||||
*args, # to be deprecated
|
||||
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
|
||||
audio: Optional[AudioInput] = None,
|
||||
video: Optional[VideoInput] = None,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
"""
|
||||
This method uses [`SamImageProcessor.__call__`] method to prepare image(s) for the model. It also prepares 2D
|
||||
points and bounding boxes for the model if they are provided.
|
||||
"""
|
||||
output_kwargs = self._merge_kwargs(
|
||||
SamProcessorKwargs,
|
||||
tokenizer_init_kwargs={},
|
||||
**kwargs,
|
||||
**self.prepare_and_validate_optional_call_args(*args),
|
||||
)
|
||||
input_points = output_kwargs["images_kwargs"].pop("input_points", None)
|
||||
input_labels = output_kwargs["images_kwargs"].pop("input_labels", None)
|
||||
input_boxes = output_kwargs["images_kwargs"].pop("input_boxes", None)
|
||||
|
||||
encoding_image_processor = self.image_processor(
|
||||
images,
|
||||
segmentation_maps=segmentation_maps,
|
||||
return_tensors=return_tensors,
|
||||
**kwargs,
|
||||
**output_kwargs["images_kwargs"],
|
||||
)
|
||||
|
||||
# pop arguments that are not used in the foward but used nevertheless
|
||||
|
|
@ -94,7 +130,8 @@ class SamProcessor(ProcessorMixin):
|
|||
input_points=input_points,
|
||||
input_labels=input_labels,
|
||||
input_boxes=input_boxes,
|
||||
return_tensors=return_tensors,
|
||||
return_tensors=output_kwargs["common_kwargs"].get("return_tensors"),
|
||||
point_pad_value=output_kwargs["images_kwargs"].get("point_pad_value"),
|
||||
)
|
||||
|
||||
return encoding_image_processor
|
||||
|
|
@ -107,6 +144,7 @@ class SamProcessor(ProcessorMixin):
|
|||
input_labels=None,
|
||||
input_boxes=None,
|
||||
return_tensors="pt",
|
||||
point_pad_value=-10,
|
||||
):
|
||||
if input_points is not None:
|
||||
if len(original_sizes) != len(input_points):
|
||||
|
|
@ -121,7 +159,9 @@ class SamProcessor(ProcessorMixin):
|
|||
# check that all arrays have the same shape
|
||||
if not all(point.shape == input_points[0].shape for point in input_points):
|
||||
if input_labels is not None:
|
||||
input_points, input_labels = self._pad_points_and_labels(input_points, input_labels)
|
||||
input_points, input_labels = self._pad_points_and_labels(
|
||||
input_points, input_labels, point_pad_value
|
||||
)
|
||||
|
||||
input_points = np.array(input_points)
|
||||
|
||||
|
|
@ -174,7 +214,7 @@ class SamProcessor(ProcessorMixin):
|
|||
|
||||
return encoding_image_processor
|
||||
|
||||
def _pad_points_and_labels(self, input_points, input_labels):
|
||||
def _pad_points_and_labels(self, input_points, input_labels, point_pad_value):
|
||||
r"""
|
||||
The method pads the 2D points and labels to the maximum number of points in the batch.
|
||||
"""
|
||||
|
|
@ -183,9 +223,9 @@ class SamProcessor(ProcessorMixin):
|
|||
for i, point in enumerate(input_points):
|
||||
if point.shape[0] != expected_nb_points:
|
||||
point = np.concatenate(
|
||||
[point, np.zeros((expected_nb_points - point.shape[0], 2)) + self.point_pad_value], axis=0
|
||||
[point, np.zeros((expected_nb_points - point.shape[0], 2)) + point_pad_value], axis=0
|
||||
)
|
||||
input_labels[i] = np.append(input_labels[i], [self.point_pad_value])
|
||||
input_labels[i] = np.append(input_labels[i], [point_pad_value])
|
||||
processed_input_points.append(point)
|
||||
input_points = processed_input_points
|
||||
return input_points, input_labels
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ from transformers.testing_utils import (
|
|||
)
|
||||
from transformers.utils import is_tf_available, is_torch_available, is_vision_available
|
||||
|
||||
from ...test_processing_common import prepare_image_inputs
|
||||
from ...test_processing_common import ProcessorTesterMixin, prepare_image_inputs
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
|
|
@ -43,7 +43,9 @@ if is_tf_available():
|
|||
|
||||
@require_vision
|
||||
@require_torchvision
|
||||
class SamProcessorTest(unittest.TestCase):
|
||||
class SamProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
processor_class = SamProcessor
|
||||
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
image_processor = SamImageProcessor()
|
||||
|
|
@ -56,11 +58,6 @@ class SamProcessorTest(unittest.TestCase):
|
|||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
# Processor tester class can't use ProcessorTesterMixin atm because the processor is atypical e.g. only contains an image processor
|
||||
def prepare_image_inputs(self):
|
||||
"""This function prepares a list of PIL images."""
|
||||
return prepare_image_inputs()
|
||||
|
||||
def prepare_mask_inputs(self):
|
||||
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
|
||||
or a list of PyTorch tensors if one specifies torchify=True.
|
||||
|
|
@ -69,6 +66,21 @@ class SamProcessorTest(unittest.TestCase):
|
|||
mask_inputs = [Image.fromarray(x) for x in mask_inputs]
|
||||
return mask_inputs
|
||||
|
||||
def test_chat_template_save_loading(self):
|
||||
self.skipTest("SamProcessor does not have a tokenizer")
|
||||
|
||||
def test_image_processor_defaults_preserved_by_image_kwargs(self):
|
||||
self.skipTest("SamProcessor does not have a tokenizer")
|
||||
|
||||
def test_kwargs_overrides_default_image_processor_kwargs(self):
|
||||
self.skipTest("SamProcessor does not have a tokenizer")
|
||||
|
||||
def test_kwargs_overrides_default_tokenizer_kwargs(self):
|
||||
self.skipTest("SamProcessor does not have a tokenizer")
|
||||
|
||||
def test_tokenizer_defaults_preserved_by_kwargs(self):
|
||||
self.skipTest("SamProcessor does not have a tokenizer")
|
||||
|
||||
def test_save_load_pretrained_additional_features(self):
|
||||
processor = SamProcessor(image_processor=self.get_image_processor())
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
|
@ -165,7 +177,7 @@ class TFSamProcessorTest(unittest.TestCase):
|
|||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
# Processor tester class can't use ProcessorTesterMixin as processor is atypical e.g. only contains an image processor and it assumes torch
|
||||
# This is to avoid repeating the skipping of the common tests
|
||||
def prepare_image_inputs(self):
|
||||
"""This function prepares a list of PIL images."""
|
||||
return prepare_image_inputs()
|
||||
|
|
@ -248,7 +260,7 @@ class SamProcessorEquivalenceTest(unittest.TestCase):
|
|||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
# Processor tester class can't use ProcessorTesterMixin atm because the processor is atypical e.g. only contains an image processor
|
||||
# This is to avoid repeating the skipping of the common tests
|
||||
def prepare_image_inputs(self):
|
||||
"""This function prepares a list of PIL images."""
|
||||
return prepare_image_inputs()
|
||||
|
|
|
|||
Loading…
Reference in a new issue