Adds ATen fallback for scaled_dot_product_attention (#21107)

### Description
<!-- Describe your changes. -->

Introduces an ATen fallback for
`torch.nn.functional.scaled_dot_product_attention`. This operator was
introduced in torch 2.0 and, since then, has had many updates including
the implementation of memory efficient attention for V100 machines. The
current torchscript exporter exports a subgraph for attention which does
not provide the same memory savings that PyTorch's memory efficient
attention kernel provides. Allowing fallback to PyTorch ATen op for
attention helps mitigate memory spike issues for models leveraging
memory efficient attention.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Memory issues arose when integrating ONNX Runtime Training with AML
Stable Diffusion.

---------

Co-authored-by: root <prathikrao@microsoft.com>
This commit is contained in:
Prathik Rao 2024-07-22 16:37:04 -07:00 committed by GitHub
parent 5b9369e93c
commit 11ad299451
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 156 additions and 1 deletions

View file

@ -304,6 +304,16 @@ A classical usage of disabling the deep copy: when the deep copy before module e
export ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT=0 # Disable
```
#### ORTMODULE_ATEN_SDPA_FALLBACK
- **Feature Area**: *ORTMODULE/Optimizations*
- **Description**: By default, this is disabled. This env var can be used for enabling pre-export attention fall back to PyTorch's [_scaled_dot_product_efficient_attention](https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14778) ATen kernel for execution when calling torch.nn.functional.scaled_dot_product_attention. NOTE: only use this feature if user model leverages memory efficient attention WITHOUT masking (ie. attn_mask=None). Utilize GPU profiling looks like NVIDIA Nsight Systems to identify if user model leverages memory efficient attention.
```bash
export ORTMODULE_ATEN_SDPA_FALLBACK=1 # ENABLE
unset ORTMODULE_ATEN_SDPA_FALLBACK # DISABLE
```
### 2.2 Memory Optimization
Q: *Want to run a bigger batch size?*

View file

@ -1794,7 +1794,20 @@ IMPLEMENT_GRADIENT_BUILDER(GetExternalGradient) {
}
std::vector<ArgDef> output_args;
for (const auto& output : node_def.outputs) {
for (size_t output_index = 0; output_index < node_def.outputs.size(); ++output_index) {
// If the input is not used in the forward computation, we don't need it for gradient computation
// Required for ORTMODULE_ATEN_SDPA_FALLBACK
if (static_cast<int>(output_index) >= GetSrcNodeInputSize()) {
continue;
}
if (!IsGradientRequiredForSrcNodeInput(static_cast<int>(output_index))) {
output_args.emplace_back(ArgDef());
continue;
}
const auto& output = node_def.outputs[output_index];
if (output.find("GI(") == 0) {
size_t index = static_cast<size_t>(std::stoi(output.substr(3, output.length() - 4)));
output_args.emplace_back(GI(index));

View file

@ -25,6 +25,7 @@
# 'is_tensor' is optional, if not present, the default is False.
import json
import os
from onnxruntime.capi import _pybind_state as C
@ -276,3 +277,39 @@ def upsample_nearest3d_gradient():
@register_gradient("org.pytorch.aten", "ATen", "upsample_bicubic2d", "vec")
def upsample_bicubic2d_gradient():
return _upsample_gradient("upsample_bicubic2d_backward", 2)
ATEN_SDPA_FALLBACK = os.getenv("ORTMODULE_ATEN_SDPA_FALLBACK", None)
if ATEN_SDPA_FALLBACK:
# based on the following internal PyTorch kernel for efficient attention:
# https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14784
@register_gradient("org.pytorch.aten", "ATen", "_scaled_dot_product_efficient_attention", "")
def scaled_dot_product_attention_gradient():
return [
(
"Constant",
[],
["grad_input_mask"],
{"value": {"value": [1, 1, 1, 0], "dtype": "int", "is_tensor": True}},
),
(
("ATen", "org.pytorch.aten"),
[
"GO(0)",
"I(0)",
"I(1)",
"I(2)",
"I(3)",
"O(0)",
"O(1)",
"O(2)",
"O(3)",
"I(5)",
"grad_input_mask",
"I(6)",
"I(7)",
],
["GI(0)", "GI(1)", "GI(2)", ""],
{"operator": {"value": "_scaled_dot_product_efficient_attention_backward", "dtype": "string"}},
),
]

View file

@ -3,6 +3,7 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os
from typing import Callable
import torch
@ -969,3 +970,26 @@ def softmax(g, input, dim, dtype=None):
softmax = g.op("Softmax", casted_input, axis_i=dim)
return softmax
ATEN_SDPA_FALLBACK = os.getenv("ORTMODULE_ATEN_SDPA_FALLBACK", None)
if ATEN_SDPA_FALLBACK:
# based on the following internal PyTorch kernel for efficient attention:
# https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14778
@register_symbolic("scaled_dot_product_attention")
def scaled_dot_product_attention(g, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
dropout_p_f = g.op("Cast", dropout_p, to_i=torch.onnx.TensorProtoDataType.FLOAT)
compute_logsumexp = g.op("Constant", value_t=torch.tensor([1], dtype=torch.bool))
return g.op(
"org.pytorch.aten::ATen",
query,
key,
value,
attn_mask,
compute_logsumexp,
dropout_p_f,
is_causal,
scale,
operator_s="_scaled_dot_product_efficient_attention",
outputs=4,
)[0]

View file

@ -6953,3 +6953,74 @@ def test_layerwise_recompute_pythonop_determinstic():
else:
if "ORTMODULE_MEMORY_OPT_LEVEL" in os.environ:
del os.environ["ORTMODULE_MEMORY_OPT_LEVEL"]
@pytest.mark.skipif(
Version(torch.__version__) < Version("2.3.0"),
reason="torch.nn.attention module was introduced in PyTorch 2.3.0",
)
def test_aten_attention():
from torch.nn.attention import SDPBackend, sdpa_kernel
class _NeuralNetAttention(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, q, k, v, attn_mask=None):
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask)
def gen_inputs(device, dtype):
return [
torch.randn(32, 8, 128, 64, dtype=dtype, device=device, requires_grad=True),
torch.randn(32, 8, 128, 64, dtype=dtype, device=device, requires_grad=True),
torch.randn(32, 8, 128, 64, dtype=dtype, device=device, requires_grad=True),
]
def run_step(model, inputs, attn_mask=None):
prediction = model(*inputs, attn_mask)
prediction.sum().backward()
return prediction
device = "cuda"
os.environ["ORTMODULE_ATEN_SDPA_FALLBACK"] = "1" # TESTING WITHOUT ATTN_MASK
pt_model = _NeuralNetAttention().to(device)
ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix="mem_eff_attn"))
# reset manual seed to reset the generator
torch.manual_seed(2333)
pt_input = gen_inputs(device=device, dtype=torch.float32)
ort_input = copy.deepcopy(pt_input)
pt_prediction = run_step(pt_model, pt_input)
ort_prediction = run_step(ort_model, ort_input)
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
_test_helpers.assert_values_are_close(ort_input[0].grad, pt_input[0].grad)
_test_helpers.assert_values_are_close(ort_input[1].grad, pt_input[1].grad)
_test_helpers.assert_values_are_close(ort_input[2].grad, pt_input[2].grad)
execution_mgr = ort_model._torch_module._execution_manager._training_manager
from onnxruntime.training.ortmodule._onnx_models import _get_onnx_file_name
path = os.path.join(
execution_mgr._debug_options.save_onnx_models.path,
_get_onnx_file_name(
execution_mgr._debug_options.save_onnx_models.name_prefix, "execution_model", execution_mgr._export_mode
),
)
onnx_model = onnx.load(path)
onnx_nodes = onnx_model.graph.node
mem_eff_attn_nodes = 0
for node in onnx_nodes:
if "ATen" in node.name:
for attr in node.attribute:
if b"_scaled_dot_product_efficient_attention" in attr.s:
mem_eff_attn_nodes += 1
assert mem_eff_attn_nodes > 0, "No mem_eff_attn nodes are found"
del os.environ["ORTMODULE_ATEN_SDPA_FALLBACK"]