From e2d6d5ce57373a453e8c6fb4293cf08c17b061e1 Mon Sep 17 00:00:00 2001 From: Michal Jamroz Date: Tue, 24 Oct 2023 14:32:03 +0200 Subject: [PATCH] Normalize only if needed (#26049) * Normalize only if needed * Update examples/pytorch/image-classification/run_image_classification.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * if else in one line * within block * one more place, sorry for mess * import order * Update examples/pytorch/image-classification/run_image_classification.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update examples/pytorch/image-classification/run_image_classification_no_trainer.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../image-classification/run_image_classification.py | 7 ++++++- .../run_image_classification_no_trainer.py | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) mode change 100644 => 100755 examples/pytorch/image-classification/run_image_classification.py diff --git a/examples/pytorch/image-classification/run_image_classification.py b/examples/pytorch/image-classification/run_image_classification.py old mode 100644 new mode 100755 index 27a81f009..d00929754 --- a/examples/pytorch/image-classification/run_image_classification.py +++ b/examples/pytorch/image-classification/run_image_classification.py @@ -28,6 +28,7 @@ from PIL import Image from torchvision.transforms import ( CenterCrop, Compose, + Lambda, Normalize, RandomHorizontalFlip, RandomResizedCrop, @@ -325,7 +326,11 @@ def main(): size = image_processor.size["shortest_edge"] else: size = (image_processor.size["height"], image_processor.size["width"]) - normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std) + normalize = ( + Normalize(mean=image_processor.image_mean, std=image_processor.image_std) + if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std") + else Lambda(lambda x: x) + ) _train_transforms = Compose( [ RandomResizedCrop(size), diff --git a/examples/pytorch/image-classification/run_image_classification_no_trainer.py b/examples/pytorch/image-classification/run_image_classification_no_trainer.py index 3e38a3e79..52b5fabd8 100644 --- a/examples/pytorch/image-classification/run_image_classification_no_trainer.py +++ b/examples/pytorch/image-classification/run_image_classification_no_trainer.py @@ -32,6 +32,7 @@ from torch.utils.data import DataLoader from torchvision.transforms import ( CenterCrop, Compose, + Lambda, Normalize, RandomHorizontalFlip, RandomResizedCrop, @@ -331,7 +332,11 @@ def main(): size = image_processor.size["shortest_edge"] else: size = (image_processor.size["height"], image_processor.size["width"]) - normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std) + normalize = ( + Normalize(mean=image_processor.image_mean, std=image_processor.image_std) + if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std") + else Lambda(lambda x: x) + ) train_transforms = Compose( [ RandomResizedCrop(size),