onnxruntime/orttraining/orttraining/python/training/ortmodule
guyang3532 341484e67c
Embedding sparsity optimization (#16141)
### Description
Optimize compute graph by eliminating padding in embedding.


### Motivation and Context
The computation for padding in nodes after embedding is unnecessary and
waste computation resources.
This pr just add an Optimizer of PaddingElimination to check and
eliminate the padding after embedding automatically by modifying the
graph.

### Implementation:
1. Find and check embedding node in graph.
2. Iterate the subgraph afterward the embedding node and record all the
input nodes and output nodes to this subgraph.
3. Insert 'Reshape + ShrunkenGather' to flatten each input node shape
from [batch_size, seqlen, ...] to [valid_token_without_padding, ...],
and insert 'GatherGrad + Reshape' to unflatten each output node shape
from [valid_token_without_padding, ...] to [batch_size, seqlen, ...]

---------

Co-authored-by: mindest <linminuser@gmail.com>
2023-06-19 20:34:53 +08:00
..
experimental Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
torch_cpp_extensions Run clang-format in CI (#15524) 2023-04-18 09:26:58 -07:00
__init__.py Bump ruff in CI (#15533) 2023-04-17 10:11:44 -07:00
_custom_autograd_function.py Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
_custom_autograd_function_exporter.py Consolidate ORTModule logging (#16078) 2023-06-01 10:09:12 +08:00
_custom_autograd_function_runner.py Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
_custom_gradient_registry.py Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
_custom_op_symbolic_registry.py Consolidate ORTModule logging (#16078) 2023-06-01 10:09:12 +08:00
_execution_agent.py Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
_fallback.py Consolidate ORTModule logging (#16078) 2023-06-01 10:09:12 +08:00
_fallback_exceptions.py Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
_gradient_accumulation_manager.py Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
_graph_execution_interface.py Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
_graph_execution_manager.py Embedding sparsity optimization (#16141) 2023-06-19 20:34:53 +08:00
_graph_execution_manager_factory.py Consolidate ORTModule logging (#16078) 2023-06-01 10:09:12 +08:00
_inference_manager.py Consolidate ORTModule logging (#16078) 2023-06-01 10:09:12 +08:00
_io.py Consolidate ORTModule logging (#16078) 2023-06-01 10:09:12 +08:00
_logger.py Consolidate ORTModule logging (#16078) 2023-06-01 10:09:12 +08:00
_onnx_models.py Type hint for ORTModule (#15938) 2023-05-25 09:28:20 +08:00
_runtime_inspector.py Introduce memory observer for ORTModule (#16213) 2023-06-15 15:45:36 +08:00
_torch_module_factory.py Consolidate ORTModule logging (#16078) 2023-06-01 10:09:12 +08:00
_torch_module_interface.py Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
_torch_module_ort.py Consolidate ORTModule logging (#16078) 2023-06-01 10:09:12 +08:00
_torch_module_pytorch.py Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
_training_manager.py Introduce memory observer for ORTModule (#16213) 2023-06-15 15:45:36 +08:00
_utils.py Consolidate ORTModule logging (#16078) 2023-06-01 10:09:12 +08:00
debug_options.py Enable conditional optimization automatically (#15885) 2023-05-23 13:08:05 +08:00
ortmodule.py Consolidate ORTModule logging (#16078) 2023-06-01 10:09:12 +08:00