onnxruntime/orttraining/orttraining/python/training/optim/fp16_optimizer.py
pengwa 2c6b31c5aa
FP16 optimizer automatically detect DeepSpeed compatibility (#18084)
### FP16 optimizer automatically detect DeepSpeed compatibility

Optimum/Transformers are using accelerate lib to prepare models, so our
FP16 optimizer wrapper does not work for long time. Because the
namespace is `accelerate.utils.deepspeed.DeepSpeedOptimizerWrapper`,
which underlying is still calling into DeepSpeed stage1and2 optimizer.

This PR includes following changes:
1. Add `accelerate.utils.deepspeed.DeepSpeedOptimizerWrapper` in the
modifier registry, plus a check on its contained `optimizer` property
MUST be DeepSpeed stage 1 and 2 optimizer. (let's cover Stage 3
optimizer later)
2. For DeepSpeed version > 0.9.1, we will store the source code in a
version list. As long as the related function in DeepSpeed remains
unchanged during its new release, we won't need manually upgrade the
version check any more. If some day, the source code did not match, a
warning will be raised to users, to add a new version of source code in
the list.

With the above change, we will have our FP16 Optimizer working again in
Optimum.


![image](https://github.com/microsoft/onnxruntime/assets/10530022/d35b4aa9-b371-46f1-98ae-73114f91179b)
2023-10-25 15:11:02 +08:00

91 lines
3.8 KiB
Python

# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from ._modifier_registry import OptimizerModifierTypeRegistry, get_full_qualified_type_name
def FP16_Optimizer(optimizer, **kwargs): # noqa: N802
"""
Simple wrapper to replace inefficient FP16_Optimizer function calls implemented by libraries for example
Apex, DeepSpeed, Megatron-LM.
Usage:
1. DeepSpeed ZeRO Optimizer Override:
>>> from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer
>>> optimizer = Adam(param_groups,
>>> lr=args.lr,
>>> weight_decay=args.weight_decay,
>>> betas=(args.adam_beta1, args.adam_beta2),
>>> eps=args.adam_eps)
>>> model, optimizer, _, lr_scheduler = deepspeed.initialize(
>>> model=model,
>>> optimizer=optimizer,
>>> args=args,
>>> lr_scheduler=lr_scheduler,
>>> mpu=mpu,
>>> dist_init_required=False)
>>> if args.fp16:
>>> optimizer = FP16_Optimizer(optimizer)
2. Megatron-LM-v1.1.5 Optimizer Override:
>>> from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer
>>> optimizer = Adam(param_groups,
>>> lr=args.lr,
>>> weight_decay=args.weight_decay,
>>> betas=(args.adam_beta1, args.adam_beta2),
>>> eps=args.adam_eps)
>>> # Wrap into fp16 optimizer.
>>> if args.fp16:
>>> optimizer = FP16_Optimizer(optimizer,
>>> static_loss_scale=args.loss_scale,
>>> dynamic_loss_scale=args.dynamic_loss_scale,
>>> dynamic_loss_args={
>>> 'scale_window': args.loss_scale_window,
>>> 'min_scale': args.min_scale,
>>> 'delayed_shift': args.hysteresis},
>>> verbose=True)
>>> optimizer = ORT_FP16_Optimizer(optimizer,
>>> get_tensor_model_parallel_rank=mpu.get_model_parallel_rank,
>>> get_tensor_model_parallel_group=mpu.get_model_parallel_group)
3. APEX AMP Override:
>>> from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer
>>> optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
>>> model, optimizer = amp.initialize(model, optimizer, opt_level="O2")
>>> optimizer = ORT_FP16_Optimizer(optimizer)
>>>
>>> # Wrap model with ORTModule tricks
>>> def patch_new_fwd(old_new_fwd):
>>> def new_new_fwd(self, *args, **kwargs):
>>> return old_new_fwd(*args, **kwargs)
>>> return new_new_fwd
>>> model.forward = types.MethodType(patch_new_fwd(model.forward), model)
>>> model = ORTModule(model)
Args:
optimizer: the FP16_Optimizer instance
Returns:
The modified FP16_Optimizer instance
"""
optimizer_full_qualified_name = (
"apex.amp.optimizer.unique_name_as_id"
if hasattr(optimizer, "_amp_stash")
else get_full_qualified_type_name(optimizer)
)
modifier = OptimizerModifierTypeRegistry.create_modifier(optimizer_full_qualified_name, optimizer, **kwargs)
if modifier is not None:
modifier.apply()
return optimizer