diff --git a/orttraining/orttraining/python/training/optim/_ds_modifier.py b/orttraining/orttraining/python/training/optim/_ds_modifier.py index d274975ce1..6ae6ccee51 100644 --- a/orttraining/orttraining/python/training/optim/_ds_modifier.py +++ b/orttraining/orttraining/python/training/optim/_ds_modifier.py @@ -10,10 +10,11 @@ # - has_overflow_partitioned_grads_serial : https://github.com/microsoft/DeepSpeed/blob/d8e9ef6f99e27bb95e10bd146d145b3372b4cfda/deepspeed/runtime/zero/stage2.py#L1799 # -------------------------------------------------------------------------- -import torch import types import warnings from distutils.version import LooseVersion + +import torch from numpy import inf from ._modifier import FP16OptimizerModifier, check_overflow, check_overflow_for_grads @@ -27,14 +28,11 @@ class DeepSpeedZeROModifier(FP16OptimizerModifier): super().__init__(optimizer) def can_be_modified(self): - try: - import deepspeed + import deepspeed - v = LooseVersion(deepspeed.__version__) - if v > LooseVersion("0.5.9") or v < LooseVersion("0.4.0"): - warnings.warn("Unsupported DeepSpeed version to override, skipped.", UserWarning) - return False - except Exception as _: + ds_version = LooseVersion(deepspeed.__version__) + if ds_version > LooseVersion("0.6.5") or ds_version < LooseVersion("0.4.0"): + warnings.warn("Skip modifying optimizer because of unsupported DeepSpeed version.", UserWarning) return False return self.check_requirements( @@ -141,14 +139,8 @@ class DeepSpeedZeROModifier(FP16OptimizerModifier): #### END OF THE ORIGINAL IMPLEMENTATION #### #### THIS IS THE FASTER IMPLEMENTATION #### - import deepspeed - - if LooseVersion(deepspeed.__version__) >= LooseVersion("0.5.7"): - fp16_groups = target.bit16_groups - else: - fp16_groups = target.fp16_groups - - for i in range(len(fp16_groups)): + groups = target.fp16_groups if hasattr(target, "fp16_groups") else target.bit16_groups + for i in range(len(groups)): grad_data = [grad.data for grad in target.averaged_gradients[i] if grad is not None] if check_overflow_for_grads(grad_data): return True diff --git a/orttraining/orttraining/python/training/optim/_modifier.py b/orttraining/orttraining/python/training/optim/_modifier.py index 9897ed4121..b3ad73110d 100644 --- a/orttraining/orttraining/python/training/optim/_modifier.py +++ b/orttraining/orttraining/python/training/optim/_modifier.py @@ -9,6 +9,7 @@ # -------------------------------------------------------------------------- import torch +import warnings from numpy import inf from ._multi_tensor_apply import MultiTensorApply @@ -32,12 +33,16 @@ class FP16OptimizerModifier(object): if require_torch_non_finite_check is True: _ = torch._amp_foreach_non_finite_check_and_unscale_ except Exception as _: + warnings.warn("Skip modifying optimizer because of Apex or torch_non_finite_check not found.", UserWarning) return False if required_funcs: for func_name in required_funcs: func = getattr(self._optimizer, func_name, None) if not func or not callable(func): + warnings.warn( + "Skip modifying optimizer because of specific function not found in optimizer.", UserWarning + ) return False return True diff --git a/orttraining/orttraining/python/training/optim/_modifier_registry.py b/orttraining/orttraining/python/training/optim/_modifier_registry.py index b19ecd6b06..4291b792a4 100644 --- a/orttraining/orttraining/python/training/optim/_modifier_registry.py +++ b/orttraining/orttraining/python/training/optim/_modifier_registry.py @@ -7,14 +7,9 @@ from ._ds_modifier import DeepSpeedZeROModifier from ._megatron_modifier import LegacyMegatronLMModifier from ._apex_amp_modifier import ApexAMPModifier -LEAGCY_MEGATRON_LM_OPTIMIZER_NAME = "megatron.fp16.fp16.FP16_Optimizer" -DEEPSPEED_ZERO1_AND_ZERO2_OPTIMIZER_NAME = "deepspeed.runtime.zero.stage2.FP16_DeepSpeedZeroOptimizer" -DEEPSPEED_ZERO1_AND_ZERO2_OPTIMIZER_NAME_1 = "deepspeed.runtime.zero.stage_1_and_2.DeepSpeedZeroOptimizer" -APEX_AMP_OPTIMIZER_NAME = "apex.amp.optimizer.unique_name_as_id" - OptimizerModifierTypeRegistry = { - LEAGCY_MEGATRON_LM_OPTIMIZER_NAME: LegacyMegatronLMModifier, - DEEPSPEED_ZERO1_AND_ZERO2_OPTIMIZER_NAME: DeepSpeedZeROModifier, - DEEPSPEED_ZERO1_AND_ZERO2_OPTIMIZER_NAME_1: DeepSpeedZeROModifier, - APEX_AMP_OPTIMIZER_NAME: ApexAMPModifier, + "megatron.fp16.fp16.FP16_Optimizer": LegacyMegatronLMModifier, + "deepspeed.runtime.zero.stage2.FP16_DeepSpeedZeroOptimizer": DeepSpeedZeROModifier, + "deepspeed.runtime.zero.stage_1_and_2.DeepSpeedZeroOptimizer": DeepSpeedZeROModifier, + "apex.amp.optimizer.unique_name_as_id": ApexAMPModifier, } diff --git a/orttraining/orttraining/python/training/optim/fp16_optimizer.py b/orttraining/orttraining/python/training/optim/fp16_optimizer.py index c4c353249f..c3864ea711 100644 --- a/orttraining/orttraining/python/training/optim/fp16_optimizer.py +++ b/orttraining/orttraining/python/training/optim/fp16_optimizer.py @@ -3,6 +3,8 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- +import warnings + from ._modifier_registry import OptimizerModifierTypeRegistry @@ -90,6 +92,7 @@ def FP16_Optimizer(optimizer, **kwargs): optimizer_full_qualified_name = get_full_qualified_type_name(optimizer) if optimizer_full_qualified_name not in OptimizerModifierTypeRegistry: + warnings.warn("Skip modifying optimizer because of optimizer name not found in registry.", UserWarning) return optimizer modifier = OptimizerModifierTypeRegistry[optimizer_full_qualified_name](optimizer, **kwargs)