mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Including support for Deepspeed 0.8.0 (#14506)
### Description Including Support for Deepspeed 0.8.0. ### Motivation and Context Deepspeed 0.8.0 has a bug fix and mlfow integration.
This commit is contained in:
parent
d06ad9462b
commit
6fa4555a06
1 changed files with 14 additions and 1 deletions
|
|
@ -39,7 +39,7 @@ class DeepSpeedZeROModifier(FP16OptimizerModifier):
|
|||
# it's safe to update the version supporting list. Otherwise, or the file is moved or renamed,
|
||||
# we need to check the implementation of these functions in detail.
|
||||
ds_version = Version(deepspeed.__version__)
|
||||
if ds_version > Version("0.7.3") or ds_version < Version("0.4.0"):
|
||||
if ds_version > Version("0.8.0") or ds_version < Version("0.4.0"):
|
||||
warnings.warn(
|
||||
"Skip modifying optimizer because of unsupported DeepSpeed version {}, "
|
||||
"supported version: 0.4.0 - 0.7.3.".format(deepspeed.__version__),
|
||||
|
|
@ -47,6 +47,19 @@ class DeepSpeedZeROModifier(FP16OptimizerModifier):
|
|||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
from deepspeed.accelerator import get_accelerator
|
||||
except ImportError as e:
|
||||
warnings.warn("Unable to import get_accelerator from deepspeed.accelerator", UserWarning)
|
||||
else:
|
||||
if not get_accelerator().device_name().startswith("cuda"):
|
||||
warnings.warn(
|
||||
"Skip modifying optimizer as device is not supported, "
|
||||
"device name: {}".format(get_accelerator().device_name()),
|
||||
UserWarning,
|
||||
)
|
||||
return False
|
||||
|
||||
return self.check_requirements(
|
||||
["has_overflow_serial", "get_grad_norm_direct", "has_overflow_partitioned_grads_serial"],
|
||||
require_apex=False,
|
||||
|
|
|
|||
Loading…
Reference in a new issue