mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Bug: training_args.py fix missing import with accelerate with version accelerate==0.20.1 (#28171)
* fix-accelerate-version * updated with exported ACCELERATE_MIN_VERSION, * update string in ACCELERATE_MIN_VERSION
This commit is contained in:
parent
c9fb250a25
commit
e37ab52dff
2 changed files with 5 additions and 2 deletions
|
|
@ -36,6 +36,7 @@ from .trainer_utils import (
|
|||
SchedulerType,
|
||||
)
|
||||
from .utils import (
|
||||
ACCELERATE_MIN_VERSION,
|
||||
ExplicitEnum,
|
||||
cached_property,
|
||||
is_accelerate_available,
|
||||
|
|
@ -1837,9 +1838,10 @@ class TrainingArguments:
|
|||
requires_backends(self, ["torch"])
|
||||
logger.info("PyTorch: setting up devices")
|
||||
if not is_sagemaker_mp_enabled():
|
||||
if not is_accelerate_available(min_version="0.20.1"):
|
||||
if not is_accelerate_available():
|
||||
raise ImportError(
|
||||
"Using the `Trainer` with `PyTorch` requires `accelerate>=0.20.1`: Please run `pip install transformers[torch]` or `pip install accelerate -U`"
|
||||
f"Using the `Trainer` with `PyTorch` requires `accelerate>={ACCELERATE_MIN_VERSION}`: "
|
||||
"Please run `pip install transformers[torch]` or `pip install accelerate -U`"
|
||||
)
|
||||
AcceleratorState._reset_state(reset_partial_state=True)
|
||||
self.distributed_state = None
|
||||
|
|
|
|||
|
|
@ -90,6 +90,7 @@ from .hub import (
|
|||
try_to_load_from_cache,
|
||||
)
|
||||
from .import_utils import (
|
||||
ACCELERATE_MIN_VERSION,
|
||||
ENV_VARS_TRUE_AND_AUTO_VALUES,
|
||||
ENV_VARS_TRUE_VALUES,
|
||||
TORCH_FX_REQUIRED_VERSION,
|
||||
|
|
|
|||
Loading…
Reference in a new issue