onnxruntime/orttraining/orttraining/python/training/optim
pengwa 2c6b31c5aa
FP16 optimizer automatically detect DeepSpeed compatibility (#18084)
### FP16 optimizer automatically detect DeepSpeed compatibility

Optimum/Transformers are using accelerate lib to prepare models, so our
FP16 optimizer wrapper does not work for long time. Because the
namespace is `accelerate.utils.deepspeed.DeepSpeedOptimizerWrapper`,
which underlying is still calling into DeepSpeed stage1and2 optimizer.

This PR includes following changes:
1. Add `accelerate.utils.deepspeed.DeepSpeedOptimizerWrapper` in the
modifier registry, plus a check on its contained `optimizer` property
MUST be DeepSpeed stage 1 and 2 optimizer. (let's cover Stage 3
optimizer later)
2. For DeepSpeed version > 0.9.1, we will store the source code in a
version list. As long as the related function in DeepSpeed remains
unchanged during its new release, we won't need manually upgrade the
version check any more. If some day, the source code did not match, a
warning will be raised to users, to add a new version of source code in
the list.

With the above change, we will have our FP16 Optimizer working again in
Optimum.


![image](https://github.com/microsoft/onnxruntime/assets/10530022/d35b4aa9-b371-46f1-98ae-73114f91179b)
2023-10-25 15:11:02 +08:00
..
__init__.py Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
_apex_amp_modifier.py Enable pylint and numpy rules (#15218) 2023-03-27 20:37:53 -07:00
_ds_code_store.py FP16 optimizer automatically detect DeepSpeed compatibility (#18084) 2023-10-25 15:11:02 +08:00
_ds_modifier.py FP16 optimizer automatically detect DeepSpeed compatibility (#18084) 2023-10-25 15:11:02 +08:00
_megatron_modifier.py [Better Engineering] Bump ruff to 0.0.278 and fix new lint errors (#16789) 2023-07-21 12:53:41 -07:00
_modifier.py Replace call to deprecated torch.norm (#16758) 2023-07-20 19:52:19 -07:00
_modifier_registry.py FP16 optimizer automatically detect DeepSpeed compatibility (#18084) 2023-10-25 15:11:02 +08:00
_multi_tensor_apply.py Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
config.py [Better Engineering] Bump ruff to 0.0.278 and fix new lint errors (#16789) 2023-07-21 12:53:41 -07:00
fp16_optimizer.py FP16 optimizer automatically detect DeepSpeed compatibility (#18084) 2023-10-25 15:11:02 +08:00
fused_adam.py Adding this set_to_none flag to zero_grad to have signature parity with pytorch Adam (#16375) 2023-06-19 17:27:41 -07:00
lr_scheduler.py Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00