From c81ebd1c39e3e1bc017a3affbba096dc9aedb5a0 Mon Sep 17 00:00:00 2001 From: Alara Dirik <8944735+alaradirik@users.noreply.github.com> Date: Tue, 20 Sep 2022 10:41:56 +0300 Subject: [PATCH] Beit postprocessing (#19099) * add post_process_semantic_segmentation method to BeiTFeatureExtractor --- docs/source/en/model_doc/beit.mdx | 1 + .../models/beit/feature_extraction_beit.py | 48 ++++++++++++++++++- 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/docs/source/en/model_doc/beit.mdx b/docs/source/en/model_doc/beit.mdx index f8177443d..689eadc70 100644 --- a/docs/source/en/model_doc/beit.mdx +++ b/docs/source/en/model_doc/beit.mdx @@ -82,6 +82,7 @@ contributed by [kamalkraj](https://huggingface.co/kamalkraj). The original code [[autodoc]] BeitFeatureExtractor - __call__ + - post_process_semantic_segmentation ## BeitModel diff --git a/src/transformers/models/beit/feature_extraction_beit.py b/src/transformers/models/beit/feature_extraction_beit.py index 62b790621..eac1ba8e3 100644 --- a/src/transformers/models/beit/feature_extraction_beit.py +++ b/src/transformers/models/beit/feature_extraction_beit.py @@ -14,7 +14,7 @@ # limitations under the License. """Feature extractor class for BEiT.""" -from typing import Optional, Union +from typing import List, Optional, Tuple, Union import numpy as np from PIL import Image @@ -27,9 +27,12 @@ from ...image_utils import ( ImageInput, is_torch_tensor, ) -from ...utils import TensorType, logging +from ...utils import TensorType, is_torch_available, logging +if is_torch_available(): + import torch + logger = logging.get_logger(__name__) @@ -222,3 +225,44 @@ class BeitFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) return encoded_inputs + + def post_process_semantic_segmentation(self, outputs, target_sizes: Union[TensorType, List[Tuple]] = None): + """ + Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. + + Args: + outputs ([`BeitForSemanticSegmentation`]): + Raw outputs of the model. + target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`, *optional*): + Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction. If left to + None, predictions will not be resized. + Returns: + semantic_segmentation: `torch.Tensor` of shape `(batch_size, 2)` or `List[torch.Tensor]` of length + `batch_size`, where each item is a semantic segmentation map of of the corresponding target_sizes entry (if + `target_sizes` is specified). Each entry of each `torch.Tensor` correspond to a semantic class id. + """ + logits = outputs.logits + + if len(logits) != len(target_sizes): + raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits") + + if target_sizes is not None and target_sizes.shape[1] != 2: + raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") + + semantic_segmentation = logits.argmax(dim=1) + + # Resize semantic segmentation maps + if target_sizes is not None: + if is_torch_tensor(target_sizes): + target_sizes = target_sizes.numpy() + + resized_maps = [] + semantic_segmentation = semantic_segmentation.numpy() + + for idx in range(len(semantic_segmentation)): + resized = self.resize(image=semantic_segmentation[idx], size=target_sizes[idx]) + resized_maps.append(resized) + + semantic_segmentation = [torch.Tensor(np.array(image)) for image in resized_maps] + + return semantic_segmentation