From e9a49babeecbcb23db97debd88c42da351822878 Mon Sep 17 00:00:00 2001 From: Amrit Sahu <88420255+sahamrit@users.noreply.github.com> Date: Fri, 7 Oct 2022 19:30:19 +0530 Subject: [PATCH] [WIP] Add ZeroShotObjectDetectionPipeline (#18445) (#18930) * Add ZeroShotObjectDetectionPipeline (#18445) * Add AutoModelForZeroShotObjectDetection task This commit also adds the following - Add explicit _processor method for ZeroShotObjectDetectionPipeline. This is necessary as pipelines don't auto infer processors yet and `OwlVitProcessor` wraps tokenizer and feature_extractor together, to process multiple images at once - Add auto tests and other tests for ZeroShotObjectDetectionPipeline * Add AutoModelForZeroShotObjectDetection task This commit also adds the following - Add explicit _processor method for ZeroShotObjectDetectionPipeline. This is necessary as pipelines don't auto infer processors yet and `OwlVitProcessor` wraps tokenizer and feature_extractor together, to process multiple images at once - Add auto tests and other tests for ZeroShotObjectDetectionPipeline * Add batching for ZeroShotObjectDetectionPipeline * Fix doc-string ZeroShotObjectDetectionPipeline * Fix output format: ZeroShotObjectDetectionPipeline --- docs/source/en/main_classes/pipelines.mdx | 7 + docs/source/en/model_doc/auto.mdx | 4 + src/transformers/__init__.py | 6 + src/transformers/models/auto/__init__.py | 4 + src/transformers/models/auto/modeling_auto.py | 19 ++ src/transformers/pipelines/__init__.py | 9 + .../pipelines/zero_shot_object_detection.py | 278 ++++++++++++++++++ src/transformers/utils/dummy_pt_objects.py | 10 + ...st_pipelines_zero_shot_object_detection.py | 263 +++++++++++++++++ utils/update_metadata.py | 5 + 10 files changed, 605 insertions(+) create mode 100644 src/transformers/pipelines/zero_shot_object_detection.py create mode 100644 tests/pipelines/test_pipelines_zero_shot_object_detection.py diff --git a/docs/source/en/main_classes/pipelines.mdx b/docs/source/en/main_classes/pipelines.mdx index 4043a0000..5374f1a40 100644 --- a/docs/source/en/main_classes/pipelines.mdx +++ b/docs/source/en/main_classes/pipelines.mdx @@ -43,6 +43,7 @@ There are two categories of pipeline abstractions to be aware about: - [`VisualQuestionAnsweringPipeline`] - [`ZeroShotClassificationPipeline`] - [`ZeroShotImageClassificationPipeline`] + - [`ZeroShotObjectDetectionPipeline`] ## The pipeline abstraction @@ -456,6 +457,12 @@ See [`TokenClassificationPipeline`] for all details. - __call__ - all +### ZeroShotObjectDetectionPipeline + +[[autodoc]] ZeroShotObjectDetectionPipeline + - __call__ + - all + ## Parent class: `Pipeline` [[autodoc]] Pipeline diff --git a/docs/source/en/model_doc/auto.mdx b/docs/source/en/model_doc/auto.mdx index 93976424b..01db8c4b1 100644 --- a/docs/source/en/model_doc/auto.mdx +++ b/docs/source/en/model_doc/auto.mdx @@ -174,6 +174,10 @@ Likewise, if your `NewModel` is a subclass of [`PreTrainedModel`], make sure its [[autodoc]] AutoModelForInstanceSegmentation +## AutoModelForZeroShotObjectDetection + +[[autodoc]] AutoModelForZeroShotObjectDetection + ## TFAutoModel [[autodoc]] TFAutoModel diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 18bfea30a..026ec59eb 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -442,6 +442,7 @@ _import_structure = { "VisualQuestionAnsweringPipeline", "ZeroShotClassificationPipeline", "ZeroShotImageClassificationPipeline", + "ZeroShotObjectDetectionPipeline", "pipeline", ], "processing_utils": ["ProcessorMixin"], @@ -878,6 +879,7 @@ else: "MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING", "MODEL_MAPPING", "MODEL_WITH_LM_HEAD_MAPPING", + "MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING", "AutoModel", "AutoModelForAudioClassification", "AutoModelForAudioFrameClassification", @@ -905,6 +907,7 @@ else: "AutoModelForVision2Seq", "AutoModelForVisualQuestionAnswering", "AutoModelWithLMHead", + "AutoModelForZeroShotObjectDetection", ] ) _import_structure["models.bart"].extend( @@ -3407,6 +3410,7 @@ if TYPE_CHECKING: VisualQuestionAnsweringPipeline, ZeroShotClassificationPipeline, ZeroShotImageClassificationPipeline, + ZeroShotObjectDetectionPipeline, pipeline, ) from .processing_utils import ProcessorMixin @@ -3772,6 +3776,7 @@ if TYPE_CHECKING: MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING, MODEL_FOR_VISION_2_SEQ_MAPPING, MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING, + MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING, MODEL_MAPPING, MODEL_WITH_LM_HEAD_MAPPING, AutoModel, @@ -3800,6 +3805,7 @@ if TYPE_CHECKING: AutoModelForVideoClassification, AutoModelForVision2Seq, AutoModelForVisualQuestionAnswering, + AutoModelForZeroShotObjectDetection, AutoModelWithLMHead, ) from .models.bart import ( diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index 6129253f1..1964c7393 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -69,6 +69,7 @@ else: "MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING", "MODEL_MAPPING", "MODEL_WITH_LM_HEAD_MAPPING", + "MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING", "AutoModel", "AutoModelForAudioClassification", "AutoModelForAudioFrameClassification", @@ -96,6 +97,7 @@ else: "AutoModelForVisualQuestionAnswering", "AutoModelForDocumentQuestionAnswering", "AutoModelWithLMHead", + "AutoModelForZeroShotObjectDetection", ] try: @@ -215,6 +217,7 @@ if TYPE_CHECKING: MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING, MODEL_FOR_VISION_2_SEQ_MAPPING, MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING, + MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING, MODEL_MAPPING, MODEL_WITH_LM_HEAD_MAPPING, AutoModel, @@ -243,6 +246,7 @@ if TYPE_CHECKING: AutoModelForVideoClassification, AutoModelForVision2Seq, AutoModelForVisualQuestionAnswering, + AutoModelForZeroShotObjectDetection, AutoModelWithLMHead, ) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 4cf9b58a5..237c98c5b 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -472,6 +472,13 @@ MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict( ] ) +MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = OrderedDict( + [ + # Model for Zero Shot Object Detection mapping + ("owlvit", "OwlViTForObjectDetection") + ] +) + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( [ # Model for Seq2Seq Causal LM mapping @@ -830,6 +837,9 @@ MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES ) MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES) +MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES +) MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES ) @@ -1016,6 +1026,15 @@ class AutoModelForObjectDetection(_BaseAutoModelClass): AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection") +class AutoModelForZeroShotObjectDetection(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING + + +AutoModelForZeroShotObjectDetection = auto_class_update( + AutoModelForZeroShotObjectDetection, head_doc="zero-shot object detection" +) + + class AutoModelForVideoClassification(_BaseAutoModelClass): _model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index da0f3d4d8..0a8787281 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -72,6 +72,7 @@ from .token_classification import ( from .visual_question_answering import VisualQuestionAnsweringPipeline from .zero_shot_classification import ZeroShotClassificationArgumentHandler, ZeroShotClassificationPipeline from .zero_shot_image_classification import ZeroShotImageClassificationPipeline +from .zero_shot_object_detection import ZeroShotObjectDetectionPipeline if is_tf_available(): @@ -124,6 +125,7 @@ if is_torch_available(): AutoModelForTokenClassification, AutoModelForVision2Seq, AutoModelForVisualQuestionAnswering, + AutoModelForZeroShotObjectDetection, ) if TYPE_CHECKING: from ..modeling_tf_utils import TFPreTrainedModel @@ -335,6 +337,13 @@ SUPPORTED_TASKS = { "default": {"model": {"pt": ("facebook/detr-resnet-50", "2729413")}}, "type": "image", }, + "zero-shot-object-detection": { + "impl": ZeroShotObjectDetectionPipeline, + "tf": (), + "pt": (AutoModelForZeroShotObjectDetection,) if is_torch_available() else (), + "default": {"model": {"pt": ("google/owlvit-base-patch32", "17740e1")}}, + "type": "multimodal", + }, } NO_FEATURE_EXTRACTOR_TASKS = set() diff --git a/src/transformers/pipelines/zero_shot_object_detection.py b/src/transformers/pipelines/zero_shot_object_detection.py new file mode 100644 index 000000000..8c18bd502 --- /dev/null +++ b/src/transformers/pipelines/zero_shot_object_detection.py @@ -0,0 +1,278 @@ +from typing import Dict, List, Union + +import numpy as np + +from ..tokenization_utils_base import BatchEncoding +from ..utils import ( + add_end_docstrings, + is_tf_available, + is_torch_available, + is_vision_available, + logging, + requires_backends, +) +from .base import PIPELINE_INIT_ARGS, Pipeline + + +if is_vision_available(): + from PIL import Image + + from ..image_utils import load_image + +if is_torch_available(): + import torch + + from ..models.auto.modeling_auto import MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING + +logger = logging.get_logger(__name__) + + +@add_end_docstrings(PIPELINE_INIT_ARGS) +class ZeroShotObjectDetectionPipeline(Pipeline): + """ + Zero shot object detection pipeline using `OwlViTForObjectDetection`. This pipeline predicts bounding boxes of + objects when you provide an image and a set of `candidate_labels`. + + This object detection pipeline can currently be loaded from [`pipeline`] using the following task identifier: + `"zero-shot-object-detection"`. + + See the list of available models on + [huggingface.co/models](https://huggingface.co/models?filter=zero-shot-object-detection). + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + if self.framework == "tf": + raise ValueError(f"The {self.__class__} is only available in PyTorch.") + + requires_backends(self, "vision") + self.check_model_type(MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING) + + def __call__( + self, + images: Union[str, List[str], "Image.Image", List["Image.Image"]], + text_queries: Union[str, List[str], List[List[str]]] = None, + **kwargs + ): + """ + Detect objects (bounding boxes & classes) in the image(s) passed as inputs. + + Args: + images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`): + The pipeline handles three types of images: + + - A string containing an http url pointing to an image + - A string containing a local path to an image + - An image loaded in PIL directly + + text_queries (`str` or `List[str]` or `List[List[str]]`): Text queries to query the target image with. + If given multiple images, `text_queries` should be provided as a list of lists, where each nested list + contains the text queries for the corresponding image. + + threshold (`float`, *optional*, defaults to 0.1): + The probability necessary to make a prediction. + + top_k (`int`, *optional*, defaults to None): + The number of top predictions that will be returned by the pipeline. If the provided number is `None` + or higher than the number of predictions available, it will default to the number of predictions. + + + Return: + A list of lists containing prediction results, one list per input image. Each list contains dictionaries + with the following keys: + + - **label** (`str`) -- Text query corresponding to the found object. + - **score** (`float`) -- Score corresponding to the object (between 0 and 1). + - **box** (`Dict[str,int]`) -- Bounding box of the detected object in image's original size. It is a + dictionary with `x_min`, `x_max`, `y_min`, `y_max` keys. + """ + if isinstance(text_queries, str) or (isinstance(text_queries, List) and not isinstance(text_queries[0], List)): + if isinstance(images, (str, Image.Image)): + inputs = {"images": images, "text_queries": text_queries} + elif isinstance(images, List): + assert len(images) == 1, "Input text_queries and images must have correspondance" + inputs = {"images": images[0], "text_queries": text_queries} + else: + raise TypeError(f"Innapropriate type of images: {type(images)}") + + elif isinstance(text_queries, str) or (isinstance(text_queries, List) and isinstance(text_queries[0], List)): + if isinstance(images, (Image.Image, str)): + images = [images] + assert len(images) == len(text_queries), "Input text_queries and images must have correspondance" + inputs = {"images": images, "text_queries": text_queries} + else: + """ + Supports the following format + - {"images": images, "text_queries": text_queries} + """ + inputs = images + results = super().__call__(inputs, **kwargs) + return results + + def _sanitize_parameters(self, **kwargs): + postprocess_params = {} + if "threshold" in kwargs: + postprocess_params["threshold"] = kwargs["threshold"] + if "top_k" in kwargs: + postprocess_params["top_k"] = kwargs["top_k"] + return {}, {}, postprocess_params + + def preprocess(self, inputs): + if not isinstance(inputs["images"], List): + inputs["images"] = [inputs["images"]] + images = [load_image(img) for img in inputs["images"]] + text_queries = inputs["text_queries"] + if isinstance(text_queries, str) or isinstance(text_queries[0], str): + text_queries = [text_queries] + + target_sizes = [torch.IntTensor([[img.height, img.width]]) for img in images] + target_sizes = torch.cat(target_sizes) + inputs = self._processor(text=inputs["text_queries"], images=images, return_tensors="pt") + return {"target_sizes": target_sizes, "text_queries": text_queries, **inputs} + + def _forward(self, model_inputs): + target_sizes = model_inputs.pop("target_sizes") + text_queries = model_inputs.pop("text_queries") + outputs = self.model(**model_inputs) + + model_outputs = outputs.__class__({"target_sizes": target_sizes, "text_queries": text_queries, **outputs}) + return model_outputs + + def postprocess(self, model_outputs, threshold=0.1, top_k=None): + texts = model_outputs["text_queries"] + + outputs = self.feature_extractor.post_process( + outputs=model_outputs, target_sizes=model_outputs["target_sizes"] + ) + + results = [] + for i in range(len(outputs)): + keep = outputs[i]["scores"] >= threshold + labels = outputs[i]["labels"][keep].tolist() + scores = outputs[i]["scores"][keep].tolist() + boxes = [self._get_bounding_box(box) for box in outputs[i]["boxes"][keep]] + + result = [ + {"score": score, "label": texts[i][label], "box": box} + for score, label, box in zip(scores, labels, boxes) + ] + + result = sorted(result, key=lambda x: x["score"], reverse=True) + if top_k: + result = result[:top_k] + results.append(result) + + return results + + def _get_bounding_box(self, box: "torch.Tensor") -> Dict[str, int]: + """ + Turns list [xmin, xmax, ymin, ymax] into dict { "xmin": xmin, ... } + + Args: + box (`torch.Tensor`): Tensor containing the coordinates in corners format. + + Returns: + bbox (`Dict[str, int]`): Dict containing the coordinates in corners format. + """ + if self.framework != "pt": + raise ValueError("The ZeroShotObjectDetectionPipeline is only available in PyTorch.") + xmin, ymin, xmax, ymax = box.int().tolist() + bbox = { + "xmin": xmin, + "ymin": ymin, + "xmax": xmax, + "ymax": ymax, + } + return bbox + + # Replication of OwlViTProcessor __call__ method, since pipelines don't auto infer processor's yet! + def _processor(self, text=None, images=None, padding="max_length", return_tensors="np", **kwargs): + """ + Main method to prepare for the model one or several text(s) and image(s). This method forwards the `text` and + `kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode: + the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to + CLIPFeatureExtractor's [`~CLIPFeatureExtractor.__call__`] if `images` is not `None`. Please refer to the + doctsring of the above two methods for more information. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, + `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a + number of channels, H and W are image height and width. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + Returns: + [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + + if text is None and images is None: + raise ValueError("You have to specify at least one text or image. Both cannot be none.") + + if text is not None: + if isinstance(text, str) or (isinstance(text, List) and not isinstance(text[0], List)): + encodings = [self.tokenizer(text, padding=padding, return_tensors=return_tensors, **kwargs)] + + elif isinstance(text, List) and isinstance(text[0], List): + encodings = [] + + # Maximum number of queries across batch + max_num_queries = max([len(t) for t in text]) + + # Pad all batch samples to max number of text queries + for t in text: + if len(t) != max_num_queries: + t = t + [" "] * (max_num_queries - len(t)) + + encoding = self.tokenizer(t, padding=padding, return_tensors=return_tensors, **kwargs) + encodings.append(encoding) + else: + raise TypeError("Input text should be a string, a list of strings or a nested list of strings") + + if return_tensors == "np": + input_ids = np.concatenate([encoding["input_ids"] for encoding in encodings], axis=0) + attention_mask = np.concatenate([encoding["attention_mask"] for encoding in encodings], axis=0) + + elif return_tensors == "pt" and is_torch_available(): + import torch + + input_ids = torch.cat([encoding["input_ids"] for encoding in encodings], dim=0) + attention_mask = torch.cat([encoding["attention_mask"] for encoding in encodings], dim=0) + + elif return_tensors == "tf" and is_tf_available(): + import tensorflow as tf + + input_ids = tf.stack([encoding["input_ids"] for encoding in encodings], axis=0) + attention_mask = tf.stack([encoding["attention_mask"] for encoding in encodings], axis=0) + + else: + raise ValueError("Target return tensor type could not be returned") + + encoding = BatchEncoding() + encoding["input_ids"] = input_ids + encoding["attention_mask"] = attention_mask + + if images is not None: + image_features = self.feature_extractor(images, return_tensors=return_tensors, **kwargs) + + if text is not None and images is not None: + encoding["pixel_values"] = image_features.pixel_values + return encoding + elif text is not None: + return encoding + else: + return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index ef1a6baaf..72db36cab 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -418,6 +418,9 @@ MODEL_FOR_VISION_2_SEQ_MAPPING = None MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = None +MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING = None + + MODEL_MAPPING = None @@ -606,6 +609,13 @@ class AutoModelForVisualQuestionAnswering(metaclass=DummyObject): requires_backends(self, ["torch"]) +class AutoModelForZeroShotObjectDetection(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class AutoModelWithLMHead(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/pipelines/test_pipelines_zero_shot_object_detection.py b/tests/pipelines/test_pipelines_zero_shot_object_detection.py new file mode 100644 index 000000000..10b7e799c --- /dev/null +++ b/tests/pipelines/test_pipelines_zero_shot_object_detection.py @@ -0,0 +1,263 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from transformers import MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING, is_vision_available, pipeline +from transformers.testing_utils import ( + is_pipeline_test, + nested_simplify, + require_tf, + require_torch, + require_vision, + slow, +) + +from .test_pipelines_common import ANY, PipelineTestCaseMeta + + +if is_vision_available(): + from PIL import Image +else: + + class Image: + @staticmethod + def open(*args, **kwargs): + pass + + +@require_vision +@require_torch +@is_pipeline_test +class ZeroShotObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta): + + model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING + + def get_test_pipeline(self, model, tokenizer, feature_extractor): + object_detector = pipeline( + "zero-shot-object-detection", model="hf-internal-testing/tiny-random-owlvit-object-detection" + ) + + examples = [ + { + "images": "./tests/fixtures/tests_samples/COCO/000000039769.png", + "text_queries": ["cat", "remote", "couch"], + } + ] + return object_detector, examples + + def run_pipeline_test(self, object_detector, examples): + batch_outputs = object_detector(examples, threshold=0.0) + + self.assertEqual(len(examples), len(batch_outputs)) + for outputs in batch_outputs: + for output_per_image in outputs: + self.assertGreater(len(output_per_image), 0) + for detected_object in output_per_image: + self.assertEqual( + detected_object, + { + "score": ANY(float), + "label": ANY(str), + "box": {"xmin": ANY(int), "ymin": ANY(int), "xmax": ANY(int), "ymax": ANY(int)}, + }, + ) + + @require_tf + @unittest.skip("Zero Shot Object Detection not implemented in TF") + def test_small_model_tf(self): + pass + + @require_torch + def test_small_model_pt(self): + object_detector = pipeline( + "zero-shot-object-detection", model="hf-internal-testing/tiny-random-owlvit-object-detection" + ) + + outputs = object_detector( + "./tests/fixtures/tests_samples/COCO/000000039769.png", + text_queries=["cat", "remote", "couch"], + threshold=0.64, + ) + + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + [ + {"score": 0.7235, "label": "cat", "box": {"xmin": 204, "ymin": 167, "xmax": 232, "ymax": 190}}, + {"score": 0.6748, "label": "remote", "box": {"xmin": 571, "ymin": 83, "xmax": 598, "ymax": 103}}, + {"score": 0.6456, "label": "remote", "box": {"xmin": 494, "ymin": 105, "xmax": 521, "ymax": 127}}, + {"score": 0.642, "label": "remote", "box": {"xmin": 67, "ymin": 274, "xmax": 93, "ymax": 297}}, + ] + ], + ) + + outputs = object_detector( + ["./tests/fixtures/tests_samples/COCO/000000039769.png"], + text_queries=["cat", "remote", "couch"], + threshold=0.64, + ) + + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + [ + {"score": 0.7235, "label": "cat", "box": {"xmin": 204, "ymin": 167, "xmax": 232, "ymax": 190}}, + {"score": 0.6748, "label": "remote", "box": {"xmin": 571, "ymin": 83, "xmax": 598, "ymax": 103}}, + {"score": 0.6456, "label": "remote", "box": {"xmin": 494, "ymin": 105, "xmax": 521, "ymax": 127}}, + {"score": 0.642, "label": "remote", "box": {"xmin": 67, "ymin": 274, "xmax": 93, "ymax": 297}}, + ] + ], + ) + + outputs = object_detector( + "./tests/fixtures/tests_samples/COCO/000000039769.png", + text_queries=[["cat", "remote", "couch"]], + threshold=0.64, + ) + + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + [ + {"score": 0.7235, "label": "cat", "box": {"xmin": 204, "ymin": 167, "xmax": 232, "ymax": 190}}, + {"score": 0.6748, "label": "remote", "box": {"xmin": 571, "ymin": 83, "xmax": 598, "ymax": 103}}, + {"score": 0.6456, "label": "remote", "box": {"xmin": 494, "ymin": 105, "xmax": 521, "ymax": 127}}, + {"score": 0.642, "label": "remote", "box": {"xmin": 67, "ymin": 274, "xmax": 93, "ymax": 297}}, + ] + ], + ) + + outputs = object_detector( + [ + "./tests/fixtures/tests_samples/COCO/000000039769.png", + "http://images.cocodataset.org/val2017/000000039769.jpg", + ], + text_queries=[["cat", "remote", "couch"], ["cat", "remote", "couch"]], + threshold=0.64, + ) + + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + [ + {"score": 0.7235, "label": "cat", "box": {"xmin": 204, "ymin": 167, "xmax": 232, "ymax": 190}}, + {"score": 0.6748, "label": "remote", "box": {"xmin": 571, "ymin": 83, "xmax": 598, "ymax": 103}}, + {"score": 0.6456, "label": "remote", "box": {"xmin": 494, "ymin": 105, "xmax": 521, "ymax": 127}}, + {"score": 0.642, "label": "remote", "box": {"xmin": 67, "ymin": 274, "xmax": 93, "ymax": 297}}, + ], + [ + {"score": 0.7235, "label": "cat", "box": {"xmin": 204, "ymin": 167, "xmax": 232, "ymax": 190}}, + {"score": 0.6748, "label": "remote", "box": {"xmin": 571, "ymin": 83, "xmax": 598, "ymax": 103}}, + {"score": 0.6456, "label": "remote", "box": {"xmin": 494, "ymin": 105, "xmax": 521, "ymax": 127}}, + {"score": 0.642, "label": "remote", "box": {"xmin": 67, "ymin": 274, "xmax": 93, "ymax": 297}}, + ], + ], + ) + + @require_torch + @slow + def test_large_model_pt(self): + object_detector = pipeline("zero-shot-object-detection") + + outputs = object_detector( + "http://images.cocodataset.org/val2017/000000039769.jpg", text_queries=["cat", "remote", "couch"] + ) + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + [ + {"score": 0.2868, "label": "cat", "box": {"xmin": 324, "ymin": 20, "xmax": 640, "ymax": 373}}, + {"score": 0.277, "label": "remote", "box": {"xmin": 40, "ymin": 72, "xmax": 177, "ymax": 115}}, + {"score": 0.2537, "label": "cat", "box": {"xmin": 1, "ymin": 55, "xmax": 315, "ymax": 472}}, + {"score": 0.1474, "label": "remote", "box": {"xmin": 335, "ymin": 74, "xmax": 371, "ymax": 187}}, + {"score": 0.1208, "label": "couch", "box": {"xmin": 4, "ymin": 0, "xmax": 642, "ymax": 476}}, + ] + ], + ) + + outputs = object_detector( + [ + "http://images.cocodataset.org/val2017/000000039769.jpg", + "http://images.cocodataset.org/val2017/000000039769.jpg", + ], + text_queries=[["cat", "remote", "couch"], ["cat", "remote", "couch"]], + ) + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + [ + {"score": 0.2868, "label": "cat", "box": {"xmin": 324, "ymin": 20, "xmax": 640, "ymax": 373}}, + {"score": 0.277, "label": "remote", "box": {"xmin": 40, "ymin": 72, "xmax": 177, "ymax": 115}}, + {"score": 0.2537, "label": "cat", "box": {"xmin": 1, "ymin": 55, "xmax": 315, "ymax": 472}}, + {"score": 0.1474, "label": "remote", "box": {"xmin": 335, "ymin": 74, "xmax": 371, "ymax": 187}}, + {"score": 0.1208, "label": "couch", "box": {"xmin": 4, "ymin": 0, "xmax": 642, "ymax": 476}}, + ], + [ + {"score": 0.2868, "label": "cat", "box": {"xmin": 324, "ymin": 20, "xmax": 640, "ymax": 373}}, + {"score": 0.277, "label": "remote", "box": {"xmin": 40, "ymin": 72, "xmax": 177, "ymax": 115}}, + {"score": 0.2537, "label": "cat", "box": {"xmin": 1, "ymin": 55, "xmax": 315, "ymax": 472}}, + {"score": 0.1474, "label": "remote", "box": {"xmin": 335, "ymin": 74, "xmax": 371, "ymax": 187}}, + {"score": 0.1208, "label": "couch", "box": {"xmin": 4, "ymin": 0, "xmax": 642, "ymax": 476}}, + ], + ], + ) + + @require_tf + @unittest.skip("Zero Shot Object Detection not implemented in TF") + def test_large_model_tf(self): + pass + + @require_torch + @slow + def test_threshold(self): + threshold = 0.2 + object_detector = pipeline("zero-shot-object-detection") + + outputs = object_detector( + "http://images.cocodataset.org/val2017/000000039769.jpg", + text_queries=["cat", "remote", "couch"], + threshold=threshold, + ) + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + [ + {"score": 0.2868, "label": "cat", "box": {"xmin": 324, "ymin": 20, "xmax": 640, "ymax": 373}}, + {"score": 0.277, "label": "remote", "box": {"xmin": 40, "ymin": 72, "xmax": 177, "ymax": 115}}, + {"score": 0.2537, "label": "cat", "box": {"xmin": 1, "ymin": 55, "xmax": 315, "ymax": 472}}, + ] + ], + ) + + @require_torch + @slow + def test_top_k(self): + top_k = 2 + object_detector = pipeline("zero-shot-object-detection") + + outputs = object_detector( + "http://images.cocodataset.org/val2017/000000039769.jpg", + text_queries=["cat", "remote", "couch"], + top_k=top_k, + ) + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + [ + {"score": 0.2868, "label": "cat", "box": {"xmin": 324, "ymin": 20, "xmax": 640, "ymax": 373}}, + {"score": 0.277, "label": "remote", "box": {"xmin": 40, "ymin": 72, "xmax": 177, "ymax": 115}}, + ] + ], + ) diff --git a/utils/update_metadata.py b/utils/update_metadata.py index aaf296c04..8bb3b7167 100644 --- a/utils/update_metadata.py +++ b/utils/update_metadata.py @@ -58,6 +58,11 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [ ("image-segmentation", "MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES", "AutoModelForImageSegmentation"), ("fill-mask", "MODEL_FOR_MASKED_LM_MAPPING_NAMES", "AutoModelForMaskedLM"), ("object-detection", "MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES", "AutoModelForObjectDetection"), + ( + "zero-shot-object-detection", + "MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES", + "AutoModelForZeroShotObjectDetection", + ), ("question-answering", "MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES", "AutoModelForQuestionAnswering"), ("text2text-generation", "MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES", "AutoModelForSeq2SeqLM"), ("text-classification", "MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES", "AutoModelForSequenceClassification"),