From 6fa4555a06e8b5ec588ff1d7880ff42eb6ba2007 Mon Sep 17 00:00:00 2001 From: Abhishek Jindal Date: Wed, 1 Feb 2023 06:19:41 -0800 Subject: [PATCH] Including support for Deepspeed 0.8.0 (#14506) ### Description Including Support for Deepspeed 0.8.0. ### Motivation and Context Deepspeed 0.8.0 has a bug fix and mlfow integration. --- .../python/training/optim/_ds_modifier.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/orttraining/orttraining/python/training/optim/_ds_modifier.py b/orttraining/orttraining/python/training/optim/_ds_modifier.py index 9e3470d33c..dc2d4c091c 100644 --- a/orttraining/orttraining/python/training/optim/_ds_modifier.py +++ b/orttraining/orttraining/python/training/optim/_ds_modifier.py @@ -39,7 +39,7 @@ class DeepSpeedZeROModifier(FP16OptimizerModifier): # it's safe to update the version supporting list. Otherwise, or the file is moved or renamed, # we need to check the implementation of these functions in detail. ds_version = Version(deepspeed.__version__) - if ds_version > Version("0.7.3") or ds_version < Version("0.4.0"): + if ds_version > Version("0.8.0") or ds_version < Version("0.4.0"): warnings.warn( "Skip modifying optimizer because of unsupported DeepSpeed version {}, " "supported version: 0.4.0 - 0.7.3.".format(deepspeed.__version__), @@ -47,6 +47,19 @@ class DeepSpeedZeROModifier(FP16OptimizerModifier): ) return False + try: + from deepspeed.accelerator import get_accelerator + except ImportError as e: + warnings.warn("Unable to import get_accelerator from deepspeed.accelerator", UserWarning) + else: + if not get_accelerator().device_name().startswith("cuda"): + warnings.warn( + "Skip modifying optimizer as device is not supported, " + "device name: {}".format(get_accelerator().device_name()), + UserWarning, + ) + return False + return self.check_requirements( ["has_overflow_serial", "get_grad_norm_direct", "has_overflow_partitioned_grads_serial"], require_apex=False,