diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index e9cb93db6..b3a25a8be 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -376,6 +376,11 @@ def normalize( channel_axis = get_channel_dimension_axis(image, input_data_format=input_data_format) num_channels = image.shape[channel_axis] + # We cast to float32 to avoid errors that can occur when subtracting uint8 values. + # We preserve the original dtype if it is a float type to prevent upcasting float16. + if not np.issubdtype(image.dtype, np.floating): + image = image.astype(np.float32) + if isinstance(mean, Iterable): if len(mean) != num_channels: raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}") diff --git a/tests/test_image_transforms.py b/tests/test_image_transforms.py index 2941685e6..ae86f84de 100644 --- a/tests/test_image_transforms.py +++ b/tests/test_image_transforms.py @@ -302,7 +302,7 @@ class ImageTransformsTester(unittest.TestCase): normalized_image = normalize(image, mean=mean, std=std, data_format="channels_first") self.assertIsInstance(normalized_image, np.ndarray) self.assertEqual(normalized_image.shape, (3, 224, 224)) - self.assertTrue(np.allclose(normalized_image, expected_image)) + self.assertTrue(np.allclose(normalized_image, expected_image, atol=1e-6)) # Test image with 4 channels is normalized correctly image = np.random.randint(0, 256, (224, 224, 4)) / 255 @@ -310,9 +310,42 @@ class ImageTransformsTester(unittest.TestCase): std = (0.1, 0.2, 0.3, 0.4) expected_image = (image - mean) / std self.assertTrue( - np.allclose(normalize(image, mean=mean, std=std, input_data_format="channels_last"), expected_image) + np.allclose( + normalize(image, mean=mean, std=std, input_data_format="channels_last"), expected_image, atol=1e-6 + ) ) + # Test float32 image input keeps float32 dtype + image = np.random.randint(0, 256, (224, 224, 3)).astype(np.float32) / 255 + mean = (0.5, 0.6, 0.7) + std = (0.1, 0.2, 0.3) + expected_image = ((image - mean) / std).astype(np.float32) + normalized_image = normalize(image, mean=mean, std=std) + self.assertEqual(normalized_image.dtype, np.float32) + self.assertTrue(np.allclose(normalized_image, expected_image, atol=1e-6)) + + # Test float16 image input keeps float16 dtype + image = np.random.randint(0, 256, (224, 224, 3)).astype(np.float16) / 255 + mean = (0.5, 0.6, 0.7) + std = (0.1, 0.2, 0.3) + + # The mean and std are cast to match the dtype of the input image + cast_mean = np.array(mean, dtype=np.float16) + cast_std = np.array(std, dtype=np.float16) + expected_image = (image - cast_mean) / cast_std + normalized_image = normalize(image, mean=mean, std=std) + self.assertEqual(normalized_image.dtype, np.float16) + self.assertTrue(np.allclose(normalized_image, expected_image, atol=1e-6)) + + # Test int image input is converted to float32 + image = np.random.randint(0, 2, (224, 224, 3), dtype=np.uint8) + mean = (0.5, 0.6, 0.7) + std = (0.1, 0.2, 0.3) + expected_image = (image.astype(np.float32) - mean) / std + normalized_image = normalize(image, mean=mean, std=std) + self.assertEqual(normalized_image.dtype, np.float32) + self.assertTrue(np.allclose(normalized_image, expected_image, atol=1e-6)) + def test_center_crop(self): image = np.random.randint(0, 256, (3, 224, 224))