mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[CUDA] enable causal in MultiHeadAttention (#21852)
### Description Enable causal in MultiHeadAttention cuda operator. All formats (Q_K_V_BSNH_BSNH_BSNH, Q_K_V_BSNH_BNSH_BNSH, Q_KV_BSNH_BSN2H and QKV_BSN3H) supports causal for now. Internally, casual will be dispatch to flash attention, efficient attention or unfused attention kernel. ### Motivation and Context Currently, MultiHeadAttention has causal enabled in CPU ep, but not in CUDA ep. It could cause issues in onnx conversion, like some model can run in CPU but not in CUDA. Enable causal in CUDA will reduce the difference of support matrix of CPU/CUDA.
This commit is contained in:
parent
d9c57ac7db
commit
ad382120fe
4 changed files with 37 additions and 28 deletions
|
|
@ -46,8 +46,6 @@ MultiHeadAttention<T>::MultiHeadAttention(const OpKernelInfo& info)
|
|||
|
||||
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);
|
||||
is_unidirectional_ = info.GetAttrOrDefault<int64_t>("unidirectional", 0) == 1;
|
||||
ORT_ENFORCE(!is_unidirectional_,
|
||||
"MHA support CUDA kernel does not Unidirectional. Consider using Attention or GQA instead.");
|
||||
|
||||
kernel_options_ = this->GetAttentionKernelOptions();
|
||||
|
||||
|
|
@ -208,13 +206,13 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
bool use_fused_cross_attention =
|
||||
kernel_type == AttentionKernelType::AttentionKernel_Default &&
|
||||
!disable_fused_cross_attention_ &&
|
||||
!is_unidirectional_ &&
|
||||
nullptr == key_padding_mask &&
|
||||
nullptr == attention_bias &&
|
||||
nullptr == past_key && nullptr == present_key &&
|
||||
(parameters.qkv_format == Q_K_V_BSNH || (parameters.qkv_format == Q_KV_BSNH_BSN2H && bias == nullptr)) &&
|
||||
parameters.hidden_size == parameters.v_hidden_size &&
|
||||
has_fused_cross_attention_kernel(sm, parameters.head_size,
|
||||
parameters.kv_sequence_length);
|
||||
has_fused_cross_attention_kernel(sm, parameters.head_size, parameters.kv_sequence_length);
|
||||
if (use_fused_cross_attention) {
|
||||
if (fused_fp16_cross_attention_kernel_ == nullptr) {
|
||||
std::call_once(fused_cross_init_once_flag_, [&]() {
|
||||
|
|
@ -233,6 +231,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
bool use_fused_runner =
|
||||
kernel_type == AttentionKernelType::AttentionKernel_Default &&
|
||||
!disable_fused_self_attention_ &&
|
||||
!is_unidirectional_ &&
|
||||
nullptr == attention_bias &&
|
||||
(parameters.qkv_format == Q_K_V_BSNH || parameters.qkv_format == QKV_BSN3H) &&
|
||||
nullptr == past_key && nullptr == present_key &&
|
||||
|
|
@ -240,13 +239,12 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
parameters.hidden_size == parameters.v_hidden_size &&
|
||||
parameters.sequence_length == parameters.kv_sequence_length && // self attention only for fused runner
|
||||
FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length,
|
||||
enable_trt_flash_attention_, false);
|
||||
enable_trt_flash_attention_, is_unidirectional_);
|
||||
if (use_fused_runner) {
|
||||
// Here we assume that num_heads and head_size does not change for a MultiHeadAttention node.
|
||||
if (nullptr == fused_fp16_runner_.get()) {
|
||||
constexpr bool is_unidirectional = false;
|
||||
std::call_once(fused_fp16_runner_created_, [&]() {
|
||||
fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional,
|
||||
fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_,
|
||||
enable_trt_flash_attention_, parameters.scale);
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -304,7 +304,7 @@ class CudaSession:
|
|||
tensor.data_ptr(),
|
||||
)
|
||||
|
||||
def infer(self, feed_dict: Dict[str, torch.Tensor], run_options: RunOptions = None, synchronize: bool = False):
|
||||
def infer(self, feed_dict: Dict[str, torch.Tensor], run_options: RunOptions = None, synchronize: bool = True):
|
||||
"""Bind input tensors and run inference"""
|
||||
for name, tensor in feed_dict.items():
|
||||
assert isinstance(tensor, torch.Tensor) and tensor.is_contiguous()
|
||||
|
|
@ -317,7 +317,6 @@ class CudaSession:
|
|||
else:
|
||||
self.bind_input_and_buffer_sharing(name, tensor)
|
||||
|
||||
# Synchronization are not needed in most cases unless different streams are used or inputs/outputs are in CPU.
|
||||
if synchronize:
|
||||
self.io_binding.synchronize_inputs()
|
||||
self.ort_session.run_with_iobinding(self.io_binding, run_options)
|
||||
|
|
|
|||
|
|
@ -587,8 +587,8 @@ class OrtMultiHeadAttention:
|
|||
self.ort_session = create_session(config, session_options, use_tf32=use_tf32)
|
||||
self.feed_dict = config.random_inputs()
|
||||
|
||||
def infer(self):
|
||||
return self.ort_session.infer(self.feed_dict)
|
||||
def infer(self, run_options=None, synchronize=True):
|
||||
return self.ort_session.infer(self.feed_dict, run_options=run_options, synchronize=synchronize)
|
||||
|
||||
|
||||
def measure_latency(cuda_session: CudaSession, input_dict):
|
||||
|
|
@ -1356,7 +1356,6 @@ if __name__ == "__main__":
|
|||
args.repeats = 10000 if args.use_gpu else 100
|
||||
|
||||
if args.use_gpu:
|
||||
assert args.torch or not args.causal, "no causal cuda kernel in MHA op"
|
||||
assert torch.cuda.is_available()
|
||||
if not args.torch:
|
||||
assert "CUDAExecutionProvider" in get_available_providers()
|
||||
|
|
|
|||
|
|
@ -68,6 +68,22 @@ def get_bias_support(format: InputFormats):
|
|||
raise RuntimeError(f"Unknown format: {format}")
|
||||
|
||||
|
||||
def get_causal_support(format: InputFormats):
|
||||
if format == InputFormats.Q_K_V_BSNH_BSNH_BSNH:
|
||||
return [True, False]
|
||||
|
||||
if format == InputFormats.Q_K_V_BSNH_BNSH_BNSH:
|
||||
return [True, False]
|
||||
|
||||
if format == InputFormats.Q_KV_BSNH_BSN2H:
|
||||
return [True, False]
|
||||
|
||||
if format == InputFormats.QKV_BSN3H:
|
||||
return [True, False]
|
||||
|
||||
raise RuntimeError(f"Unknown format: {format}")
|
||||
|
||||
|
||||
def get_atten_bias_support():
|
||||
atten_bias_options = [
|
||||
# (has_attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1)
|
||||
|
|
@ -215,7 +231,7 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool):
|
|||
for num_heads in heads:
|
||||
for head_size in head_sizes:
|
||||
for format in formats:
|
||||
for causal in [True, False]:
|
||||
for causal in get_causal_support(format):
|
||||
for mask_format in mask_formats:
|
||||
for has_bias in get_bias_support(format):
|
||||
for (
|
||||
|
|
@ -256,8 +272,8 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool):
|
|||
has_attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1 = atten_bias_options[
|
||||
i % len(atten_bias_options)
|
||||
]
|
||||
for causal in [True, False]:
|
||||
for format in formats:
|
||||
for format in formats:
|
||||
for causal in get_causal_support(format):
|
||||
for has_bias in get_bias_support(format):
|
||||
config = MultiHeadAttentionConfig(
|
||||
batch_size=batch_size,
|
||||
|
|
@ -308,7 +324,7 @@ def kv_cache_test_cases(provider: str, comprehensive: bool):
|
|||
for num_heads in heads:
|
||||
for head_size in head_sizes:
|
||||
for format in formats:
|
||||
for causal in [True, False]:
|
||||
for causal in get_causal_support(format):
|
||||
for has_past_input in [True, False]:
|
||||
for mask_format in mask_formats:
|
||||
for has_bias in get_bias_support(format):
|
||||
|
|
@ -353,8 +369,8 @@ def kv_cache_test_cases(provider: str, comprehensive: bool):
|
|||
has_attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1 = atten_bias_options[
|
||||
i % len(atten_bias_options)
|
||||
]
|
||||
for causal in [True, False]:
|
||||
for format in formats:
|
||||
for format in formats:
|
||||
for causal in get_causal_support(format):
|
||||
for has_past_input in [True, False]:
|
||||
for has_bias in get_bias_support(format):
|
||||
sequence_length = 1 if has_past_input else past_sequence_length
|
||||
|
|
@ -397,7 +413,7 @@ def no_kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool):
|
|||
device, dtype, formats = get_provider_support_info(provider, False)
|
||||
|
||||
for format in formats:
|
||||
for causal in [True, False]:
|
||||
for causal in get_causal_support(format):
|
||||
for num_heads in heads:
|
||||
for head_size in head_sizes:
|
||||
configs = [] # list of configurations to run in parallel
|
||||
|
|
@ -437,7 +453,7 @@ def kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool):
|
|||
device, dtype, formats = get_provider_support_info(provider, True)
|
||||
|
||||
for format in formats:
|
||||
for causal in [True, False]:
|
||||
for causal in get_causal_support(format):
|
||||
for num_heads in heads:
|
||||
for head_size in head_sizes:
|
||||
configs = []
|
||||
|
|
@ -494,12 +510,8 @@ def parity_check_mha(
|
|||
rtol=1e-3,
|
||||
atol=1e-3,
|
||||
):
|
||||
# CUDA kernel does not support causal so skip such test cases.
|
||||
if config.causal and config.provider == "CUDAExecutionProvider":
|
||||
return
|
||||
|
||||
ort_mha = OrtMultiHeadAttention(config, use_tf32=False)
|
||||
ort_outputs = ort_mha.infer()
|
||||
ort_outputs = ort_mha.infer(synchronize=True)
|
||||
out = ort_outputs["output"]
|
||||
out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size))
|
||||
|
||||
|
|
@ -602,9 +614,6 @@ def parity_check_mha_multi_threading(
|
|||
):
|
||||
# Use the first config to create a session, which is shared by all configs to run in parallel.
|
||||
config = test_inputs[0]["config"]
|
||||
# For now, MHA CUDA kernel does not support causal so skip such test cases.
|
||||
if config.causal and config.provider == "CUDAExecutionProvider":
|
||||
return None
|
||||
|
||||
# Some kernel does not support certain input format.
|
||||
if attention_kernel not in [
|
||||
|
|
@ -784,6 +793,10 @@ class TestMultiHeadAttention(unittest.TestCase):
|
|||
|
||||
def run_mha_cuda_multi_threading(self, attention_kernel):
|
||||
for configs in multi_thread_test_cases("CUDAExecutionProvider", comprehensive_mode):
|
||||
if configs and configs[0].causal and (SdpaKernel.TRT_CAUSAL_ATTENTION & attention_kernel != 0):
|
||||
# TRT fused causal is disabled by default so skip the test of causal for multi-threading.
|
||||
continue
|
||||
|
||||
test_inputs = []
|
||||
for config in configs:
|
||||
ort_inputs = config.random_inputs()
|
||||
|
|
|
|||
Loading…
Reference in a new issue