onnxruntime/orttraining/orttraining/python/training
Prathik Rao 11ad299451
Adds ATen fallback for scaled_dot_product_attention (#21107)
### Description
<!-- Describe your changes. -->

Introduces an ATen fallback for
`torch.nn.functional.scaled_dot_product_attention`. This operator was
introduced in torch 2.0 and, since then, has had many updates including
the implementation of memory efficient attention for V100 machines. The
current torchscript exporter exports a subgraph for attention which does
not provide the same memory savings that PyTorch's memory efficient
attention kernel provides. Allowing fallback to PyTorch ATen op for
attention helps mitigate memory spike issues for models leveraging
memory efficient attention.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Memory issues arose when integrating ONNX Runtime Training with AML
Stable Diffusion.

---------

Co-authored-by: root <prathikrao@microsoft.com>
2024-07-22 16:37:04 -07:00
..
amp [Better Engineering] Bump ruff to 0.0.278 and fix new lint errors (#16789) 2023-07-21 12:53:41 -07:00
api Introduce a Nominal Checkpoint for On-Device Training (#19232) 2024-01-30 22:11:25 -08:00
experimental Manage ORTModule configurations consistently (#16396) 2023-06-27 19:19:36 +08:00
onnxblock Fix typos according to reviewdog report. (#21335) 2024-07-22 13:37:32 -07:00
optim Fix typos - 1st Wave (#21278) 2024-07-11 13:35:08 +08:00
ort_triton Support BFloat16 for Triton Codegen (#20353) 2024-04-18 17:15:11 +08:00
ortmodule Adds ATen fallback for scaled_dot_product_attention (#21107) 2024-07-22 16:37:04 -07:00
utils ORTModule GraphTransitionManager (#19007) 2024-07-03 10:53:31 +08:00
__init__.py Bump ruff to 0.3.2 and black to 24 (#19878) 2024-03-13 10:00:32 -07:00
_utils.py Removed all the deprecated python training code and related tests and utils (#18333) 2023-11-17 18:19:21 -08:00
artifacts.py Enable >2GB models + allow model paths to be passed for generate_artifacts API (#20958) 2024-06-21 09:55:26 -07:00