From 491e9518754c357c448be775e89a4c4644554635 Mon Sep 17 00:00:00 2001 From: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Date: Thu, 15 Dec 2022 18:47:04 +0000 Subject: [PATCH] Move convert_to_rgb to image_transforms module (#20784) * Move convert_to_rgb to image_transforms module * Fix tests --- src/transformers/image_transforms.py | 20 +++++++++++++ .../models/bit/image_processing_bit.py | 17 ++--------- .../image_processing_chinese_clip.py | 17 ++--------- .../models/clip/image_processing_clip.py | 17 ++--------- .../vit_hybrid/image_processing_vit_hybrid.py | 18 ++--------- tests/test_image_transforms.py | 30 +++++++++++++++++++ 6 files changed, 58 insertions(+), 61 deletions(-) diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index 0129107a8..d09f29b79 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -20,6 +20,7 @@ import numpy as np from transformers.image_utils import ( ChannelDimension, + ImageInput, get_channel_dimension_axis, get_image_size, infer_channel_dimension_format, @@ -687,3 +688,22 @@ def pad( image = to_channel_dimension_format(image, data_format) if data_format is not None else image return image + + +# TODO (Amy): Accept 1/3/4 channel numpy array as input and return np.array as default +def convert_to_rgb(image: ImageInput) -> ImageInput: + """ + Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image + as is. + + Args: + image (Image): + The image to convert. + """ + requires_backends(convert_to_rgb, ["vision"]) + + if not isinstance(image, PIL.Image.Image): + return image + + image = image.convert("RGB") + return image diff --git a/src/transformers/models/bit/image_processing_bit.py b/src/transformers/models/bit/image_processing_bit.py index 8f2fa9524..f210ad30d 100644 --- a/src/transformers/models/bit/image_processing_bit.py +++ b/src/transformers/models/bit/image_processing_bit.py @@ -14,7 +14,7 @@ # limitations under the License. """Image processor class for BiT.""" -from typing import Any, Dict, List, Optional, Union +from typing import Dict, List, Optional, Union import numpy as np @@ -23,6 +23,7 @@ from transformers.utils.generic import TensorType from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_transforms import ( center_crop, + convert_to_rgb, get_resize_output_image_size, normalize, rescale, @@ -41,20 +42,6 @@ if is_vision_available(): import PIL -def convert_to_rgb(image: Union[Any, PIL.Image.Image]) -> Union[Any, PIL.Image.Image]: - """ - Converts `PIL.Image.Image` to RGB format. Images in other formats are returned as is. - - Args: - image (`PIL.Image.Image`): - The image to convert. - """ - if not isinstance(image, PIL.Image.Image): - return image - - return image.convert("RGB") - - class BitImageProcessor(BaseImageProcessor): r""" Constructs a BiT image processor. diff --git a/src/transformers/models/chinese_clip/image_processing_chinese_clip.py b/src/transformers/models/chinese_clip/image_processing_chinese_clip.py index a38c79054..593ba05f8 100644 --- a/src/transformers/models/chinese_clip/image_processing_chinese_clip.py +++ b/src/transformers/models/chinese_clip/image_processing_chinese_clip.py @@ -14,7 +14,7 @@ # limitations under the License. """Image processor class for Chinese-CLIP.""" -from typing import Any, Dict, List, Optional, Union +from typing import Dict, List, Optional, Union import numpy as np @@ -23,6 +23,7 @@ from transformers.utils.generic import TensorType from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_transforms import ( center_crop, + convert_to_rgb, get_resize_output_image_size, normalize, rescale, @@ -41,20 +42,6 @@ if is_vision_available(): import PIL -def convert_to_rgb(image: Union[Any, PIL.Image.Image]) -> Union[Any, PIL.Image.Image]: - """ - Converts `PIL.Image.Image` to RGB format. Images in other formats are returned as is. - - Args: - image (`PIL.Image.Image`): - The image to convert. - """ - if not isinstance(image, PIL.Image.Image): - return image - - return image.convert("RGB") - - class ChineseCLIPImageProcessor(BaseImageProcessor): r""" Constructs a Chinese-CLIP image processor. diff --git a/src/transformers/models/clip/image_processing_clip.py b/src/transformers/models/clip/image_processing_clip.py index a30d1cad4..2f660a622 100644 --- a/src/transformers/models/clip/image_processing_clip.py +++ b/src/transformers/models/clip/image_processing_clip.py @@ -14,7 +14,7 @@ # limitations under the License. """Image processor class for CLIP.""" -from typing import Any, Dict, List, Optional, Union +from typing import Dict, List, Optional, Union import numpy as np @@ -23,6 +23,7 @@ from transformers.utils.generic import TensorType from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_transforms import ( center_crop, + convert_to_rgb, get_resize_output_image_size, normalize, rescale, @@ -41,20 +42,6 @@ if is_vision_available(): import PIL -def convert_to_rgb(image: Union[Any, PIL.Image.Image]) -> Union[Any, PIL.Image.Image]: - """ - Converts `PIL.Image.Image` to RGB format. Images in other formats are returned as is. - - Args: - image (`PIL.Image.Image`): - The image to convert. - """ - if not isinstance(image, PIL.Image.Image): - return image - - return image.convert("RGB") - - class CLIPImageProcessor(BaseImageProcessor): r""" Constructs a CLIP image processor. diff --git a/src/transformers/models/vit_hybrid/image_processing_vit_hybrid.py b/src/transformers/models/vit_hybrid/image_processing_vit_hybrid.py index 296346a54..2cd007470 100644 --- a/src/transformers/models/vit_hybrid/image_processing_vit_hybrid.py +++ b/src/transformers/models/vit_hybrid/image_processing_vit_hybrid.py @@ -14,7 +14,7 @@ # limitations under the License. """Image processor class for ViT hybrid.""" -from typing import Any, Dict, List, Optional, Union +from typing import Dict, List, Optional, Union import numpy as np @@ -23,6 +23,7 @@ from transformers.utils.generic import TensorType from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_transforms import ( center_crop, + convert_to_rgb, get_resize_output_image_size, normalize, rescale, @@ -41,21 +42,6 @@ if is_vision_available(): import PIL -# Copied from transformers.models.bit.image_processing_bit.convert_to_rgb -def convert_to_rgb(image: Union[Any, PIL.Image.Image]) -> Union[Any, PIL.Image.Image]: - """ - Converts `PIL.Image.Image` to RGB format. Images in other formats are returned as is. - - Args: - image (`PIL.Image.Image`): - The image to convert. - """ - if not isinstance(image, PIL.Image.Image): - return image - - return image.convert("RGB") - - class ViTHybridImageProcessor(BaseImageProcessor): r""" Constructs a ViT Hybrid image processor. diff --git a/tests/test_image_transforms.py b/tests/test_image_transforms.py index f22657f63..206c8dc5b 100644 --- a/tests/test_image_transforms.py +++ b/tests/test_image_transforms.py @@ -37,6 +37,7 @@ if is_vision_available(): from transformers.image_transforms import ( center_crop, center_to_corners_format, + convert_to_rgb, corners_to_center_format, get_resize_output_image_size, id_to_rgb, @@ -456,3 +457,32 @@ class ImageTransformsTester(unittest.TestCase): self.assertTrue( np.allclose(expected_image, pad(image, ((0, 2), (2, 1)), mode="reflect", data_format="channels_last")) ) + + @require_vision + def test_convert_to_rgb(self): + # Test that an RGBA image is converted to RGB + image = np.array([[[1, 2, 3, 4], [5, 6, 7, 8]]], dtype=np.uint8) + pil_image = PIL.Image.fromarray(image) + self.assertEqual(pil_image.mode, "RGBA") + self.assertEqual(pil_image.size, (2, 1)) + + # For the moment, numpy images are returned as is + rgb_image = convert_to_rgb(image) + self.assertEqual(rgb_image.shape, (1, 2, 4)) + self.assertTrue(np.allclose(rgb_image, image)) + + # And PIL images are converted + rgb_image = convert_to_rgb(pil_image) + self.assertEqual(rgb_image.mode, "RGB") + self.assertEqual(rgb_image.size, (2, 1)) + self.assertTrue(np.allclose(np.array(rgb_image), np.array([[[1, 2, 3], [5, 6, 7]]], dtype=np.uint8))) + + # Test that a grayscale image is converted to RGB + image = np.array([[0, 255]], dtype=np.uint8) + pil_image = PIL.Image.fromarray(image) + self.assertEqual(pil_image.mode, "L") + self.assertEqual(pil_image.size, (2, 1)) + rgb_image = convert_to_rgb(pil_image) + self.assertEqual(rgb_image.mode, "RGB") + self.assertEqual(rgb_image.size, (2, 1)) + self.assertTrue(np.allclose(np.array(rgb_image), np.array([[[0, 0, 0], [255, 255, 255]]], dtype=np.uint8)))