onnxruntime/orttraining/orttraining/python/training
guyang3532 471e969e2f
Check padding density by input of embedding module (#19821)
### Description
The PaddingElimination optimization is enabled when the density of
embedding padding less than 90%. We need to check the density of the
embedding padding to decide whether enable the optimization.

Before this pr, we just check the inputs of graph and correlate one with
the embedding node by iterate graph from the embedding node back to one
graph input.
This is hard to be general because there may be complicated pattern
between graph input and embedding node.

This pr check padding density by the direct input of embedding module
rather than the input of graph at the first graph execution when
exporting onnx graph.
And if the density < 90%, insert a flag PythonOp after the embedding
node as:
```
             Embedding
		  |
            PythonOp (func_name:_FlagPaddingElimination)   (insert if density < 90%)
		  |
            Following graph
```

When the PaddingElimination is invoked, it check if there is the flag
PythonOp(func_name:_FlagPaddingElimination) after the Embedding node and
if it is, remove it and do the padding elimination optimization.
2024-04-10 18:45:51 +08:00
..
amp
api Introduce a Nominal Checkpoint for On-Device Training (#19232) 2024-01-30 22:11:25 -08:00
experimental
onnxblock Introduce a Nominal Checkpoint for On-Device Training (#19232) 2024-01-30 22:11:25 -08:00
optim Bump ruff to 0.3.2 and black to 24 (#19878) 2024-03-13 10:00:32 -07:00
ort_triton Add Symbolic Shape Hint to Triton Codegen Config (#20056) 2024-03-25 15:05:02 +08:00
ortmodule Check padding density by input of embedding module (#19821) 2024-04-10 18:45:51 +08:00
utils Fix and enable few ORTModule Unit Tests (#19847) 2024-03-12 10:49:19 +08:00
__init__.py Bump ruff to 0.3.2 and black to 24 (#19878) 2024-03-13 10:00:32 -07:00
_utils.py
artifacts.py Add support for SGD optimizer in minimal build (#19901) 2024-03-14 11:31:20 -07:00