mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Normalize only if needed (#26049)
* Normalize only if needed * Update examples/pytorch/image-classification/run_image_classification.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * if else in one line * within block * one more place, sorry for mess * import order * Update examples/pytorch/image-classification/run_image_classification.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update examples/pytorch/image-classification/run_image_classification_no_trainer.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
576e2823a3
commit
e2d6d5ce57
2 changed files with 12 additions and 2 deletions
7
examples/pytorch/image-classification/run_image_classification.py
Normal file → Executable file
7
examples/pytorch/image-classification/run_image_classification.py
Normal file → Executable file
|
|
@ -28,6 +28,7 @@ from PIL import Image
|
|||
from torchvision.transforms import (
|
||||
CenterCrop,
|
||||
Compose,
|
||||
Lambda,
|
||||
Normalize,
|
||||
RandomHorizontalFlip,
|
||||
RandomResizedCrop,
|
||||
|
|
@ -325,7 +326,11 @@ def main():
|
|||
size = image_processor.size["shortest_edge"]
|
||||
else:
|
||||
size = (image_processor.size["height"], image_processor.size["width"])
|
||||
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
|
||||
normalize = (
|
||||
Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
|
||||
if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std")
|
||||
else Lambda(lambda x: x)
|
||||
)
|
||||
_train_transforms = Compose(
|
||||
[
|
||||
RandomResizedCrop(size),
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ from torch.utils.data import DataLoader
|
|||
from torchvision.transforms import (
|
||||
CenterCrop,
|
||||
Compose,
|
||||
Lambda,
|
||||
Normalize,
|
||||
RandomHorizontalFlip,
|
||||
RandomResizedCrop,
|
||||
|
|
@ -331,7 +332,11 @@ def main():
|
|||
size = image_processor.size["shortest_edge"]
|
||||
else:
|
||||
size = (image_processor.size["height"], image_processor.size["width"])
|
||||
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
|
||||
normalize = (
|
||||
Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
|
||||
if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std")
|
||||
else Lambda(lambda x: x)
|
||||
)
|
||||
train_transforms = Compose(
|
||||
[
|
||||
RandomResizedCrop(size),
|
||||
|
|
|
|||
Loading…
Reference in a new issue