mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
FP16_Optimizer Support for more Deepspeed Versions (#12046)
* fp16_optimizer for more ds versions * change ds version * bugfix * fix bug
This commit is contained in:
parent
ecca6f4d16
commit
04f7c2deda
4 changed files with 20 additions and 25 deletions
|
|
@ -10,10 +10,11 @@
|
|||
# - has_overflow_partitioned_grads_serial : https://github.com/microsoft/DeepSpeed/blob/d8e9ef6f99e27bb95e10bd146d145b3372b4cfda/deepspeed/runtime/zero/stage2.py#L1799
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import torch
|
||||
import types
|
||||
import warnings
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import torch
|
||||
from numpy import inf
|
||||
|
||||
from ._modifier import FP16OptimizerModifier, check_overflow, check_overflow_for_grads
|
||||
|
|
@ -27,14 +28,11 @@ class DeepSpeedZeROModifier(FP16OptimizerModifier):
|
|||
super().__init__(optimizer)
|
||||
|
||||
def can_be_modified(self):
|
||||
try:
|
||||
import deepspeed
|
||||
import deepspeed
|
||||
|
||||
v = LooseVersion(deepspeed.__version__)
|
||||
if v > LooseVersion("0.5.9") or v < LooseVersion("0.4.0"):
|
||||
warnings.warn("Unsupported DeepSpeed version to override, skipped.", UserWarning)
|
||||
return False
|
||||
except Exception as _:
|
||||
ds_version = LooseVersion(deepspeed.__version__)
|
||||
if ds_version > LooseVersion("0.6.5") or ds_version < LooseVersion("0.4.0"):
|
||||
warnings.warn("Skip modifying optimizer because of unsupported DeepSpeed version.", UserWarning)
|
||||
return False
|
||||
|
||||
return self.check_requirements(
|
||||
|
|
@ -141,14 +139,8 @@ class DeepSpeedZeROModifier(FP16OptimizerModifier):
|
|||
#### END OF THE ORIGINAL IMPLEMENTATION ####
|
||||
|
||||
#### THIS IS THE FASTER IMPLEMENTATION ####
|
||||
import deepspeed
|
||||
|
||||
if LooseVersion(deepspeed.__version__) >= LooseVersion("0.5.7"):
|
||||
fp16_groups = target.bit16_groups
|
||||
else:
|
||||
fp16_groups = target.fp16_groups
|
||||
|
||||
for i in range(len(fp16_groups)):
|
||||
groups = target.fp16_groups if hasattr(target, "fp16_groups") else target.bit16_groups
|
||||
for i in range(len(groups)):
|
||||
grad_data = [grad.data for grad in target.averaged_gradients[i] if grad is not None]
|
||||
if check_overflow_for_grads(grad_data):
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
# --------------------------------------------------------------------------
|
||||
|
||||
import torch
|
||||
import warnings
|
||||
from numpy import inf
|
||||
from ._multi_tensor_apply import MultiTensorApply
|
||||
|
||||
|
|
@ -32,12 +33,16 @@ class FP16OptimizerModifier(object):
|
|||
if require_torch_non_finite_check is True:
|
||||
_ = torch._amp_foreach_non_finite_check_and_unscale_
|
||||
except Exception as _:
|
||||
warnings.warn("Skip modifying optimizer because of Apex or torch_non_finite_check not found.", UserWarning)
|
||||
return False
|
||||
|
||||
if required_funcs:
|
||||
for func_name in required_funcs:
|
||||
func = getattr(self._optimizer, func_name, None)
|
||||
if not func or not callable(func):
|
||||
warnings.warn(
|
||||
"Skip modifying optimizer because of specific function not found in optimizer.", UserWarning
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
|
|
|||
|
|
@ -7,14 +7,9 @@ from ._ds_modifier import DeepSpeedZeROModifier
|
|||
from ._megatron_modifier import LegacyMegatronLMModifier
|
||||
from ._apex_amp_modifier import ApexAMPModifier
|
||||
|
||||
LEAGCY_MEGATRON_LM_OPTIMIZER_NAME = "megatron.fp16.fp16.FP16_Optimizer"
|
||||
DEEPSPEED_ZERO1_AND_ZERO2_OPTIMIZER_NAME = "deepspeed.runtime.zero.stage2.FP16_DeepSpeedZeroOptimizer"
|
||||
DEEPSPEED_ZERO1_AND_ZERO2_OPTIMIZER_NAME_1 = "deepspeed.runtime.zero.stage_1_and_2.DeepSpeedZeroOptimizer"
|
||||
APEX_AMP_OPTIMIZER_NAME = "apex.amp.optimizer.unique_name_as_id"
|
||||
|
||||
OptimizerModifierTypeRegistry = {
|
||||
LEAGCY_MEGATRON_LM_OPTIMIZER_NAME: LegacyMegatronLMModifier,
|
||||
DEEPSPEED_ZERO1_AND_ZERO2_OPTIMIZER_NAME: DeepSpeedZeROModifier,
|
||||
DEEPSPEED_ZERO1_AND_ZERO2_OPTIMIZER_NAME_1: DeepSpeedZeROModifier,
|
||||
APEX_AMP_OPTIMIZER_NAME: ApexAMPModifier,
|
||||
"megatron.fp16.fp16.FP16_Optimizer": LegacyMegatronLMModifier,
|
||||
"deepspeed.runtime.zero.stage2.FP16_DeepSpeedZeroOptimizer": DeepSpeedZeROModifier,
|
||||
"deepspeed.runtime.zero.stage_1_and_2.DeepSpeedZeroOptimizer": DeepSpeedZeROModifier,
|
||||
"apex.amp.optimizer.unique_name_as_id": ApexAMPModifier,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@
|
|||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import warnings
|
||||
|
||||
from ._modifier_registry import OptimizerModifierTypeRegistry
|
||||
|
||||
|
||||
|
|
@ -90,6 +92,7 @@ def FP16_Optimizer(optimizer, **kwargs):
|
|||
|
||||
optimizer_full_qualified_name = get_full_qualified_type_name(optimizer)
|
||||
if optimizer_full_qualified_name not in OptimizerModifierTypeRegistry:
|
||||
warnings.warn("Skip modifying optimizer because of optimizer name not found in registry.", UserWarning)
|
||||
return optimizer
|
||||
|
||||
modifier = OptimizerModifierTypeRegistry[optimizer_full_qualified_name](optimizer, **kwargs)
|
||||
|
|
|
|||
Loading…
Reference in a new issue