onnxruntime/tools/ci_build
pengwa 1150b1f81e
ORTModule memory improvement (#18924)
## Dependency

https://github.com/microsoft/onnxruntime/pull/19007

## ORTModule memory efficient gradient management

Previously I have tried to solve the coarsed-grained gradient
accumulation/update problem in ORTModule with
https://github.com/microsoft/onnxruntime/pull/8979, while that
resolution somehow is not fully validated with DDP or there is user
hooks on the gradient accumulation on torch parameter.

This PR is addressing the problem in the similar approach as PR 8979,
e.g. trigger gradient accumulation once ORT computed the grad, but
instead of use a AccumulateGrad op, this time with a ONNX operator
PythonOp, internally it will call param.backward(grad), which will help
handle all related hooks correctly.


## Design

Check the details from


https://microsoftapc-my.sharepoint.com/:p:/g/personal/pengwa_microsoft_com/EaaBq4EzsFhOmsDEXCG7Ba4Bb9bwd0O2sFV_JXJ4jBLYLA?e=7Sz2g8&nav=eyJzSWQiOjI3MSwiY0lkIjozMjE4NzI1NDIzfQ

## Convergence Validation:


![image](https://github.com/microsoft/onnxruntime/assets/10530022/ccf3a213-e815-4b23-b759-165033b2d9fe)

differences are on mostly 0.000x, sometimes 0.00x, which may comes from
the different order gradient apply happens before or after this change
(on deepspeed zero stage 2)


## TODO

Consolidate the logic with Stage3's similar logic.
2024-01-16 08:57:37 +08:00
..
github ORTModule memory improvement (#18924) 2024-01-16 08:57:37 +08:00
__init__.py
amd_hipify.py undo hipify of __half to rocblas_half (#18573) 2023-11-24 18:04:23 +08:00
build.py Set default flags nvcc and do not set default compile flags for ROCM EP (#19124) 2024-01-14 11:36:49 -08:00
clean_docker_image_cache.py
compile_triton.py [Better Engineering] Bump ruff to 0.0.278 and fix new lint errors (#16789) 2023-07-21 12:53:41 -07:00
coverage.py
gen_def.py [TensorRT EP] Refactor OrtTensorRTProviderOptions initialization and make it easy to add new field (#17617) 2023-10-06 14:12:20 -07:00
get_docker_image.py [Better Engineering] Bump ruff to 0.0.278 and fix new lint errors (#16789) 2023-07-21 12:53:41 -07:00
logger.py
op_registration_utils.py [CI] Removes type2 in process_registration and fix Windows GPU Reduced Ops CI Pipeline (#16530) 2023-07-07 18:21:06 +02:00
op_registration_validator.py [CI] Removes type2 in process_registration and fix Windows GPU Reduced Ops CI Pipeline (#16530) 2023-07-07 18:21:06 +02:00
patch_manylinux.py [Better Engineering] Bump ruff to 0.0.278 and fix new lint errors (#16789) 2023-07-21 12:53:41 -07:00
policheck_exclusions.xml
reduce_op_kernels.py Re-organize the transpose optimization and layout transformation files. (#16246) 2023-07-07 08:24:47 +10:00
replace_urls_in_deps.py Add a build validation for Linux ARM64 cross-compile (#18200) 2023-11-08 13:03:18 -08:00
requirements.txt Adding python3.12 support to ORT (#18814) 2024-01-11 08:34:28 -08:00
set-trigger-rules.py Pr trggiers generated by code (#17247) 2023-08-30 05:57:03 +08:00
update_tsaoptions.py
upload_python_package_to_azure_storage.py [Linter] Bump ruff and remove pylint (#17797) 2023-10-05 21:07:33 -07:00