Bug: training_args.py fix missing import with accelerate with version accelerate==0.20.1 (#28171)

* fix-accelerate-version

* updated with exported ACCELERATE_MIN_VERSION,

* update string in ACCELERATE_MIN_VERSION
This commit is contained in:
Michael Feil 2023-12-22 12:41:35 +01:00 committed by GitHub
parent c9fb250a25
commit e37ab52dff
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 2 deletions

View file

@ -36,6 +36,7 @@ from .trainer_utils import (
SchedulerType,
)
from .utils import (
ACCELERATE_MIN_VERSION,
ExplicitEnum,
cached_property,
is_accelerate_available,
@ -1837,9 +1838,10 @@ class TrainingArguments:
requires_backends(self, ["torch"])
logger.info("PyTorch: setting up devices")
if not is_sagemaker_mp_enabled():
if not is_accelerate_available(min_version="0.20.1"):
if not is_accelerate_available():
raise ImportError(
"Using the `Trainer` with `PyTorch` requires `accelerate>=0.20.1`: Please run `pip install transformers[torch]` or `pip install accelerate -U`"
f"Using the `Trainer` with `PyTorch` requires `accelerate>={ACCELERATE_MIN_VERSION}`: "
"Please run `pip install transformers[torch]` or `pip install accelerate -U`"
)
AcceleratorState._reset_state(reset_partial_state=True)
self.distributed_state = None

View file

@ -90,6 +90,7 @@ from .hub import (
try_to_load_from_cache,
)
from .import_utils import (
ACCELERATE_MIN_VERSION,
ENV_VARS_TRUE_AND_AUTO_VALUES,
ENV_VARS_TRUE_VALUES,
TORCH_FX_REQUIRED_VERSION,