mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
OmDet Turbo processor standardization (#34937)
* Fix docstring * Fix docstring * Add `classes_structure` to model output * Update omdet postprocessing * Adjust tests * Update code example in docs * Add deprecation to "classes" key in output * Types, docs * Fixing test * Fix missed clip_boxes * [run-slow] omdet_turbo * Apply suggestions from code review Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> * Make CamelCase class --------- Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
This commit is contained in:
parent
94ae9a8da1
commit
42b2857b01
5 changed files with 248 additions and 182 deletions
|
|
@ -44,37 +44,40 @@ One unique property of OmDet-Turbo compared to other zero-shot object detection
|
|||
Here's how to load the model and prepare the inputs to perform zero-shot object detection on a single image:
|
||||
|
||||
```python
|
||||
import requests
|
||||
from PIL import Image
|
||||
>>> import torch
|
||||
>>> import requests
|
||||
>>> from PIL import Image
|
||||
|
||||
from transformers import AutoProcessor, OmDetTurboForObjectDetection
|
||||
>>> from transformers import AutoProcessor, OmDetTurboForObjectDetection
|
||||
|
||||
processor = AutoProcessor.from_pretrained("omlab/omdet-turbo-swin-tiny-hf")
|
||||
model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf")
|
||||
>>> processor = AutoProcessor.from_pretrained("omlab/omdet-turbo-swin-tiny-hf")
|
||||
>>> model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf")
|
||||
|
||||
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
image = Image.open(requests.get(url, stream=True).raw)
|
||||
classes = ["cat", "remote"]
|
||||
inputs = processor(image, text=classes, return_tensors="pt")
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
>>> text_labels = ["cat", "remote"]
|
||||
>>> inputs = processor(image, text=text_labels, return_tensors="pt")
|
||||
|
||||
outputs = model(**inputs)
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(**inputs)
|
||||
|
||||
# convert outputs (bounding boxes and class logits)
|
||||
results = processor.post_process_grounded_object_detection(
|
||||
outputs,
|
||||
classes=classes,
|
||||
target_sizes=[image.size[::-1]],
|
||||
score_threshold=0.3,
|
||||
nms_threshold=0.3,
|
||||
)[0]
|
||||
for score, class_name, box in zip(
|
||||
results["scores"], results["classes"], results["boxes"]
|
||||
):
|
||||
box = [round(i, 1) for i in box.tolist()]
|
||||
print(
|
||||
f"Detected {class_name} with confidence "
|
||||
f"{round(score.item(), 2)} at location {box}"
|
||||
)
|
||||
>>> # convert outputs (bounding boxes and class logits)
|
||||
>>> results = processor.post_process_grounded_object_detection(
|
||||
... outputs,
|
||||
... target_sizes=[(image.height, image.width)],
|
||||
... text_labels=text_labels,
|
||||
... threshold=0.3,
|
||||
... nms_threshold=0.3,
|
||||
... )
|
||||
>>> result = results[0]
|
||||
>>> boxes, scores, text_labels = result["boxes"], result["scores"], result["text_labels"]
|
||||
>>> for box, score, text_label in zip(boxes, scores, text_labels):
|
||||
... box = [round(i, 2) for i in box.tolist()]
|
||||
... print(f"Detected {text_label} with confidence {round(score.item(), 3)} at location {box}")
|
||||
Detected remote with confidence 0.768 at location [39.89, 70.35, 176.74, 118.04]
|
||||
Detected cat with confidence 0.72 at location [11.6, 54.19, 314.8, 473.95]
|
||||
Detected remote with confidence 0.563 at location [333.38, 75.77, 370.7, 187.03]
|
||||
Detected cat with confidence 0.552 at location [345.15, 23.95, 639.75, 371.67]
|
||||
```
|
||||
|
||||
### Multi image inference
|
||||
|
|
@ -93,22 +96,22 @@ OmDet-Turbo can perform batched multi-image inference, with support for differen
|
|||
|
||||
>>> url1 = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image1 = Image.open(BytesIO(requests.get(url1).content)).convert("RGB")
|
||||
>>> classes1 = ["cat", "remote"]
|
||||
>>> task1 = "Detect {}.".format(", ".join(classes1))
|
||||
>>> text_labels1 = ["cat", "remote"]
|
||||
>>> task1 = "Detect {}.".format(", ".join(text_labels1))
|
||||
|
||||
>>> url2 = "http://images.cocodataset.org/train2017/000000257813.jpg"
|
||||
>>> image2 = Image.open(BytesIO(requests.get(url2).content)).convert("RGB")
|
||||
>>> classes2 = ["boat"]
|
||||
>>> text_labels2 = ["boat"]
|
||||
>>> task2 = "Detect everything that looks like a boat."
|
||||
|
||||
>>> url3 = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
|
||||
>>> image3 = Image.open(BytesIO(requests.get(url3).content)).convert("RGB")
|
||||
>>> classes3 = ["statue", "trees"]
|
||||
>>> text_labels3 = ["statue", "trees"]
|
||||
>>> task3 = "Focus on the foreground, detect statue and trees."
|
||||
|
||||
>>> inputs = processor(
|
||||
... images=[image1, image2, image3],
|
||||
... text=[classes1, classes2, classes3],
|
||||
... text=[text_labels1, text_labels2, text_labels3],
|
||||
... task=[task1, task2, task3],
|
||||
... return_tensors="pt",
|
||||
... )
|
||||
|
|
@ -119,19 +122,19 @@ OmDet-Turbo can perform batched multi-image inference, with support for differen
|
|||
>>> # convert outputs (bounding boxes and class logits)
|
||||
>>> results = processor.post_process_grounded_object_detection(
|
||||
... outputs,
|
||||
... classes=[classes1, classes2, classes3],
|
||||
... target_sizes=[image1.size[::-1], image2.size[::-1], image3.size[::-1]],
|
||||
... score_threshold=0.2,
|
||||
... text_labels=[text_labels1, text_labels2, text_labels3],
|
||||
... target_sizes=[(image.height, image.width) for image in [image1, image2, image3]],
|
||||
... threshold=0.2,
|
||||
... nms_threshold=0.3,
|
||||
... )
|
||||
|
||||
>>> for i, result in enumerate(results):
|
||||
... for score, class_name, box in zip(
|
||||
... result["scores"], result["classes"], result["boxes"]
|
||||
... for score, text_label, box in zip(
|
||||
... result["scores"], result["text_labels"], result["boxes"]
|
||||
... ):
|
||||
... box = [round(i, 1) for i in box.tolist()]
|
||||
... print(
|
||||
... f"Detected {class_name} with confidence "
|
||||
... f"Detected {text_label} with confidence "
|
||||
... f"{round(score.item(), 2)} at location {box} in image {i}"
|
||||
... )
|
||||
Detected remote with confidence 0.77 at location [39.9, 70.4, 176.7, 118.0] in image 0
|
||||
|
|
|
|||
|
|
@ -143,22 +143,24 @@ class OmDetTurboObjectDetectionOutput(ModelOutput):
|
|||
The predicted class of the objects from the encoder.
|
||||
encoder_extracted_states (`torch.FloatTensor`):
|
||||
The extracted states from the Feature Pyramid Network (FPN) and Path Aggregation Network (PAN) of the encoder.
|
||||
decoder_hidden_states (`Optional[Tuple[torch.FloatTensor]]`):
|
||||
decoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape
|
||||
`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
|
||||
plus the initial embedding outputs.
|
||||
decoder_attentions (`Optional[Tuple[Tuple[torch.FloatTensor]]]`):
|
||||
decoder_attentions (`Tuple[Tuple[torch.FloatTensor]]`, *optional*):
|
||||
Tuple of tuples of `torch.FloatTensor` (one for attention for each layer) of shape `(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
|
||||
weighted average in the self-attention, cross-attention and multi-scale deformable attention heads.
|
||||
encoder_hidden_states (`Optional[Tuple[torch.FloatTensor]]`):
|
||||
encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
|
||||
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape
|
||||
`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
|
||||
plus the initial embedding outputs.
|
||||
encoder_attentions (`Optional[Tuple[Tuple[torch.FloatTensor]]]`):
|
||||
encoder_attentions (`Tuple[Tuple[torch.FloatTensor]]`, *optional*):
|
||||
Tuple of tuples of `torch.FloatTensor` (one for attention for each layer) of shape `(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
|
||||
weighted average in the self-attention, cross-attention and multi-scale deformable attention heads.
|
||||
classes_structure (`torch.LongTensor`, *optional*):
|
||||
The number of queried classes for each image.
|
||||
"""
|
||||
|
||||
loss: torch.FloatTensor = None
|
||||
|
|
@ -173,6 +175,7 @@ class OmDetTurboObjectDetectionOutput(ModelOutput):
|
|||
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
classes_structure: Optional[torch.LongTensor] = None
|
||||
|
||||
|
||||
# Copied from models.deformable_detr.load_cuda_kernels
|
||||
|
|
@ -1667,16 +1670,16 @@ class OmDetTurboForObjectDetection(OmDetTurboPreTrainedModel):
|
|||
@replace_return_docstrings(output_type=OmDetTurboObjectDetectionOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Tensor,
|
||||
classes_input_ids: Tensor,
|
||||
classes_attention_mask: Tensor,
|
||||
tasks_input_ids: Tensor,
|
||||
tasks_attention_mask: Tensor,
|
||||
classes_structure: Tensor,
|
||||
labels: Optional[Tensor] = None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
pixel_values: torch.FloatTensor,
|
||||
classes_input_ids: torch.LongTensor,
|
||||
classes_attention_mask: torch.LongTensor,
|
||||
tasks_input_ids: torch.LongTensor,
|
||||
tasks_attention_mask: torch.LongTensor,
|
||||
classes_structure: torch.LongTensor,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple[torch.FloatTensor], OmDetTurboObjectDetectionOutput]:
|
||||
r"""
|
||||
Returns:
|
||||
|
|
@ -1770,6 +1773,7 @@ class OmDetTurboForObjectDetection(OmDetTurboPreTrainedModel):
|
|||
decoder_outputs[2],
|
||||
encoder_outputs[1],
|
||||
encoder_outputs[2],
|
||||
classes_structure,
|
||||
]
|
||||
if output is not None
|
||||
)
|
||||
|
|
@ -1787,6 +1791,7 @@ class OmDetTurboForObjectDetection(OmDetTurboPreTrainedModel):
|
|||
decoder_attentions=decoder_outputs.attentions,
|
||||
encoder_hidden_states=encoder_outputs.hidden_states,
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
classes_structure=classes_structure,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,8 @@
|
|||
Processor class for OmDet-Turbo.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...image_transforms import center_to_corners_format
|
||||
|
|
@ -28,12 +29,25 @@ from ...utils import (
|
|||
is_torch_available,
|
||||
is_torchvision_available,
|
||||
)
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .modeling_omdet_turbo import OmDetTurboObjectDetectionOutput
|
||||
|
||||
|
||||
class OmDetTurboTextKwargs(TextKwargs, total=False):
|
||||
task: Optional[Union[str, List[str], TextInput, PreTokenizedInput]]
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
if is_torchvision_available():
|
||||
from torchvision.ops.boxes import batched_nms
|
||||
|
||||
|
||||
class OmDetTurboProcessorKwargs(ProcessingKwargs, total=False):
|
||||
text_kwargs: OmDetTurboTextKwargs
|
||||
_defaults = {
|
||||
|
|
@ -55,11 +69,23 @@ class OmDetTurboProcessorKwargs(ProcessingKwargs, total=False):
|
|||
}
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
class DictWithDeprecationWarning(dict):
|
||||
message = (
|
||||
"The `classes` key is deprecated for `OmDetTurboProcessor.post_process_grounded_object_detection` "
|
||||
"output dict and will be removed in a 4.51.0 version. Please use `text_labels` instead."
|
||||
)
|
||||
|
||||
if is_torchvision_available():
|
||||
from torchvision.ops.boxes import batched_nms
|
||||
def __getitem__(self, key):
|
||||
if key == "classes":
|
||||
warnings.warn(self.message, FutureWarning)
|
||||
return super().__getitem__("text_labels")
|
||||
return super().__getitem__(key)
|
||||
|
||||
def get(self, key, *args, **kwargs):
|
||||
if key == "classes":
|
||||
warnings.warn(self.message, FutureWarning)
|
||||
return super().get("text_labels", *args, **kwargs)
|
||||
return super().get(key, *args, **kwargs)
|
||||
|
||||
|
||||
def clip_boxes(box, box_size: Tuple[int, int]):
|
||||
|
|
@ -97,76 +123,80 @@ def compute_score(boxes):
|
|||
|
||||
|
||||
def _post_process_boxes_for_image(
|
||||
boxes: TensorType,
|
||||
scores: TensorType,
|
||||
predicted_classes: TensorType,
|
||||
classes: List[str],
|
||||
boxes: "torch.Tensor",
|
||||
scores: "torch.Tensor",
|
||||
labels: "torch.Tensor",
|
||||
image_num_classes: int,
|
||||
image_size: Tuple[int, int],
|
||||
num_classes: int,
|
||||
score_threshold: float,
|
||||
threshold: float,
|
||||
nms_threshold: float,
|
||||
max_num_det: int = None,
|
||||
) -> dict:
|
||||
max_num_det: Optional[int] = None,
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
"""
|
||||
Filter predicted results using given thresholds and NMS.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor): A Tensor of predicted class-specific or class-agnostic
|
||||
boxes for the image. Shape : (num_queries, max_num_classes_in_batch * 4) if doing
|
||||
class-specific regression, or (num_queries, 4) if doing class-agnostic
|
||||
regression.
|
||||
scores (torch.Tensor): A Tensor of predicted class scores for the image.
|
||||
Shape : (num_queries, max_num_classes_in_batch + 1)
|
||||
predicted_classes (torch.Tensor): A Tensor of predicted classes for the image.
|
||||
Shape : (num_queries * (max_num_classes_in_batch + 1),)
|
||||
classes (List[str]): The input classes names.
|
||||
image_size (tuple): A tuple of (height, width) for the image.
|
||||
num_classes (int): The number of classes given for this image.
|
||||
score_threshold (float): Only return detections with a confidence score exceeding this
|
||||
threshold.
|
||||
nms_threshold (float): The threshold to use for box non-maximum suppression. Value in [0, 1].
|
||||
max_num_det (int, optional): The maximum number of detections to return. Default is None.
|
||||
boxes (`torch.Tensor`):
|
||||
A Tensor of predicted class-specific or class-agnostic boxes for the image.
|
||||
Shape (num_queries, max_num_classes_in_batch * 4) if doing class-specific regression,
|
||||
or (num_queries, 4) if doing class-agnostic regression.
|
||||
scores (`torch.Tensor` of shape (num_queries, max_num_classes_in_batch + 1)):
|
||||
A Tensor of predicted class scores for the image.
|
||||
labels (`torch.Tensor` of shape (num_queries * (max_num_classes_in_batch + 1),)):
|
||||
A Tensor of predicted labels for the image.
|
||||
image_num_classes (`int`):
|
||||
The number of classes queried for detection on the image.
|
||||
image_size (`Tuple[int, int]`):
|
||||
A tuple of (height, width) for the image.
|
||||
threshold (`float`):
|
||||
Only return detections with a confidence score exceeding this threshold.
|
||||
nms_threshold (`float`):
|
||||
The threshold to use for box non-maximum suppression. Value in [0, 1].
|
||||
max_num_det (`int`, *optional*):
|
||||
The maximum number of detections to return. Default is None.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary the following keys:
|
||||
Tuple: A tuple with the following:
|
||||
"boxes" (Tensor): A tensor of shape (num_filtered_objects, 4), containing the predicted boxes in (x1, y1, x2, y2) format.
|
||||
"scores" (Tensor): A tensor of shape (num_filtered_objects,), containing the predicted confidence scores for each detection.
|
||||
"classes" (List[str]): A list of strings, where each string is the predicted class for the
|
||||
corresponding detection
|
||||
"labels" (Tensor): A tensor of ids, where each id is the predicted class id for the corresponding detection
|
||||
"""
|
||||
|
||||
# Filter by max number of detections
|
||||
proposal_num = len(boxes) if max_num_det is None else max_num_det
|
||||
scores_per_image, topk_indices = scores.flatten(0, 1).topk(proposal_num, sorted=False)
|
||||
classes_per_image = predicted_classes[topk_indices]
|
||||
box_pred_per_image = boxes.view(-1, 1, 4).repeat(1, num_classes, 1).view(-1, 4)
|
||||
box_pred_per_image = box_pred_per_image[topk_indices]
|
||||
labels_per_image = labels[topk_indices]
|
||||
boxes_per_image = boxes.view(-1, 1, 4).repeat(1, scores.shape[1], 1).view(-1, 4)
|
||||
boxes_per_image = boxes_per_image[topk_indices]
|
||||
|
||||
# Score filtering
|
||||
box_pred_per_image = center_to_corners_format(box_pred_per_image)
|
||||
box_pred_per_image = box_pred_per_image * torch.tensor(image_size[::-1]).repeat(2).to(box_pred_per_image.device)
|
||||
filter_mask = scores_per_image > score_threshold # R x K
|
||||
# Convert and scale boxes to original image size
|
||||
boxes_per_image = center_to_corners_format(boxes_per_image)
|
||||
boxes_per_image = boxes_per_image * torch.tensor(image_size[::-1]).repeat(2).to(boxes_per_image.device)
|
||||
|
||||
# Filtering by confidence score
|
||||
filter_mask = scores_per_image > threshold # R x K
|
||||
score_keep = filter_mask.nonzero(as_tuple=False).view(-1)
|
||||
box_pred_per_image = box_pred_per_image[score_keep]
|
||||
boxes_per_image = boxes_per_image[score_keep]
|
||||
scores_per_image = scores_per_image[score_keep]
|
||||
classes_per_image = classes_per_image[score_keep]
|
||||
labels_per_image = labels_per_image[score_keep]
|
||||
|
||||
filter_classes_mask = classes_per_image < len(classes)
|
||||
# Ensure we did not overflow to non existing classes
|
||||
filter_classes_mask = labels_per_image < image_num_classes
|
||||
classes_keep = filter_classes_mask.nonzero(as_tuple=False).view(-1)
|
||||
box_pred_per_image = box_pred_per_image[classes_keep]
|
||||
boxes_per_image = boxes_per_image[classes_keep]
|
||||
scores_per_image = scores_per_image[classes_keep]
|
||||
classes_per_image = classes_per_image[classes_keep]
|
||||
labels_per_image = labels_per_image[classes_keep]
|
||||
|
||||
# NMS
|
||||
keep = batched_nms(box_pred_per_image, scores_per_image, classes_per_image, nms_threshold)
|
||||
box_pred_per_image = box_pred_per_image[keep]
|
||||
keep = batched_nms(boxes_per_image, scores_per_image, labels_per_image, nms_threshold)
|
||||
boxes_per_image = boxes_per_image[keep]
|
||||
scores_per_image = scores_per_image[keep]
|
||||
classes_per_image = classes_per_image[keep]
|
||||
classes_per_image = [classes[i] for i in classes_per_image]
|
||||
labels_per_image = labels_per_image[keep]
|
||||
|
||||
# create an instance
|
||||
result = {}
|
||||
result["boxes"] = clip_boxes(box_pred_per_image, image_size)
|
||||
result["scores"] = scores_per_image
|
||||
result["classes"] = classes_per_image
|
||||
# Clip to image size
|
||||
boxes_per_image = clip_boxes(boxes_per_image, image_size)
|
||||
|
||||
return result
|
||||
return boxes_per_image, scores_per_image, labels_per_image
|
||||
|
||||
|
||||
class OmDetTurboProcessor(ProcessorMixin):
|
||||
|
|
@ -274,11 +304,26 @@ class OmDetTurboProcessor(ProcessorMixin):
|
|||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
def _get_default_image_size(self) -> Tuple[int, int]:
|
||||
height = (
|
||||
self.image_processor.size["height"]
|
||||
if "height" in self.image_processor.size
|
||||
else self.image_processor.size["shortest_edge"]
|
||||
)
|
||||
width = (
|
||||
self.image_processor.size["width"]
|
||||
if "width" in self.image_processor.size
|
||||
else self.image_processor.size["longest_edge"]
|
||||
)
|
||||
return height, width
|
||||
|
||||
@deprecate_kwarg("score_threshold", new_name="threshold", version="4.51.0")
|
||||
@deprecate_kwarg("classes", new_name="text_labels", version="4.51.0")
|
||||
def post_process_grounded_object_detection(
|
||||
self,
|
||||
outputs,
|
||||
classes: Union[List[str], List[List[str]]],
|
||||
score_threshold: float = 0.3,
|
||||
outputs: "OmDetTurboObjectDetectionOutput",
|
||||
text_labels: Optional[Union[List[str], List[List[str]]]] = None,
|
||||
threshold: float = 0.3,
|
||||
nms_threshold: float = 0.5,
|
||||
target_sizes: Optional[Union[TensorType, List[Tuple]]] = None,
|
||||
max_num_det: Optional[int] = None,
|
||||
|
|
@ -290,67 +335,77 @@ class OmDetTurboProcessor(ProcessorMixin):
|
|||
Args:
|
||||
outputs ([`OmDetTurboObjectDetectionOutput`]):
|
||||
Raw outputs of the model.
|
||||
classes (Union[List[str], List[List[str]]]): The input classes names.
|
||||
score_threshold (float, defaults to 0.3): Only return detections with a confidence score exceeding this
|
||||
threshold.
|
||||
nms_threshold (float, defaults to 0.5): The threshold to use for box non-maximum suppression. Value in [0, 1].
|
||||
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*, defaults to None):
|
||||
text_labels (Union[List[str], List[List[str]]], *optional*):
|
||||
The input classes names. If not provided, `text_labels` will be set to `None` in `outputs`.
|
||||
threshold (float, defaults to 0.3):
|
||||
Only return detections with a confidence score exceeding this threshold.
|
||||
nms_threshold (float, defaults to 0.5):
|
||||
The threshold to use for box non-maximum suppression. Value in [0, 1].
|
||||
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
|
||||
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
|
||||
`(height, width)` of each image in the batch. If unset, predictions will not be resized.
|
||||
max_num_det (int, *optional*, defaults to None): The maximum number of detections to return.
|
||||
max_num_det (`int`, *optional*):
|
||||
The maximum number of detections to return.
|
||||
Returns:
|
||||
`List[Dict]`: A list of dictionaries, each dictionary containing the scores, classes and boxes for an image
|
||||
in the batch as predicted by the model.
|
||||
"""
|
||||
if isinstance(classes[0], str):
|
||||
classes = [classes]
|
||||
|
||||
boxes_logits = outputs.decoder_coord_logits
|
||||
scores_logits = outputs.decoder_class_logits
|
||||
batch_size = len(outputs.decoder_coord_logits)
|
||||
|
||||
# Inputs consistency check
|
||||
# Inputs consistency check for target sizes
|
||||
if target_sizes is None:
|
||||
height = (
|
||||
self.image_processor.size["height"]
|
||||
if "height" in self.image_processor.size
|
||||
else self.image_processor.size["shortest_edge"]
|
||||
)
|
||||
width = (
|
||||
self.image_processor.size["width"]
|
||||
if "width" in self.image_processor.size
|
||||
else self.image_processor.size["longest_edge"]
|
||||
)
|
||||
target_sizes = ((height, width),) * len(boxes_logits)
|
||||
elif len(target_sizes[0]) != 2:
|
||||
height, width = self._get_default_image_size()
|
||||
target_sizes = [(height, width)] * batch_size
|
||||
|
||||
if any(len(image_size) != 2 for image_size in target_sizes):
|
||||
raise ValueError(
|
||||
"Each element of target_sizes must contain the size (height, width) of each image of the batch"
|
||||
)
|
||||
if len(target_sizes) != len(boxes_logits):
|
||||
|
||||
if len(target_sizes) != batch_size:
|
||||
raise ValueError("Make sure that you pass in as many target sizes as output sequences")
|
||||
if len(classes) != len(boxes_logits):
|
||||
|
||||
# Inputs consistency check for text labels
|
||||
if text_labels is not None and isinstance(text_labels[0], str):
|
||||
text_labels = [text_labels]
|
||||
|
||||
if text_labels is not None and len(text_labels) != batch_size:
|
||||
raise ValueError("Make sure that you pass in as many classes group as output sequences")
|
||||
|
||||
# Convert target_sizes to list for easier handling
|
||||
if isinstance(target_sizes, torch.Tensor):
|
||||
target_sizes = target_sizes.tolist()
|
||||
|
||||
scores, predicted_classes = compute_score(scores_logits)
|
||||
num_classes = scores_logits.shape[2]
|
||||
batch_boxes = outputs.decoder_coord_logits
|
||||
batch_logits = outputs.decoder_class_logits
|
||||
batch_num_classes = outputs.classes_structure
|
||||
|
||||
batch_scores, batch_labels = compute_score(batch_logits)
|
||||
|
||||
results = []
|
||||
for scores_img, box_per_img, image_size, class_names in zip(scores, boxes_logits, target_sizes, classes):
|
||||
results.append(
|
||||
_post_process_boxes_for_image(
|
||||
box_per_img,
|
||||
scores_img,
|
||||
predicted_classes,
|
||||
class_names,
|
||||
image_size,
|
||||
num_classes,
|
||||
score_threshold=score_threshold,
|
||||
nms_threshold=nms_threshold,
|
||||
max_num_det=max_num_det,
|
||||
)
|
||||
for boxes, scores, image_size, image_num_classes in zip(
|
||||
batch_boxes, batch_scores, target_sizes, batch_num_classes
|
||||
):
|
||||
boxes, scores, labels = _post_process_boxes_for_image(
|
||||
boxes=boxes,
|
||||
scores=scores,
|
||||
labels=batch_labels,
|
||||
image_num_classes=image_num_classes,
|
||||
image_size=image_size,
|
||||
threshold=threshold,
|
||||
nms_threshold=nms_threshold,
|
||||
max_num_det=max_num_det,
|
||||
)
|
||||
result = DictWithDeprecationWarning(
|
||||
{"boxes": boxes, "scores": scores, "labels": labels, "text_labels": None}
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
# Add text labels
|
||||
if text_labels is not None:
|
||||
for result, image_text_labels in zip(results, text_labels):
|
||||
result["text_labels"] = [image_text_labels[idx] for idx in result["labels"]]
|
||||
|
||||
return results
|
||||
|
||||
|
|
|
|||
|
|
@ -646,9 +646,9 @@ def prepare_img():
|
|||
|
||||
|
||||
def prepare_text():
|
||||
classes = ["cat", "remote"]
|
||||
task = "Detect {}.".format(", ".join(classes))
|
||||
return classes, task
|
||||
text_labels = ["cat", "remote"]
|
||||
task = "Detect {}.".format(", ".join(text_labels))
|
||||
return text_labels, task
|
||||
|
||||
|
||||
def prepare_img_batched():
|
||||
|
|
@ -660,14 +660,14 @@ def prepare_img_batched():
|
|||
|
||||
|
||||
def prepare_text_batched():
|
||||
classes1 = ["cat", "remote"]
|
||||
classes2 = ["boat"]
|
||||
classes3 = ["statue", "trees", "torch"]
|
||||
text_labels1 = ["cat", "remote"]
|
||||
text_labels2 = ["boat"]
|
||||
text_labels3 = ["statue", "trees", "torch"]
|
||||
|
||||
task1 = "Detect {}.".format(", ".join(classes1))
|
||||
task1 = "Detect {}.".format(", ".join(text_labels1))
|
||||
task2 = "Detect all the boat in the image."
|
||||
task3 = "Focus on the foreground, detect statue, torch and trees."
|
||||
return [classes1, classes2, classes3], [task1, task2, task3]
|
||||
return [text_labels1, text_labels2, text_labels3], [task1, task2, task3]
|
||||
|
||||
|
||||
@require_timm
|
||||
|
|
@ -683,8 +683,8 @@ class OmDetTurboModelIntegrationTests(unittest.TestCase):
|
|||
|
||||
processor = self.default_processor
|
||||
image = prepare_img()
|
||||
classes, task = prepare_text()
|
||||
encoding = processor(images=image, text=classes, task=task, return_tensors="pt").to(torch_device)
|
||||
text_labels, task = prepare_text()
|
||||
encoding = processor(images=image, text=text_labels, task=task, return_tensors="pt").to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**encoding)
|
||||
|
|
@ -706,7 +706,7 @@ class OmDetTurboModelIntegrationTests(unittest.TestCase):
|
|||
|
||||
# verify grounded postprocessing
|
||||
results = processor.post_process_grounded_object_detection(
|
||||
outputs, classes=[classes], target_sizes=[image.size[::-1]]
|
||||
outputs, text_labels=[text_labels], target_sizes=[image.size[::-1]]
|
||||
)[0]
|
||||
expected_scores = torch.tensor([0.7675, 0.7196, 0.5634, 0.5524]).to(torch_device)
|
||||
expected_slice_boxes = torch.tensor([39.8870, 70.3522, 176.7424, 118.0354]).to(torch_device)
|
||||
|
|
@ -715,8 +715,8 @@ class OmDetTurboModelIntegrationTests(unittest.TestCase):
|
|||
self.assertTrue(torch.allclose(results["scores"], expected_scores, atol=1e-2))
|
||||
self.assertTrue(torch.allclose(results["boxes"][0, :], expected_slice_boxes, atol=1e-2))
|
||||
|
||||
expected_classes = ["remote", "cat", "remote", "cat"]
|
||||
self.assertListEqual(results["classes"], expected_classes)
|
||||
expected_text_labels = ["remote", "cat", "remote", "cat"]
|
||||
self.assertListEqual(results["text_labels"], expected_text_labels)
|
||||
|
||||
def test_inference_object_detection_head_fp16(self):
|
||||
model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf").to(
|
||||
|
|
@ -725,8 +725,8 @@ class OmDetTurboModelIntegrationTests(unittest.TestCase):
|
|||
|
||||
processor = self.default_processor
|
||||
image = prepare_img()
|
||||
classes, task = prepare_text()
|
||||
encoding = processor(images=image, text=classes, task=task, return_tensors="pt").to(
|
||||
text_labels, task = prepare_text()
|
||||
encoding = processor(images=image, text=text_labels, task=task, return_tensors="pt").to(
|
||||
torch_device, dtype=torch.float16
|
||||
)
|
||||
|
||||
|
|
@ -750,7 +750,7 @@ class OmDetTurboModelIntegrationTests(unittest.TestCase):
|
|||
|
||||
# verify grounded postprocessing
|
||||
results = processor.post_process_grounded_object_detection(
|
||||
outputs, classes=[classes], target_sizes=[image.size[::-1]]
|
||||
outputs, text_labels=[text_labels], target_sizes=[image.size[::-1]]
|
||||
)[0]
|
||||
expected_scores = torch.tensor([0.7675, 0.7196, 0.5634, 0.5524]).to(torch_device, dtype=torch.float16)
|
||||
expected_slice_boxes = torch.tensor([39.8870, 70.3522, 176.7424, 118.0354]).to(
|
||||
|
|
@ -761,16 +761,16 @@ class OmDetTurboModelIntegrationTests(unittest.TestCase):
|
|||
self.assertTrue(torch.allclose(results["scores"], expected_scores, atol=1e-2))
|
||||
self.assertTrue(torch.allclose(results["boxes"][0, :], expected_slice_boxes, atol=1e-1))
|
||||
|
||||
expected_classes = ["remote", "cat", "remote", "cat"]
|
||||
self.assertListEqual(results["classes"], expected_classes)
|
||||
expected_text_labels = ["remote", "cat", "remote", "cat"]
|
||||
self.assertListEqual(results["text_labels"], expected_text_labels)
|
||||
|
||||
def test_inference_object_detection_head_no_task(self):
|
||||
model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf").to(torch_device)
|
||||
|
||||
processor = self.default_processor
|
||||
image = prepare_img()
|
||||
classes, _ = prepare_text()
|
||||
encoding = processor(images=image, text=classes, return_tensors="pt").to(torch_device)
|
||||
text_labels, _ = prepare_text()
|
||||
encoding = processor(images=image, text=text_labels, return_tensors="pt").to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**encoding)
|
||||
|
|
@ -792,7 +792,7 @@ class OmDetTurboModelIntegrationTests(unittest.TestCase):
|
|||
|
||||
# verify grounded postprocessing
|
||||
results = processor.post_process_grounded_object_detection(
|
||||
outputs, classes=[classes], target_sizes=[image.size[::-1]]
|
||||
outputs, text_labels=[text_labels], target_sizes=[image.size[::-1]]
|
||||
)[0]
|
||||
expected_scores = torch.tensor([0.7675, 0.7196, 0.5634, 0.5524]).to(torch_device)
|
||||
expected_slice_boxes = torch.tensor([39.8870, 70.3522, 176.7424, 118.0354]).to(torch_device)
|
||||
|
|
@ -801,8 +801,8 @@ class OmDetTurboModelIntegrationTests(unittest.TestCase):
|
|||
self.assertTrue(torch.allclose(results["scores"], expected_scores, atol=1e-2))
|
||||
self.assertTrue(torch.allclose(results["boxes"][0, :], expected_slice_boxes, atol=1e-2))
|
||||
|
||||
expected_classes = ["remote", "cat", "remote", "cat"]
|
||||
self.assertListEqual(results["classes"], expected_classes)
|
||||
expected_text_labels = ["remote", "cat", "remote", "cat"]
|
||||
self.assertListEqual(results["text_labels"], expected_text_labels)
|
||||
|
||||
def test_inference_object_detection_head_batched(self):
|
||||
torch_device = "cpu"
|
||||
|
|
@ -810,10 +810,10 @@ class OmDetTurboModelIntegrationTests(unittest.TestCase):
|
|||
|
||||
processor = self.default_processor
|
||||
images_batched = prepare_img_batched()
|
||||
classes_batched, tasks_batched = prepare_text_batched()
|
||||
encoding = processor(images=images_batched, text=classes_batched, task=tasks_batched, return_tensors="pt").to(
|
||||
torch_device
|
||||
)
|
||||
text_labels_batched, tasks_batched = prepare_text_batched()
|
||||
encoding = processor(
|
||||
images=images_batched, text=text_labels_batched, task=tasks_batched, return_tensors="pt"
|
||||
).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**encoding)
|
||||
|
|
@ -837,7 +837,7 @@ class OmDetTurboModelIntegrationTests(unittest.TestCase):
|
|||
# verify grounded postprocessing
|
||||
results = processor.post_process_grounded_object_detection(
|
||||
outputs,
|
||||
classes=classes_batched,
|
||||
text_labels=text_labels_batched,
|
||||
target_sizes=[image.size[::-1] for image in images_batched],
|
||||
score_threshold=0.2,
|
||||
)
|
||||
|
|
@ -858,19 +858,19 @@ class OmDetTurboModelIntegrationTests(unittest.TestCase):
|
|||
torch.allclose(torch.stack([result["boxes"][0, :] for result in results]), expected_slice_boxes, atol=1e-2)
|
||||
)
|
||||
|
||||
expected_classes = [
|
||||
expected_text_labels = [
|
||||
["remote", "cat", "remote", "cat"],
|
||||
["boat", "boat", "boat", "boat"],
|
||||
["statue", "trees", "trees", "torch", "statue", "statue"],
|
||||
]
|
||||
self.assertListEqual([result["classes"] for result in results], expected_classes)
|
||||
self.assertListEqual([result["text_labels"] for result in results], expected_text_labels)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_inference_object_detection_head_equivalence_cpu_gpu(self):
|
||||
processor = self.default_processor
|
||||
image = prepare_img()
|
||||
classes, task = prepare_text()
|
||||
encoding = processor(images=image, text=classes, task=task, return_tensors="pt")
|
||||
text_labels, task = prepare_text()
|
||||
encoding = processor(images=image, text=text_labels, task=task, return_tensors="pt")
|
||||
# 1. run model on CPU
|
||||
model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf")
|
||||
|
||||
|
|
@ -894,10 +894,10 @@ class OmDetTurboModelIntegrationTests(unittest.TestCase):
|
|||
|
||||
# verify grounded postprocessing
|
||||
results_cpu = processor.post_process_grounded_object_detection(
|
||||
cpu_outputs, classes=[classes], target_sizes=[image.size[::-1]]
|
||||
cpu_outputs, text_labels=[text_labels], target_sizes=[image.size[::-1]]
|
||||
)[0]
|
||||
result_gpu = processor.post_process_grounded_object_detection(
|
||||
gpu_outputs, classes=[classes], target_sizes=[image.size[::-1]]
|
||||
gpu_outputs, text_labels=[text_labels], target_sizes=[image.size[::-1]]
|
||||
)[0]
|
||||
|
||||
self.assertTrue(torch.allclose(results_cpu["scores"], result_gpu["scores"].cpu(), atol=1e-2))
|
||||
|
|
|
|||
|
|
@ -76,10 +76,13 @@ class OmDetTurboProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def get_fake_omdet_turbo_output(self):
|
||||
classes = self.get_fake_omdet_turbo_classes()
|
||||
classes_structure = torch.tensor([len(sublist) for sublist in classes])
|
||||
torch.manual_seed(42)
|
||||
return OmDetTurboObjectDetectionOutput(
|
||||
decoder_coord_logits=torch.rand(self.batch_size, self.num_queries, 4),
|
||||
decoder_class_logits=torch.rand(self.batch_size, self.num_queries, self.embed_dim),
|
||||
classes_structure=classes_structure,
|
||||
)
|
||||
|
||||
def get_fake_omdet_turbo_classes(self):
|
||||
|
|
@ -99,7 +102,7 @@ class OmDetTurboProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||
)
|
||||
|
||||
self.assertEqual(len(post_processed), self.batch_size)
|
||||
self.assertEqual(list(post_processed[0].keys()), ["boxes", "scores", "classes"])
|
||||
self.assertEqual(list(post_processed[0].keys()), ["boxes", "scores", "labels", "text_labels"])
|
||||
self.assertEqual(post_processed[0]["boxes"].shape, (self.num_queries, 4))
|
||||
self.assertEqual(post_processed[0]["scores"].shape, (self.num_queries,))
|
||||
expected_scores = torch.tensor([0.7310, 0.6579, 0.6513, 0.6444, 0.6252])
|
||||
|
|
|
|||
Loading…
Reference in a new issue