Normalize only if needed (#26049)

* Normalize only if needed

* Update examples/pytorch/image-classification/run_image_classification.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* if else in one line

* within block

* one more place, sorry for mess

* import order

* Update examples/pytorch/image-classification/run_image_classification.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update examples/pytorch/image-classification/run_image_classification_no_trainer.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Michal Jamroz 2023-10-24 14:32:03 +02:00 committed by GitHub
parent 576e2823a3
commit e2d6d5ce57
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 2 deletions

View file

@ -28,6 +28,7 @@ from PIL import Image
from torchvision.transforms import (
CenterCrop,
Compose,
Lambda,
Normalize,
RandomHorizontalFlip,
RandomResizedCrop,
@ -325,7 +326,11 @@ def main():
size = image_processor.size["shortest_edge"]
else:
size = (image_processor.size["height"], image_processor.size["width"])
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
normalize = (
Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std")
else Lambda(lambda x: x)
)
_train_transforms = Compose(
[
RandomResizedCrop(size),

View file

@ -32,6 +32,7 @@ from torch.utils.data import DataLoader
from torchvision.transforms import (
CenterCrop,
Compose,
Lambda,
Normalize,
RandomHorizontalFlip,
RandomResizedCrop,
@ -331,7 +332,11 @@ def main():
size = image_processor.size["shortest_edge"]
else:
size = (image_processor.size["height"], image_processor.size["width"])
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
normalize = (
Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std")
else Lambda(lambda x: x)
)
train_transforms = Compose(
[
RandomResizedCrop(size),