mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-24 02:47:54 +00:00
24 commits
| Author | SHA1 | Message | Date | |
|---|---|---|---|---|
|
|
4932e04053
|
ORTModule GraphTransitionManager (#19007)
### Problem
Currently, the codebase contains some logics pertaining to model
re-export checks and graph_builder reinitialization checks. Ideally,
these operations should function akin to a state machine. However, upon
inspecting the implementation, it becomes apparent that certain states
are checked or set in various scattered locations. This fragmentation
makes it challenging to comprehend when a re-export or re-initialization
will be triggered. For optimal clarity and maintainability, it is
advisable to consolidate these states into a cohesive component, rather
than dispersing them within the current graph execution manager.
Furthermore, the process of model exports and post-export processing for
stage 3 support or memory-efficient gradient management introduces
considerable complexity. To enhance the codebase's structure, it would
be beneficial to extract these intricate functionalities into a
dedicated component, divorcing them from the current graph execution
manager.
As part of the effort to improve the codebase, it's essential to address
inconsistencies in handling input/output flatten/unflatten operations.
Currently, there are several functions performing these operations
recursively, each with slightly different implementations. This
inconsistency leads to varying support for input/output data types and
structures in different parts of the code. To rectify this, the proposed
pull request simplifies these operations into a set of primitive
functions, ensuring uniformity. This not only streamlines the code but
also facilitates the maintenance of consistency when introducing bug
fixes or supporting new data types. One thing to mention here: input
output handling is deeply bound to the graph transition mentioned above,
so it is difficult to make this change separately.
While acknowledging the complexity of these logics, it is reassuring
that the codebase benefits from an extensive suite of unit tests that
cover all possible branches. Despite the intricacies, ensuring the
passage of all tests has been a time-intensive but necessary aspect of
this development effort.
### Design
Introduce `GraphTransitionManager` and put all model export and
post-export processing logics in it.
1. Re-export check
2. Do export
3. Re-post-export process check
4. Do post-export process
5. Return `PostExportProcessedModelInfo`, which contains all the
information we need, to pass to ORT to build gradient graph (currently
we do the same for training or evaluating, but ideally we should not do
it for evaluating, let's keep this behavior as it is now, and make the
change later).
```
# Input names for the pre-gradient-build graph.
# This may be different with the one in ExportedGraph since we may
modify the graph inputs as needed
# for example when memory efficient gradient management is enabled.
self.onnx_graph_input_names: list[str] = onnx_graph_input_names
# A subset of onnx_graph_input_names.
# Input names that require gradients for the pre-gradient-build graph.
self.onnx_graph_input_names_require_grad: list[str] =
onnx_graph_input_names_require_grad
# Create symbolic names for each dimension of the graph input (e.g.
onnx_graph_input_names).
# The key is the input name, the value is a dict of {dim_index:
symbolic_dim_name}
# e.g. {"input1": {0: "input1_dim0", 1: "input1_dim1"}, "input2": {0:
"input2_dim0"}}
self.onnx_graph_input_dynamic_axes_map: dict[str, dict[int, str]] =
onnx_graph_input_dynamic_axes_map
self.buffer_for_ort_runs: dict[str, torch.Tensor] = OrderedDict()
self.onnx_graph_input_names_user_defined = (
onnx_graph_input_names_user_defined # The ONNX graph input names
excluding the parameters, buffers.
)
# The ONNX graph input names excluding the parameters, buffers.
self.onnx_graph_input_names_require_grad_user_defined =
onnx_graph_input_names_require_grad_user_defined
self._post_export_processed_model: onnx.ModelProto | None =
post_export_processed_model
# A function to access the input data from the args and kwargs.
# If it is not None, the length is same as onnx_graph_input_names.
# For i-th input name, we can use the i-th function to get the input
data from args and kwargs.
self.data_accessor: list[callable] | None = data_accessor
# Used for unflattening the outputs from the ORT forward run.
self.module_forward_output_schema: ORTModelInputOutputSchemaType | None
= module_forward_output_schema```
The `GraphTransitionManager` instance is a property of
`GraphExecutionManager` (e.g. `TrainingManager` or ``InferenceManager),
1. Use
'self._graph_transition_manager.use_cache_or_reconstruct_post_processed_model(inputs,
kwargs)' to check whether the PyTorch module need a re-export or
re-post-export-process.
2. Use
`self._graph_transition_manager._post_export_processed_model_info.construct_inputs`
to construct the list of inputs used for ORT runs.
3. Use
`self._graph_transition_manager._post_export_processed_model_info.restore_outputs(user_outputs)`
to restore the outputs in original PyTorch output structure.
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
|
||
|
|
3e954da3e6
|
Fix and enable few ORTModule Unit Tests (#19847)
### Fix and enable few ORTModule Unit Tests
Fix 'test_bert_inputs_with_dynamic_shape' and
'test_bert_result_with_layerwise_recompute' generate Nan loss in ORT
run.
The root cause is, the logic to generatic attention mask test data is
not correct, only 0 or 1 is allowed in the dataset, but we see lots of
other numbers. ( The reason we don't have this using old version of
transformers for example v4.4.2 or 4.16.2 is because they don't contains
such
|
||
|
|
026e3178ae
|
Improve memory matrix for ORTModule (#19620)
### Memory matrix for ORTModule Collect parameter/gradient/buffers sizes also. Exposed as a function, can be used externally for debugging purpose. ``` 2024-02-27 07:18:55,283 orttraining.rank-0 [INFO] - rank-0 step 1 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 816 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,322 orttraining.rank-0 [INFO] - rank-0 step 1 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 816 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,358 orttraining.rank-0 [INFO] - rank-0 step 1 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 816 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,438 orttraining.rank-0 [INFO] - rank-0 step 1 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|▏ | 2/3200 [01:27<32:05:11, 36.12s/it]2024-02-27 07:18:55,498 orttraining.rank-0 [INFO] - rank-0 step 2 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,537 orttraining.rank-0 [INFO] - rank-0 step 2 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,576 orttraining.rank-0 [INFO] - rank-0 step 2 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,657 orttraining.rank-0 [INFO] - rank-0 step 2 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|▏ | 3/3200 [01:27<17:30:57, 19.72s/it]2024-02-27 07:18:55,711 orttraining.rank-0 [INFO] - rank-0 step 3 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,750 orttraining.rank-0 [INFO] - rank-0 step 3 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,786 orttraining.rank-0 [INFO] - rank-0 step 3 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,867 orttraining.rank-0 [INFO] - rank-0 step 3 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 [2024-02-27 07:18:55,886] [INFO] [loss_scaler.py:190:update_scale] [deepspeed] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 65536, but hysteresis is 2. Reducing hysteresis to 1 0%|▎ | 4/3200 [01:28<10:39:52, 12.01s/it]2024-02-27 07:18:55,902 orttraining.rank-0 [INFO] - rank-0 step 4 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,944 orttraining.rank-0 [INFO] - rank-0 step 4 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,979 orttraining.rank-0 [INFO] - rank-0 step 4 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,060 orttraining.rank-0 [INFO] - rank-0 step 4 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|▍ | 5/3200 [01:28<6:53:04, 7.76s/it]2024-02-27 07:18:56,115 orttraining.rank-0 [INFO] - rank-0 step 5 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,154 orttraining.rank-0 [INFO] - rank-0 step 5 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,190 orttraining.rank-0 [INFO] - rank-0 step 5 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,270 orttraining.rank-0 [INFO] - rank-0 step 5 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|▍ | 6/3200 [01:28<4:36:19, 5.19s/it]2024-02-27 07:18:56,323 orttraining.rank-0 [INFO] - rank-0 step 6 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,365 orttraining.rank-0 [INFO] - rank-0 step 6 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,398 orttraining.rank-0 [INFO] - rank-0 step 6 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,478 orttraining.rank-0 [INFO] - rank-0 step 6 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|▌ | 7/3200 [01:28<3:09:33, 3.56s/it]2024-02-27 07:18:56,533 orttraining.rank-0 [INFO] - rank-0 step 7 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,572 orttraining.rank-0 [INFO] - rank-0 step 7 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,608 orttraining.rank-0 [INFO] - rank-0 step 7 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,727 orttraining.rank-0 [INFO] - rank-0 step 7 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|▌ | 8/3200 [01:28<2:13:48, 2.52s/it]2024-02-27 07:18:56,806 orttraining.rank-0 [INFO] - rank-0 step 8 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,846 orttraining.rank-0 [INFO] - rank-0 step 8 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,882 orttraining.rank-0 [INFO] - rank-0 step 8 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,962 orttraining.rank-0 [INFO] - rank-0 step 8 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|▋ | 9/3200 [01:29<1:36:03, 1.81s/it]2024-02-27 07:18:57,053 orttraining.rank-0 [INFO] - rank-0 step 9 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:57,094 orttraining.rank-0 [INFO] - rank-0 step 9 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 ``` |
||
|
|
1150b1f81e
|
ORTModule memory improvement (#18924)
## Dependency https://github.com/microsoft/onnxruntime/pull/19007 ## ORTModule memory efficient gradient management Previously I have tried to solve the coarsed-grained gradient accumulation/update problem in ORTModule with https://github.com/microsoft/onnxruntime/pull/8979, while that resolution somehow is not fully validated with DDP or there is user hooks on the gradient accumulation on torch parameter. This PR is addressing the problem in the similar approach as PR 8979, e.g. trigger gradient accumulation once ORT computed the grad, but instead of use a AccumulateGrad op, this time with a ONNX operator PythonOp, internally it will call param.backward(grad), which will help handle all related hooks correctly. ## Design Check the details from https://microsoftapc-my.sharepoint.com/:p:/g/personal/pengwa_microsoft_com/EaaBq4EzsFhOmsDEXCG7Ba4Bb9bwd0O2sFV_JXJ4jBLYLA?e=7Sz2g8&nav=eyJzSWQiOjI3MSwiY0lkIjozMjE4NzI1NDIzfQ ## Convergence Validation:  differences are on mostly 0.000x, sometimes 0.00x, which may comes from the different order gradient apply happens before or after this change (on deepspeed zero stage 2) ## TODO Consolidate the logic with Stage3's similar logic. |
||
|
|
998517b209
|
Minor fixes (#18949)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> |
||
|
|
5eda79bdd3
|
Improve perf for stage3 training (#18099)
### Improve perf for stage3 training - first wave Port existing PythonOp/PythonOpGrad python runner to C++, also introduce an unsafe run mode (to skip inplace, save for backward, materrialized grad detection on the fly). This reduce the overhead from XX~XXX us to X ~ lower end of XX us . In LLAMA2 7B training with 8x32GV100, we have observed 6.7% gains over PyTorch. (1.59 v.s. 1.49it/s) Peak memory also dropped from 31GB to 28GB. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> |
||
|
|
ccf3b2054b
|
Allow layer-wise recompute (#18566)
### Allow layer-wise recompute Early, we need users/developers to specify the subgraphs to recompute, now we introduced a more user-friendly way to enable recompute for all detected stashed activation recomputation subgraphs. This scarifies getting the best configs while makes it easier to support user requirements when they switches from PyTorch per-layer gradient checkpoint to ORTModule. `ORTMODULE_MEMORY_OPT_LEVEL` is introduced to control the usage, by default, it is 0, e.g. `USER_SPECIFIED`, all subgraphs definedin `ORTMODULE_MEMORY_OPT_CONFIG` will be recomputed. So this is compatible to existing recompute usage in ORTModule integrated models. Using `ORTMODULE_MEMORY_OPT_LEVEL=1`, we will enable all recompute plans detected, so those configs in `ORTMODULE_MEMORY_OPT_CONFIG` will not be respected any more. Add Unit Tests using 3 layer blooms. https://github.com/microsoft/onnxruntime/blob/pengwa/add_aggresive_recompute/docs/Memory_Optimizer.md |
||
|
|
43a5147e01
|
Memory optimization refactor and refinement (#17481)
### Memory optimization refactor and refinement Currently memory optimizer runs graph transformations and print recompute opportunities in INFO level, while ORT backend has many many INFO level logs making users hard to find those information. So we are looking for a Python binding API to retrieve the memory optimization opportunities instead of depending on the MemoryOptimizer's default logging. Then we can print ORTModule feature statistics using this information. Also, with such an API, we can create an ORT session created, where allocation plan is done, the analysis will consider buffer reuse as well. This can void giving some recomputation subgraphs that are reusing other subgraphs' output buffers. Check https://github.com/microsoft/onnxruntime/blob/pengwa/add_devinfo_level/docs/Memory_Optimizer.md for the new flow using `MemoryOptimizer`. This pull requests made following refactoring: 1. Print the log in ORTModule Python script, along with ORTModule feature enabling stats. This is implemented by exposing an API `get_serialized_ortmodule_memory_stat` to retrieve the memory optimization opportunities. 2. We are analyzing memory optimization opportunities considering ORT memory planning. This is done by firstly creating the execution graph without enabling MemoryOptimizer, then we call `execution_agent.get_serialized_ortmodule_memory_stat` which internally will consider the session memory allocation planner when analyzing memory optimization opportunity. As a direct result, the memory optimization opportunities can show those stashed activations that are reusing other buffers. 3. Move recompute analysis logic from memory_optimizer.h/cc to recompute_analysis.h/cc. 4. Abstract optimization strategies for their own implementation. This will make introducing new strategies (for example compression and decompression ) easier. New logging matrix (INFO Level), in WARNING level, the details will NOT show. ``` 2023-09-13 13:25:09,249 orttraining.rank-0 [WARNING] - ***** ONNX Runtime Training (ORTModule) is accelerating your model ***** ORTModule is enabled with following features ON/OFF for [training] mode: ATen Executor : ON : Dispatch ATen operators to ORT's ATen executor Cast Propagation : ON : Level 1 enabled Custom Function : ON : Support custom torch.autograd.Function export and execution Memory Optimizer : ON : RecomputeConfig: Reshape+Where+BiasSoftmax+:1:-1,Cast+:1:-1, ProbeLevel: 1, available configs: Config Freq Saving(B) Saving Symbolic(Bytes) - Plan 1 : ON : Reshape+Where+BiasSoftmax+:1:-1 5 671,088,640 640.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2 - Plan 2 : ON : Cast+:1:-1 6 402,587,648 inputs_input_ids_dim0*inputs_input_ids_dim1*(384.0*inputs_input_ids_dim1 - 64.0) - Plan 3 : OFF : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2 - Plan 4 : OFF : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1) - Plan 5 : OFF : BiasGelu+:1:-1 6 125,808,640 inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0) - Plan 6 : OFF : FusedMatMul+:1:-1 6 125,808,640 inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0) - Plan 7 : OFF : FusedMatMul+Add+FusedMatMul+Add+Add+Add+:1:-1 5 26,214,400 25600.0*inputs_input_ids_dim0*inputs_input_ids_dim1 - Plan 8 : OFF : Add+:1:-1 1 5,237,760 5120.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) - Plan 9 : OFF : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1 - Plan 10 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1 Compute Optimizer : ON : Enable/Disable with env ORTMODULE_ENABLE_COMPUTE_OPTIMIZER=1/0 - FLOPReduction : ON : Reduce FLOPs by upstreaming shrinking-sized ops Auto Fallback : ON : Fallback to PyTorch when encountering unsupported ops TritonOp Enabled : OFF : ORT will switch to Triton for executing some ops to further accelerate training. ZeRO Stage3 Support : OFF : Enable/Disable with env ORTMODULE_ENABLE_ZERO_STAGE3=1/0 Total ORT initialization overhead is 10.73s where export takes 8.39s. Other overhead details: graph builder init takes 0.06s, runtime detection takes 0.01s, graph building takes 0.31s, session creation takes 1.96s Versions: ONNX Runtime - 1.16.0+cu118, ONNX - 1.11.0 Note 1: use comma to enable multiple plans at the same time. export ORTMODULE_MEMORY_OPT_CONFIG=<plan1 config>,<plan2 config>,... Note 2: saving is calculated based on the 1st batch symbolic dim values: inputs_input_ids_dim0=1, inputs_input_ids_dim1=1024, inputs_attention_mask_dim0=1, inputs_attention_mask_dim1=1024, inputs_labels_dim0=1, inputs_labels_dim1=1024, ************************************************************************ ``` If DEVINFO level is enabled, then more details about the memory optimizations are printed. ``` MemoryInsight Summary - User config: BiasGelu+:1:-1,Cast+:2:-1 ========================================================================================================================================== |Freq | Memory Optimization Opportunities (Clustered by node-level activation patterns) | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |3 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph FusedMatMul+Add+Reshape+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+Add+Reshape+:1:-1 | | | Stashed Activations: | | | - ReuseFreq : Output 0(3), | | | - Output 0 : [inputs_input_ids_dim0 x inputs_input_ids_dim1 x 32 x 240 x ], byte/elem: 2, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |2 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph Reshape+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+:1:-1 | | | Stashed Activations: | | | - ReuseFreq : Output 0(2), | | | - Output 0 : [ x 2560 x ], byte/elem: 2, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |2 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph FusedMatMul+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+:1:-1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0 x inputs_input_ids_dim1 x 10240 x ], byte/elem: 2, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |2 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph Cast+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Cast+:1:-1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 x inputs_input_ids_dim1 x ], byte/elem: 2, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |2 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph Reshape+Where+BiasSoftmax+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+Where+BiasSoftmax+:1:-1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 x inputs_input_ids_dim1 x ], byte/elem: 4, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |2 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph BiasGelu+ | | | Status : Enabled, requested count=-1, actual applied count=2 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0 x inputs_input_ids_dim1 x 10240 x ], byte/elem: 2, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |2 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph FusedMatMul+Add+FusedMatMul+Add+Add+Add+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+Add+FusedMatMul+Add+Add+Add+:1:-1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0 x inputs_input_ids_dim1 x 2560 x ], byte/elem: 2, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |1 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph Reshape+Where+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+Where+:1:-1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 x inputs_input_ids_dim1 x ], byte/elem: 4, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |1 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph FusedMatMul+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=FusedMatMul+:1:-1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) x 10240 x ], byte/elem: 2, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |1 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph Cast+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Cast+:1:-1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 - 1 x inputs_input_ids_dim1 x ], byte/elem: 2, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |1 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0 x 1 x 1 x inputs_input_ids_dim1 x ], byte/elem: 4, 100% saved | | | | | |>>Option 2 : RecomputeWithCompromise subgraph Cast+ | | | Status : Enabled, requested count=-1, actual applied count=1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0 x 1 x 1 x inputs_input_ids_dim1 x ], byte/elem: 4, 50% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |1 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph BiasSoftmax+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=BiasSoftmax+:1:-1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0 x 32 x inputs_input_ids_dim1 - 1 x inputs_input_ids_dim1 x ], byte/elem: 4, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |1 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph BiasGelu+ | | | Status : Enabled, requested count=-1, actual applied count=1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) x 10240 x ], byte/elem: 2, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |1 |For each row options are mutually exclusive, only one of them can be enabled. | | | | | |>>Option 1 : Recompute subgraph Add+ | | | Status : Disabled. Enable with export ORTMODULE_MEMORY_OPT_CONFIG=Add+:1:-1 | | | Stashed Activations: | | | - Output 0 : [inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1) x 2560 x ], byte/elem: 2, 100% saved | |_ _ _ _|_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | ========================================================================================================================================== Note: use comma as a separator for enabling more than one subgraphs. ************************************************************************ ``` ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> |
||
|
|
c8e1038eab
|
Optimize 4bit Qlora training (#18131)
### Optimize 4bit Qlora training Extent existing `MatmulBnb4bit` to its usage in training scenarios. The PR includes following changes: 1. Add special `torch.autograd.Function` export logic for `bitsandbytes.autograd._functions.MatMul4Bit` that is preferred before common PythonOp exporter. 2. Add `training_mode` optional attribute for op `MatmulBnb4bit`, which help skip some inference specific logic in implementation. 3. Add `transB` optional attribute, which is by default be 1; setting it to be 0 is needed by backward usage. Changing from `PythonOp` to this `MatmulBnb4bit` brings roughly ~2.9% throughput gains. The reason is: `bitsandbytes.autograd._functions.MatMul4Bit` has logic `ctx.save_for_backward`, which would need an additional copy in PythonOp, otherwise, the tensor might be released by ORT, while backward op still references it. Removing the clones also reduce the peak memory consumptions because `bitsandbytes.autograd._functions.MatMul4Bit` saved tensors that are not needed in backward compute. |
||
|
|
0e2782438a
|
Support inplace update for PythonOp/Grad (#17687)
### Support inplace update for PythonOp/Grad This PR is based on another PR https://github.com/microsoft/onnxruntime/pull/17685's branch, to make it easier to review. With PR: PR https://github.com/microsoft/onnxruntime/pull/17685, By default all PythonOp inputs/outputs are assumed to not be inplaced, if during run, we found some inplace update happens (by checking output data address with all inputs data address), we add clone before set it as PythonOp/Grad's outputs. In this case, results are correct, but implicit copies overheads are introduced. This PR allow users to define output input reuse map, to let ORT know how to do the reuse map, avoid such unnecessary copies. |
||
|
|
54b7503c30
|
create patch for allgather fn for deepspeed stage 3 (#17855)
### Description <!-- Describe your changes. --> Patch for All gather fn for Deepspeed Stage 3 changes ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> |
||
|
|
6b7bce5ec9
|
Model post process for zero stage3 training (#17187)
### Model post process for zero stage3 training This is the last change to make single GPU/Multiple GPUs run pass. Design details: https://microsoft.sharepoint.com/:p:/t/ONNX2/EfNfJ43necpIoPI6x5M2zvYBVbfjoPQmG4Boc_F7-tHm1w?e=ekQwA6&nav=eyJzSWQiOjMxNiwiY0lkIjoxMDE1Nzg3NDZ9 `PyTorch` runs with ZeROOffloadSubscriber: ``` model = prepare_model(...) from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3 configure_ort_compatible_zero_stage3() ``` `ORTModule` runs with ZeROOffloadSubscriber: ``` os.environ['ORTMODULE_ENABLE_ZERO_STAGE3'] = '1' from onnxruntime.training.ortmodule import ORTModule model = ORTModule(self.model) ``` It will be fairly easy to debug convergence issue if both ORT and PyTorch can run the same offload path. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> |
||
|
|
d90afc697b
|
Introduce ZeROOffloadSubscriber for ORTModule (#17006)
### Introduce ZeROOffloadSubscriber for ORTModule As part of the work: integrate ORTModule with DeepSpeed stage3, this PR mainly focus on moving original PyTorch-based (leveraging hooks) param partition/offload implementation to ORTModule compatible implementation. Changes include: 1. Refactor `SubscriberBase`/`SubcriberManager` to support pre-forward/post_forward hooks. 2. Implement new `ZeROOffloadSubscriber` by re-using DeepSpeed hook function as much as possible. Since all hook functions are defined in `DeepSpeedZeRoOffload._register_hooks_recursively` and `DeepSpeedZeRoOffload.setup_zero_stage3_hooks`, and the good thing is, the closure is not complex, all hooks are referencing the owning `DeepSpeedZeRoOffload` instance, so we can create new hook function with `FunctionType` by binding the owning `DeepSpeedZeRoOffload` instance, then call the new created function in subscriber's `pre_forward_module_apply_impl` and `post_forward_module_apply_impl` interfaces. 3. Monkey patch `DeepSpeedZeRoOffload.setup_zero_stage3_hooks` to register the `ZeROOffloadSubscriber` for the model, then we don't need change any code on the DeepSpeed repo (at least so far). 4. Fix the ATen embedding custom symbolic exporter function by tolerating weights size be (0) (changed by DeepSpeed zero stage 3). UT will be added once stage3 is fully supported. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> |
||
|
|
cd7b3f54da
|
Allow defining customized PythonOp shape inferer (#17093)
### Allow defining customized PythonOp shape inferer For `torch.autograd.Function`, we converted it to PythonOp in MSDomain, there are two places to do shape inferencing for it: 1. in SymbolicShapeInfer, there is one. 2. in PythonOp op definition. For common PythonOp, since we don't know the relation ship between inputs and outputs, so we only infer the rank from output ranks, and generate symbolic dimensions for each dim. While this will introduce many meaningless symbolic dimensions, sometimes blocking our graph transformers to do op fusion. This PR provide a way to define custom shape inferencing for `torch.autograd.Function` we defined, to propagate the original dimensions across the PythonOp at the best efforts. But the 2rd one is not covered yet, we could refine that later. Fixing 1st one is enough for ORTModule training/evaluation. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> |
||
|
|
6e6f582e08
|
Use full qualified name for PythonOp export (#17021)
### Use full qualified name for PythonOp export Originally, when there are duplicate named torch.autograd.Function in different module, for example: `a.b.c.Gelu` v.s. `d.e.func.<locals>.Gelu` We by default will throw exception to let user be aware we cannot distinguish the two Gelu because during model export, we did not module path. The workaround is we introduced `ORTMODULE_SKIPPED_AUTOGRAD_FUNCTIONS` to ignore those duplicated named Gelu that is not used by model run. This has limitations obviously for example if two Gelus are both used in training. This PR finds a way to construct a full qualified name. `def _export_pt_1_10(g, n, *args, **kwargs):` 1. in exporter function, kwargs contains `name` and `module`, in the above example: `a.b.c.Gelu` --> name: `Gelu`, module: `a.b.c` `d.e.func.<locals>.Gelu` --> name: `Gelu`, module: `d.e` Using name and module is not enough to get a full qualified name, for the second case, where `d.e` is the module path, then there is a function called `func`, in this function, there is a local auto.grad.Function named `Gelu`. (Many of our UT looks like this). We can only get `d.e.Gelu`, but this is not the correct full qual name. The reason for this: `kwargs[name]` or `n.name` only return the class's name, not the class's full qual name. (be noted kwargs[module]` is correct). 2. `n` is torch.Node, we can access `pyobj` to get the torch.autograd.Function's apply method instance, then use `._self` to get the torch.autograd.Function class. Then we can get the `module` and `class`'s ful qual name, added together, we get the full qual name. With the above change, we don't need use `kwargs[name]` and `kwargs[module]` , and don't need check naming conflicting or `ORTMODULE_SKIPPED_AUTOGRAD_FUNCTIONS` env var any more. |
||
|
|
a6887f171f
|
Refactor schema extraction and output unflattening (#16894)
### Motivation and Context When we handle PyTorch models' inputs in different places (ORTModule or others), it's common for us to flatten a structured data into a 1-D tensor list (required by lib for example torch.onnx.export, torch.autograd.Function.forward or ORT inference session), then do subsequent work, then unflatten back to original hierarchy as returned values. DeepStage3 hooks support work also need such a lib to do similar things, so I was proposing to extract this pair of APIs in training/utils/, which can be more used more generally. Also a comprehensive set of test data are used for testing unflatten/flatten in unit tests. Let me know if you have any other suggestions. ### Refactor schema extraction and output unflattening Move `_extract_schema` and `unflatten_user_output` in `orttraining/orttraining/python/training/ortmodule/_io.py` . to `extract_data_and_schema` and `unflatten_data_using_schema` in `orttraining/orttraining/python/training/utils/torch_io_helper.py` as shared libs, which can be used later by other features (deepspeed stage 3 hook rewrite). While there are still a few duplicated logic handling flatten with different task by recursively loop the data struct, will change them step by step in case of heavy review efforts. |
||
|
|
735a32fee1
|
Introduce memory observer for ORTModule (#16213)
### Introduce memory observer for ORTModule To analyze memory usage for ORTModule training, we need collect per-iteration memory footprint in different stages (pre-forward, post-forward, pre-backward, and post-backward). Currently we only collect the data using torch.cuda APIs. The next step is, we could collect the detailed stashed activation list and its percentage within ORT backend, which is beyond this PR. Sample as below: ``` 0/8] step 0 memory (MiB) | phase: pre_forward | allocated: 1866 | max allocated: 1866 | cached: 1874 | max cached: 1874 | inactive: 8 | max inactive: 8 [0/8] step 0 memory (MiB) | phase: post_forward | allocated: 23277 | max allocated: 26215 | cached: 26406 | max cached: 26406 | inactive: 193 | max inactive: 405 [0/8] step 0 memory (MiB) | phase: pre_backward | allocated: 23277 | max allocated: 26215 | cached: 26406 | max cached: 26406 | inactive: 193 | max inactive: 405 [0/8] step 0 memory (MiB) | phase: post_backward | allocated: 2932 | max allocated: 26215 | cached: 26406 | max cached: 26406 | inactive: 6158 | max inactive: 6158 0%|█ | 1/200 [00:26<1:26:18, 26.02s/it] [0/8] step 1 memory (MiB) | phase: pre_forward | allocated: 2356 | max allocated: 26215 | cached: 26406 | max cached: 26406 | inactive: 2454 | max inactive: 6165 [0/8] step 1 memory (MiB) | phase: post_forward | allocated: 23767 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2639 | max inactive: 6165 [0/8] step 1 memory (MiB) | phase: pre_backward | allocated: 23767 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2639 | max inactive: 6165 [0/8] step 1 memory (MiB) | phase: post_backward | allocated: 3422 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 5284 | max inactive: 6165 1%|██ | 2/200 [00:26<36:47, 11.15s/it] [0/8] step 2 memory (MiB) | phase: pre_forward | allocated: 2356 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2454 | max inactive: 6165 [0/8] step 2 memory (MiB) | phase: post_forward | allocated: 23767 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2639 | max inactive: 6165 [0/8] step 2 memory (MiB) | phase: pre_backward | allocated: 23767 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 2639 | max inactive: 6165 [0/8] step 2 memory (MiB) | phase: post_backward | allocated: 3422 | max allocated: 26705 | cached: 29342 | max cached: 29342 | inactive: 5284 | max inactive: 6165 ``` |
||
|
|
40bcc0441b
|
Enhance StatisticsSubscriber (#16098)
### Enhance StatisticsSubscriber There are few improvements for `StatisticsSubscriber`: - Reduce peak memory impact for tensors (having many many many elements, consuming too much GPU memory, causing original recipe run failed with OOM), by split the statistics into two phases (split into buckets, and merge result across buckets). - Allow dump intermediate tensors. Originally only nn.Module forward()'s return value are dumped, there are requirements we want to inspect some specific intermediate tensor in the forward() function, now we support it. - Add documents for collecting dumps on multiple ranks Docs link on this branch for better view: https://github.com/microsoft/onnxruntime/blob/pengwa/conv_tool_v2/docs/ORTModule_Convergence_Notes.md --------- Co-authored-by: mindest <30493312+mindest@users.noreply.github.com> |
||
|
|
a36caba073
|
Bump ruff in CI (#15533)
### Description Bump ruff version in CI and fixed new lint errors. - This change enables the flake8-implicit-str-concat rules which helps detect unintended string concatenations: https://beta.ruff.rs/docs/rules/#flake8-implicit-str-concat-isc - Update gitignore to include common python files that we want to exclude. ### Motivation and Context Code quality |
||
|
|
938e2136c6
|
Enable pylint and numpy rules (#15218)
### Description Enable pylint and numpy rules ### Motivation and Context Modernize numpy usage and enable more quality checks |
||
|
|
d834ec895a
|
Adopt linrtunner as the linting tool - take 2 (#15085)
### Description `lintrunner` is a linter runner successfully used by pytorch, onnx and onnx-script. It provides a uniform experience running linters locally and in CI. It supports all major dev systems: Windows, Linux and MacOs. The checks are enforced by the `Python format` workflow. This PR adopts `lintrunner` to onnxruntime and fixed ~2000 flake8 errors in Python code. `lintrunner` now runs all required python lints including `ruff`(replacing `flake8`), `black` and `isort`. Future lints like `clang-format` can be added. Most errors are auto-fixed by `ruff` and the fixes should be considered robust. Lints that are more complicated to fix are applied `# noqa` for now and should be fixed in follow up PRs. ### Notable changes 1. This PR **removed some suboptimal patterns**: - `not xxx in` -> `xxx not in` membership checks - bare excepts (`except:` -> `except Exception`) - unused imports The follow up PR will remove: - `import *` - mutable values as default in function definitions (`def func(a=[])`) - more unused imports - unused local variables 2. Use `ruff` to replace `flake8`. `ruff` is much (40x) faster than flake8 and is more robust. We are using it successfully in onnx and onnx-script. It also supports auto-fixing many flake8 errors. 3. Removed the legacy flake8 ci flow and updated docs. 4. The added workflow supports SARIF code scanning reports on github, example snapshot:  5. Removed `onnxruntime-python-checks-ci-pipeline` as redundant ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Unified linting experience in CI and local. Replacing https://github.com/microsoft/onnxruntime/pull/14306 --------- Signed-off-by: Justin Chu <justinchu@microsoft.com> |
||
|
|
1d32285536
|
Statistics tool for ORTModule convergence parity (#15020)
### Statistics tool for ORTModule convergence parity
As ORTModule get more and more validated, it is pretty fast to
intergrade PyTorch based model with ORT.
The same time, we need make sure once there is convergence issue, we
don't spend months of time to investigate. As part of this efforts, this
PR is introducing a tool to dump activation statistics without much
involvement from users. The dumping results contains only some statistic
numbers plus sampled data, which is not big, compared with dumping all
the tensors, it is much faster and space efficient.
For us to use it, two single lines are needed before wrapping ORTModule.
For baseline run, need also apply the same trick.
```
+ from onnxruntime.training.utils.hooks import SubscriberManager, StatisticsSubscriber
+ SubscriberManager.subscribe(model, [StatisticsSubscriber("pt_out", override_output_dir=True)])
```
Once you run the steps, following command can be used to merge result
into per-step-summary respectively for ORT and baseline runs.
```bash
python -m onnxruntime.training.utils.hooks.merge_activation_summary --pt_dir pt_out --ort_dir ort_out --output_dir /tmp/output
```
Docs is added here as part of this PR [convergence investigation
notes](https://github.com/microsoft/onnxruntime/blob/pengwa/conv_tool/docs/ORTModule_Convergence_Notes.md)
Based on the generated merged files, we can compare them with tools.

### Design and Implementation
This PR introduced a common mechanism registering custom logic for
nn.Module's post forward hooks. And statistics for activation
(StatisticsSubscriber) is one of the implementations. If there is other
needs, we can define another XXSubscriber to do the customized things.
|
||
|
|
fdce4fa6af
|
Format all python files under onnxruntime with black and isort (#11324)
Description: Format all python files under onnxruntime with black and isort. After checking in, we can use .git-blame-ignore-revs to ignore the formatting PR in git blame. #11315, #11316 |
||
|
|
7691e7ed12
|
Introduce load balancing dataset samplers (#10163) |