onnxruntime/orttraining/orttraining/python/training/ortmodule
guyang3532 cfe830b248
Generalize label input sparsity check and refactor (#20636)
### Description
The InsertGatherBeforeSceLoss optimization is enabled when the density
of label padding less than 90%. We need to check the density of the
label padding to decide whether enable the optimization.

Before this pr, we just check the inputs of graph and correlate one with
the SCE node by iterate graph from the SCE node back to one graph input.
This is hard to be general because there may be complicated pattern
between graph input and SCE node.

This pr check padding density by the direct input of SCE module rather
than the input of graph at the first graph execution when exporting onnx
graph.
And if the density < 90%, insert a flag PythonOp after the SCE node as:
```
           SoftmaxCrossEntropy
		  |
            PythonOp (func_name: FlagAndPrintDensity)   (insert if density < 90%)
		  |
            Following graph
```

When the InsertGatherBeforeSceLoss is invoked, it check if there is the
flag PythonOp(func_name: FlagAndPrintDensity) after the SCE node and if
it is, remove it and do the padding elimination optimization.

If the env of ORTMODULE_PRINT_INPUT_DENSITY is 1, we will print input
density each step by the PythonOp (func_name: FlagAndPrintDensity). In
this case the PythonOp will not be removed.
2024-05-10 21:55:43 +08:00
..
experimental Add LayerSpec Support to ORTPipelineModule (#20410) 2024-04-23 17:57:08 -07:00
graph_optimizers ATen Op Supports Int Return Type and CPU Tensor Arguments (#19773) 2024-03-06 10:11:46 +08:00
torch_cpp_extensions Improve perf for mem efficient grad mgmt (#20480) 2024-05-10 08:09:17 +08:00
__init__.py Prompt layer-wise recompute when applicable (#20126) 2024-04-10 11:50:28 +08:00
_custom_autograd_function.py Improve perf for stage3 training (#18099) 2023-12-15 13:32:19 +08:00
_custom_autograd_function_exporter.py Prompt layer-wise recompute when applicable (#20126) 2024-04-10 11:50:28 +08:00
_custom_gradient_registry.py Bump ruff linter to 0.2.1 (#19471) 2024-02-08 16:08:27 -08:00
_custom_op_symbolic_registry.py add bf16 support for few ops (#20385) 2024-04-25 11:28:34 -07:00
_execution_agent.py Memory optimization refactor and refinement (#17481) 2023-11-23 11:39:00 +08:00
_fallback.py Bump linter versions (#18341) 2023-11-08 13:04:40 -08:00
_fallback_exceptions.py Bump ruff to 0.3.2 and black to 24 (#19878) 2024-03-13 10:00:32 -07:00
_gradient_accumulation_manager.py
_graph_execution_interface.py
_graph_execution_manager.py Generalize label input sparsity check and refactor (#20636) 2024-05-10 21:55:43 +08:00
_graph_execution_manager_factory.py
_inference_manager.py Generalize label input sparsity check and refactor (#20636) 2024-05-10 21:55:43 +08:00
_io.py Generalize label input sparsity check and refactor (#20636) 2024-05-10 21:55:43 +08:00
_logger.py Bump ruff to 0.3.2 and black to 24 (#19878) 2024-03-13 10:00:32 -07:00
_mem_efficient_grad_mgmt.py Improve perf for mem efficient grad mgmt (#20480) 2024-05-10 08:09:17 +08:00
_onnx_models.py ORTModule memory improvement (#18924) 2024-01-16 08:57:37 +08:00
_pythonop_helper.py ORTModule memory improvement (#18924) 2024-01-16 08:57:37 +08:00
_runtime_inspector.py Generalize label input sparsity check and refactor (#20636) 2024-05-10 21:55:43 +08:00
_torch_module_factory.py
_torch_module_interface.py
_torch_module_ort.py
_torch_module_pytorch.py
_training_manager.py Generalize label input sparsity check and refactor (#20636) 2024-05-10 21:55:43 +08:00
_utils.py Bump ruff to 0.3.2 and black to 24 (#19878) 2024-03-13 10:00:32 -07:00
_zero_stage3_compatibility.py Improve perf for stage3 training (#18099) 2023-12-15 13:32:19 +08:00
graph_optimizer_registry.py [ORTModule] ATen Efficient Attention and Triton Flash Attention (#17959) 2023-10-27 10:29:27 +08:00
options.py Generalize label input sparsity check and refactor (#20636) 2024-05-10 21:55:43 +08:00
ortmodule.py