From e37ab52dff8ba167938beac8f42349bda3198198 Mon Sep 17 00:00:00 2001 From: Michael Feil <63565275+michaelfeil@users.noreply.github.com> Date: Fri, 22 Dec 2023 12:41:35 +0100 Subject: [PATCH] Bug: `training_args.py` fix missing import with accelerate with version `accelerate==0.20.1` (#28171) * fix-accelerate-version * updated with exported ACCELERATE_MIN_VERSION, * update string in ACCELERATE_MIN_VERSION --- src/transformers/training_args.py | 6 ++++-- src/transformers/utils/__init__.py | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index bc1f1bfd5..9acb9d78d 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -36,6 +36,7 @@ from .trainer_utils import ( SchedulerType, ) from .utils import ( + ACCELERATE_MIN_VERSION, ExplicitEnum, cached_property, is_accelerate_available, @@ -1837,9 +1838,10 @@ class TrainingArguments: requires_backends(self, ["torch"]) logger.info("PyTorch: setting up devices") if not is_sagemaker_mp_enabled(): - if not is_accelerate_available(min_version="0.20.1"): + if not is_accelerate_available(): raise ImportError( - "Using the `Trainer` with `PyTorch` requires `accelerate>=0.20.1`: Please run `pip install transformers[torch]` or `pip install accelerate -U`" + f"Using the `Trainer` with `PyTorch` requires `accelerate>={ACCELERATE_MIN_VERSION}`: " + "Please run `pip install transformers[torch]` or `pip install accelerate -U`" ) AcceleratorState._reset_state(reset_partial_state=True) self.distributed_state = None diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 02959b329..7364a1c67 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -90,6 +90,7 @@ from .hub import ( try_to_load_from_cache, ) from .import_utils import ( + ACCELERATE_MIN_VERSION, ENV_VARS_TRUE_AND_AUTO_VALUES, ENV_VARS_TRUE_VALUES, TORCH_FX_REQUIRED_VERSION,