onnxruntime/orttraining/orttraining/python/training/optim/fp16_optimizer.py
Justin Chu d834ec895a
Adopt linrtunner as the linting tool - take 2 (#15085)
### Description

`lintrunner` is a linter runner successfully used by pytorch, onnx and
onnx-script. It provides a uniform experience running linters locally
and in CI. It supports all major dev systems: Windows, Linux and MacOs.
The checks are enforced by the `Python format` workflow.

This PR adopts `lintrunner` to onnxruntime and fixed ~2000 flake8 errors
in Python code. `lintrunner` now runs all required python lints
including `ruff`(replacing `flake8`), `black` and `isort`. Future lints
like `clang-format` can be added.

Most errors are auto-fixed by `ruff` and the fixes should be considered
robust.

Lints that are more complicated to fix are applied `# noqa` for now and
should be fixed in follow up PRs.

### Notable changes

1. This PR **removed some suboptimal patterns**:

	- `not xxx in` -> `xxx not in` membership checks
	- bare excepts (`except:` -> `except Exception`)
	- unused imports
	
	The follow up PR will remove:
	
	- `import *`
	- mutable values as default in function definitions (`def func(a=[])`)
	- more unused imports
	- unused local variables

2. Use `ruff` to replace `flake8`. `ruff` is much (40x) faster than
flake8 and is more robust. We are using it successfully in onnx and
onnx-script. It also supports auto-fixing many flake8 errors.

3. Removed the legacy flake8 ci flow and updated docs.

4. The added workflow supports SARIF code scanning reports on github,
example snapshot:
	

![image](https://user-images.githubusercontent.com/11205048/212598953-d60ce8a9-f242-4fa8-8674-8696b704604a.png)

5. Removed `onnxruntime-python-checks-ci-pipeline` as redundant

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Unified linting experience in CI and local.

Replacing https://github.com/microsoft/onnxruntime/pull/14306

---------

Signed-off-by: Justin Chu <justinchu@microsoft.com>
2023-03-24 15:29:03 -07:00

101 lines
4.1 KiB
Python

# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import warnings
from ._modifier_registry import OptimizerModifierTypeRegistry
def FP16_Optimizer(optimizer, **kwargs): # noqa: N802
"""
Simple wrapper to replace inefficient FP16_Optimizer function calls implemented by libraries for example
Apex, DeepSpeed, Megatron-LM.
Usage:
1. DeepSpeed ZeRO Optimizer Override:
>>> from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer
>>> optimizer = Adam(param_groups,
>>> lr=args.lr,
>>> weight_decay=args.weight_decay,
>>> betas=(args.adam_beta1, args.adam_beta2),
>>> eps=args.adam_eps)
>>> model, optimizer, _, lr_scheduler = deepspeed.initialize(
>>> model=model,
>>> optimizer=optimizer,
>>> args=args,
>>> lr_scheduler=lr_scheduler,
>>> mpu=mpu,
>>> dist_init_required=False)
>>> if args.fp16:
>>> optimizer = FP16_Optimizer(optimizer)
2. Megatron-LM-v1.1.5 Optimizer Override:
>>> from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer
>>> optimizer = Adam(param_groups,
>>> lr=args.lr,
>>> weight_decay=args.weight_decay,
>>> betas=(args.adam_beta1, args.adam_beta2),
>>> eps=args.adam_eps)
>>> # Wrap into fp16 optimizer.
>>> if args.fp16:
>>> optimizer = FP16_Optimizer(optimizer,
>>> static_loss_scale=args.loss_scale,
>>> dynamic_loss_scale=args.dynamic_loss_scale,
>>> dynamic_loss_args={
>>> 'scale_window': args.loss_scale_window,
>>> 'min_scale': args.min_scale,
>>> 'delayed_shift': args.hysteresis},
>>> verbose=True)
>>> optimizer = ORT_FP16_Optimizer(optimizer,
>>> get_tensor_model_parallel_rank=mpu.get_model_parallel_rank,
>>> get_tensor_model_parallel_group=mpu.get_model_parallel_group)
3. APEX AMP Override:
>>> from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer
>>> optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
>>> model, optimizer = amp.initialize(model, optimizer, opt_level="O2")
>>> optimizer = ORT_FP16_Optimizer(optimizer)
>>>
>>> # Wrap model with ORTModule tricks
>>> def patch_new_fwd(old_new_fwd):
>>> def new_new_fwd(self, *args, **kwargs):
>>> return old_new_fwd(*args, **kwargs)
>>> return new_new_fwd
>>> model.forward = types.MethodType(patch_new_fwd(model.forward), model)
>>> model = ORTModule(model)
Args:
optimizer: the FP16_Optimizer instance
Returns:
The modified FP16_Optimizer instance
"""
def get_full_qualified_type_name(o):
if hasattr(optimizer, "_amp_stash"):
return "apex.amp.optimizer.unique_name_as_id"
klass = o.__class__
module = klass.__module__
if module == "builtins":
return klass.__qualname__
return module + "." + klass.__qualname__
optimizer_full_qualified_name = get_full_qualified_type_name(optimizer)
if optimizer_full_qualified_name not in OptimizerModifierTypeRegistry:
warnings.warn("Skip modifying optimizer because of optimizer name not found in registry.", UserWarning)
return optimizer
modifier = OptimizerModifierTypeRegistry[optimizer_full_qualified_name](optimizer, **kwargs)
modifier.apply()
return optimizer