mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
### FP16 optimizer automatically detect DeepSpeed compatibility Optimum/Transformers are using accelerate lib to prepare models, so our FP16 optimizer wrapper does not work for long time. Because the namespace is `accelerate.utils.deepspeed.DeepSpeedOptimizerWrapper`, which underlying is still calling into DeepSpeed stage1and2 optimizer. This PR includes following changes: 1. Add `accelerate.utils.deepspeed.DeepSpeedOptimizerWrapper` in the modifier registry, plus a check on its contained `optimizer` property MUST be DeepSpeed stage 1 and 2 optimizer. (let's cover Stage 3 optimizer later) 2. For DeepSpeed version > 0.9.1, we will store the source code in a version list. As long as the related function in DeepSpeed remains unchanged during its new release, we won't need manually upgrade the version check any more. If some day, the source code did not match, a warning will be raised to users, to add a new version of source code in the list. With the above change, we will have our FP16 Optimizer working again in Optimum. 
91 lines
3.8 KiB
Python
91 lines
3.8 KiB
Python
# -------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
# --------------------------------------------------------------------------
|
|
|
|
|
|
from ._modifier_registry import OptimizerModifierTypeRegistry, get_full_qualified_type_name
|
|
|
|
|
|
def FP16_Optimizer(optimizer, **kwargs): # noqa: N802
|
|
"""
|
|
Simple wrapper to replace inefficient FP16_Optimizer function calls implemented by libraries for example
|
|
Apex, DeepSpeed, Megatron-LM.
|
|
|
|
Usage:
|
|
1. DeepSpeed ZeRO Optimizer Override:
|
|
|
|
>>> from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer
|
|
>>> optimizer = Adam(param_groups,
|
|
>>> lr=args.lr,
|
|
>>> weight_decay=args.weight_decay,
|
|
>>> betas=(args.adam_beta1, args.adam_beta2),
|
|
>>> eps=args.adam_eps)
|
|
|
|
>>> model, optimizer, _, lr_scheduler = deepspeed.initialize(
|
|
>>> model=model,
|
|
>>> optimizer=optimizer,
|
|
>>> args=args,
|
|
>>> lr_scheduler=lr_scheduler,
|
|
>>> mpu=mpu,
|
|
>>> dist_init_required=False)
|
|
>>> if args.fp16:
|
|
>>> optimizer = FP16_Optimizer(optimizer)
|
|
|
|
2. Megatron-LM-v1.1.5 Optimizer Override:
|
|
|
|
>>> from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer
|
|
>>> optimizer = Adam(param_groups,
|
|
>>> lr=args.lr,
|
|
>>> weight_decay=args.weight_decay,
|
|
>>> betas=(args.adam_beta1, args.adam_beta2),
|
|
>>> eps=args.adam_eps)
|
|
|
|
>>> # Wrap into fp16 optimizer.
|
|
>>> if args.fp16:
|
|
>>> optimizer = FP16_Optimizer(optimizer,
|
|
>>> static_loss_scale=args.loss_scale,
|
|
>>> dynamic_loss_scale=args.dynamic_loss_scale,
|
|
>>> dynamic_loss_args={
|
|
>>> 'scale_window': args.loss_scale_window,
|
|
>>> 'min_scale': args.min_scale,
|
|
>>> 'delayed_shift': args.hysteresis},
|
|
>>> verbose=True)
|
|
>>> optimizer = ORT_FP16_Optimizer(optimizer,
|
|
>>> get_tensor_model_parallel_rank=mpu.get_model_parallel_rank,
|
|
>>> get_tensor_model_parallel_group=mpu.get_model_parallel_group)
|
|
|
|
3. APEX AMP Override:
|
|
|
|
>>> from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer
|
|
>>> optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
|
|
|
|
>>> model, optimizer = amp.initialize(model, optimizer, opt_level="O2")
|
|
>>> optimizer = ORT_FP16_Optimizer(optimizer)
|
|
>>>
|
|
>>> # Wrap model with ORTModule tricks
|
|
>>> def patch_new_fwd(old_new_fwd):
|
|
>>> def new_new_fwd(self, *args, **kwargs):
|
|
>>> return old_new_fwd(*args, **kwargs)
|
|
>>> return new_new_fwd
|
|
|
|
>>> model.forward = types.MethodType(patch_new_fwd(model.forward), model)
|
|
>>> model = ORTModule(model)
|
|
Args:
|
|
optimizer: the FP16_Optimizer instance
|
|
|
|
Returns:
|
|
The modified FP16_Optimizer instance
|
|
|
|
"""
|
|
|
|
optimizer_full_qualified_name = (
|
|
"apex.amp.optimizer.unique_name_as_id"
|
|
if hasattr(optimizer, "_amp_stash")
|
|
else get_full_qualified_type_name(optimizer)
|
|
)
|
|
modifier = OptimizerModifierTypeRegistry.create_modifier(optimizer_full_qualified_name, optimizer, **kwargs)
|
|
if modifier is not None:
|
|
modifier.apply()
|
|
|
|
return optimizer
|