onnxruntime/orttraining/orttraining/python
pengwa 516c8e95fa
Optimize SCE loss compute (#15401)
### Optimize SCE loss compute

Compute optimization based on label data sparsity:
- Insert ShrunkenGather before SCELoss node, to filter out invalid
labels for compute.
- Support ShrunkenGather upstream.
- Added test for the above.
- Added flag to enable label sparsity optimization with env var, by
default disabled now. Will enable after comprehensive benchmarking
later.
- Extract common logic into test_optimizer_utils.h/cc from
core/optimizer/compute_optimzier_test.cc, then the common functions can
be shared by both core/optimizer/compute_optimzier_test.cc and
orttraining/core/optimizer/compute_optimzier_test.cc
- Extract common logic into shared_utils.h/cc: `GetONNXOpSetVersion` and
`Create1DInitializerFromVector`


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2023-04-13 13:02:12 +08:00
..
deprecated Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
training Optimize SCE loss compute (#15401) 2023-04-13 13:02:12 +08:00
checkpointing_utils.py Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
ort_trainer.py Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
orttraining_pybind_common.h Re-work global objects dependancies in pybind layer. (#14941) 2023-03-10 13:55:31 -08:00
orttraining_pybind_state.cc Optimize SCE loss compute (#15401) 2023-04-13 13:02:12 +08:00
orttraining_python_module.cc Delete eager mode code and increase minimal required python version to 3.8 (#15450) 2023-04-10 16:00:04 -07:00
orttraining_python_module_eager.h Abjindal/clean eager backend (#10055) 2022-01-19 14:20:09 -08:00
pt_patch.py Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00