mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-28 03:20:58 +00:00
Disable fused causal attention (#14732)
There is accuracy regression in GPT-2 model. Top1 match rate (vs PyTorch model) drops about 1%. The cause is the fused causal attention uses fp16 accumulation. Disable it by default and add an environment variable ORT_ENABLE_FUSED_CAUSAL_ATTENTION=1 to turn on it manually. It also updated the GPT-2 parity test script to generate left side padding to reflect the actual usage. To test: ``` python -m onnxruntime.transformers.models.gpt2.convert_to_onnx -m gpt2 --output gpt2.onnx -o -p fp16 --use_gpu ``` The top1-match-rate in the output is on-par with ORT 1.13.1.
This commit is contained in:
parent
25e10f413e
commit
c0d2472ede
10 changed files with 55 additions and 40 deletions
|
|
@ -25,7 +25,7 @@ enum AttentionQkvFormat {
|
|||
Q_KV_BSNH_BSN2H, // for TRT fused cross attention, kv are packed
|
||||
};
|
||||
|
||||
enum AttentionKernelType{
|
||||
enum AttentionKernelType {
|
||||
AttentionKernel_Unfused,
|
||||
AttentionKernel_TrtFusedAttention,
|
||||
AttentionKernel_TrtFlashAttention,
|
||||
|
|
@ -38,15 +38,15 @@ enum AttentionKernelType{
|
|||
struct AttentionParameters {
|
||||
int batch_size;
|
||||
int sequence_length;
|
||||
int kv_sequence_length; // input sequence length of K or V
|
||||
int past_sequence_length; // sequence length in past state of K or V
|
||||
int total_sequence_length; // total sequence length of K or V
|
||||
int max_sequence_length; // max sequence length from 4D mask
|
||||
int input_hidden_size; // first dimension of weights for input projection
|
||||
int hidden_size; // hidden size of Q or K
|
||||
int head_size; // hidden size per head of Q or K
|
||||
int v_hidden_size; // hidden size of V
|
||||
int v_head_size; // hidden size per head of V
|
||||
int kv_sequence_length; // input sequence length of K or V
|
||||
int past_sequence_length; // sequence length in past state of K or V
|
||||
int total_sequence_length; // total sequence length of K or V
|
||||
int max_sequence_length; // max sequence length from 4D mask
|
||||
int input_hidden_size; // first dimension of weights for input projection
|
||||
int hidden_size; // hidden size of Q or K
|
||||
int head_size; // hidden size per head of Q or K
|
||||
int v_hidden_size; // hidden size of V
|
||||
int v_head_size; // hidden size per head of V
|
||||
int num_heads;
|
||||
bool is_unidirectional;
|
||||
bool past_present_share_buffer;
|
||||
|
|
@ -56,13 +56,17 @@ struct AttentionParameters {
|
|||
};
|
||||
|
||||
namespace attention {
|
||||
// Environment variable to enable or disable fused self/causal attention kernel. Default is 0 (enabled).
|
||||
constexpr const char* kDisableFusedAttention = "ORT_DISABLE_FUSED_ATTENTION";
|
||||
// Environment variable to enable or disable TRT fused self attention kernel. Default is 0 (enabled).
|
||||
constexpr const char* kDisableFusedSelfAttention = "ORT_DISABLE_FUSED_ATTENTION";
|
||||
|
||||
// Environment variable to enable or disable fused cross attention kernel. Default is 0 (enabled).
|
||||
constexpr const char* kDisableFusedCrossAttention = "ORT_DISABLE_FUSED_CROSS_ATTENTION";
|
||||
|
||||
// Environment variable to enable or disable TRT flash attention. Default is 0 (enabled).
|
||||
// Environment variable to enable or disable TRT fused causal attention kernels. Default is 0 (disabled).
|
||||
// Note that those causal attention kernels use fp16 accumulation. There is potential accuracy drop using those kernels.
|
||||
constexpr const char* kEnableFusedCausalAttention = "ORT_ENABLE_FUSED_CAUSAL_ATTENTION";
|
||||
|
||||
// Environment variable to enable or disable TRT flash attention. This applies to both self and causal attention. Default is 0 (enabled).
|
||||
constexpr const char* kDisableTrtFlashAttention = "ORT_DISABLE_TRT_FLASH_ATTENTION";
|
||||
|
||||
// Environment variable to enable or disable cutlass memory efficient attention. Default is 0 (enabled).
|
||||
|
|
|
|||
|
|
@ -39,12 +39,15 @@ REGISTER_KERNEL_TYPED(MLFloat16)
|
|||
|
||||
template <typename T>
|
||||
Attention<T>::Attention(const OpKernelInfo& info) : CudaKernel(info), AttentionBase(info, false) {
|
||||
disable_fused_runner_ = sizeof(T) != 2 ||
|
||||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedAttention, false);
|
||||
disable_fused_self_attention_ = sizeof(T) != 2 ||
|
||||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedSelfAttention, false);
|
||||
|
||||
enable_trt_flash_attention_ = sizeof(T) == 2 &&
|
||||
!ParseEnvironmentVariableWithDefault<bool>(attention::kDisableTrtFlashAttention, false);
|
||||
|
||||
enable_fused_causal_attention_ = sizeof(T) == 2 &&
|
||||
ParseEnvironmentVariableWithDefault<bool>(attention::kEnableFusedCausalAttention, false);
|
||||
|
||||
#if USE_FLASH_ATTENTION
|
||||
disable_memory_efficient_attention_ = ParseEnvironmentVariableWithDefault<bool>(attention::kDisableMemoryEfficientAttention, false);
|
||||
#else
|
||||
|
|
@ -97,14 +100,13 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
int sm = device_prop.major * 10 + device_prop.minor;
|
||||
bool is_mask_1d_seq_len = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN;
|
||||
|
||||
if (is_unidirectional_) { // GPT
|
||||
if (is_unidirectional_ && enable_fused_causal_attention_) { // GPT
|
||||
// GPT fused kernels requires left side padding. mask can be:
|
||||
// none (no padding), 1D sequence lengths or 2d mask.
|
||||
// Fused kernels don't support different sequence lengths of q and kv, so only apply to the first token
|
||||
// where past state is empty.
|
||||
bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING;
|
||||
bool use_causal_fused_runner = !disable_fused_runner_ &&
|
||||
(nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) &&
|
||||
bool use_causal_fused_runner = (nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) &&
|
||||
nullptr == relative_position_bias &&
|
||||
parameters.past_sequence_length == 0 &&
|
||||
parameters.hidden_size == parameters.v_hidden_size &&
|
||||
|
|
@ -121,7 +123,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
fused_runner = fused_fp16_runner_.get();
|
||||
}
|
||||
} else { // BERT
|
||||
bool use_fused_runner = !disable_fused_runner_ &&
|
||||
bool use_fused_runner = !disable_fused_self_attention_ &&
|
||||
(nullptr == mask_index || is_mask_1d_seq_len) &&
|
||||
nullptr == past &&
|
||||
nullptr == present &&
|
||||
|
|
|
|||
|
|
@ -21,8 +21,9 @@ class Attention final : public CudaKernel, public AttentionBase {
|
|||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
||||
protected:
|
||||
bool disable_fused_runner_;
|
||||
bool disable_fused_self_attention_;
|
||||
bool enable_trt_flash_attention_;
|
||||
bool enable_fused_causal_attention_;
|
||||
bool disable_memory_efficient_attention_;
|
||||
mutable std::unique_ptr<MHARunner> fused_fp16_runner_;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -620,12 +620,14 @@ Status QkvToContext(
|
|||
if (use_fused_kernel || use_fused_causal) {
|
||||
int* sequence_offset = reinterpret_cast<int*>(scratch1);
|
||||
if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) {
|
||||
DUMP_TENSOR_D("mask", reinterpret_cast<const int*>(data.mask_index), batch_size, sequence_length);
|
||||
LaunchTrtSequenceOffset2d(sequence_offset, data.mask_index, batch_size, sequence_length, stream);
|
||||
} else {
|
||||
sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache,
|
||||
data.mask_index, batch_size, sequence_length, stream,
|
||||
sequence_offset);
|
||||
}
|
||||
DUMP_TENSOR_D("sequence_offset", sequence_offset, 1, (data.mask_index != nullptr ? 2 : 1) * batch_size + 1);
|
||||
CUDA_RETURN_IF_ERROR(cudaGetLastError());
|
||||
|
||||
FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast<FusedMHARunnerFP16v2*>(fused_runner);
|
||||
|
|
|
|||
|
|
@ -42,8 +42,8 @@ MultiHeadAttention<T>::MultiHeadAttention(const OpKernelInfo& info)
|
|||
|
||||
mask_filter_value_ = info.GetAttrOrDefault<float>("mask_filter_value", -10000.0f);
|
||||
|
||||
disable_fused_runner_ = sizeof(T) != 2 ||
|
||||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedAttention, false);
|
||||
disable_fused_self_attention_ = sizeof(T) != 2 ||
|
||||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedSelfAttention, false);
|
||||
|
||||
enable_trt_flash_attention_ = sizeof(T) == 2 &&
|
||||
!ParseEnvironmentVariableWithDefault<bool>(attention::kDisableTrtFlashAttention, false);
|
||||
|
|
@ -124,7 +124,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
}
|
||||
}
|
||||
|
||||
bool use_fused_runner = !disable_fused_runner_ &&
|
||||
bool use_fused_runner = !disable_fused_self_attention_ &&
|
||||
fused_cross_attention_kernel == nullptr &&
|
||||
nullptr == relative_position_bias &&
|
||||
(value != nullptr || key == nullptr) &&
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ class MultiHeadAttention final : public CudaKernel {
|
|||
protected:
|
||||
int num_heads_; // number of attention heads
|
||||
float mask_filter_value_;
|
||||
bool disable_fused_runner_;
|
||||
bool disable_fused_self_attention_;
|
||||
bool enable_trt_flash_attention_;
|
||||
bool disable_fused_cross_attention_;
|
||||
bool disable_memory_efficient_attention_;
|
||||
|
|
|
|||
|
|
@ -543,6 +543,7 @@ def get_ort_environment_variables():
|
|||
# Environment variables might impact ORT performance on transformer models. Note that they are for testing only.
|
||||
env_names = [
|
||||
"ORT_DISABLE_FUSED_ATTENTION",
|
||||
"ORT_ENABLE_FUSED_CAUSAL_ATTENTION",
|
||||
"ORT_DISABLE_FUSED_CROSS_ATTENTION",
|
||||
"ORT_DISABLE_TRT_FLASH_ATTENTION",
|
||||
"ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION",
|
||||
|
|
|
|||
|
|
@ -24,7 +24,6 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
|
|||
|
||||
from benchmark_helper import Precision
|
||||
from float16 import float_to_float16_max_diff
|
||||
from fusion_options import AttentionMaskFormat
|
||||
from io_binding_helper import IOBindingHelper
|
||||
from onnx_model import OnnxModel
|
||||
from torch_onnx_export_helper import torch_onnx_export
|
||||
|
|
@ -188,6 +187,7 @@ class Gpt2Helper:
|
|||
input_ids_dtype: torch.dtype = torch.int32,
|
||||
position_ids_dtype: torch.dtype = torch.int32,
|
||||
attention_mask_dtype: torch.dtype = torch.int32,
|
||||
left_side_padding: bool = True,
|
||||
) -> Gpt2Inputs:
|
||||
"""Create random inputs for GPT2 model.
|
||||
Returns torch tensors of input_ids, position_ids, attention_mask and a list of past state tensors.
|
||||
|
|
@ -218,9 +218,14 @@ class Gpt2Helper:
|
|||
dtype=attention_mask_dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if total_sequence_length >= 2:
|
||||
padding_position = random.randint(0, total_sequence_length - 1) # test input with padding.
|
||||
attention_mask[:, padding_position] = 0
|
||||
for i in range(batch_size):
|
||||
padding_length = random.randint(0, total_sequence_length - 1)
|
||||
if left_side_padding:
|
||||
attention_mask[i, :padding_length] = 0
|
||||
else: # right side padding
|
||||
attention_mask[i, total_sequence_length - padding_length :] = 0
|
||||
|
||||
# Deduce position_ids from attention mask
|
||||
position_ids = None
|
||||
|
|
@ -517,11 +522,6 @@ class Gpt2Helper:
|
|||
|
||||
optimization_options = FusionOptions("gpt2")
|
||||
|
||||
if is_float16 and stage == 1:
|
||||
# For init_decoder, enable mask index to use fused causal cuda kernel.
|
||||
# Potentially, we can add other optimization like unpad for effective transformer
|
||||
optimization_options.attention_mask_format = AttentionMaskFormat.MaskIndexEnd
|
||||
|
||||
# TODO(hasesh): Investigate parity issue for GPT-2 fp16 when SkipLayerNormalization
|
||||
# is enabled
|
||||
if is_float16:
|
||||
|
|
@ -841,6 +841,7 @@ class Gpt2Helper:
|
|||
input_ids_dtype=input_ids_dtype,
|
||||
position_ids_dtype=position_ids_dtype,
|
||||
attention_mask_dtype=attention_mask_dtype,
|
||||
left_side_padding=True,
|
||||
)
|
||||
outputs = Gpt2Helper.pytorch_inference(model, dummy_inputs)
|
||||
if use_io_binding:
|
||||
|
|
@ -868,6 +869,7 @@ class Gpt2Helper:
|
|||
max_abs_diff_list.append(max_abs_diff)
|
||||
if is_all_close:
|
||||
passed_test_cases += 1
|
||||
|
||||
if is_top1_matched:
|
||||
top1_matched_cases += 1
|
||||
top1_matched_cases_per_run[run_id] += 1
|
||||
|
|
|
|||
|
|
@ -930,7 +930,8 @@ TEST(AttentionTest, Causal_EmptyPastState) {
|
|||
ScopedEnvironmentVariables scoped_env_vars{
|
||||
EnvVarMap{
|
||||
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedAttention, "1"}}};
|
||||
{onnxruntime::contrib::attention::kEnableFusedCausalAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"}}};
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional,
|
||||
use_past_state, past_sequence_length, &past_data, &present_data);
|
||||
|
|
@ -941,7 +942,8 @@ TEST(AttentionTest, Causal_EmptyPastState) {
|
|||
ScopedEnvironmentVariables scoped_env_vars{
|
||||
EnvVarMap{
|
||||
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedAttention, "0"}}};
|
||||
{onnxruntime::contrib::attention::kEnableFusedCausalAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"}}};
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional,
|
||||
use_past_state, past_sequence_length, &past_data, &present_data);
|
||||
|
|
@ -952,7 +954,8 @@ TEST(AttentionTest, Causal_EmptyPastState) {
|
|||
ScopedEnvironmentVariables scoped_env_vars{
|
||||
EnvVarMap{
|
||||
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedAttention, "0"}}};
|
||||
{onnxruntime::contrib::attention::kEnableFusedCausalAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"}}};
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional,
|
||||
use_past_state, past_sequence_length, &past_data, &present_data);
|
||||
|
|
|
|||
|
|
@ -181,7 +181,7 @@ static void RunMultiHeadAttentionKernel(
|
|||
ScopedEnvironmentVariables scoped_env_vars{
|
||||
EnvVarMap{
|
||||
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"}}};
|
||||
RunMultiHeadAttentionTest(
|
||||
|
|
@ -195,7 +195,7 @@ static void RunMultiHeadAttentionKernel(
|
|||
ScopedEnvironmentVariables scoped_env_vars{
|
||||
EnvVarMap{
|
||||
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}};
|
||||
RunMultiHeadAttentionTest(
|
||||
|
|
@ -209,7 +209,7 @@ static void RunMultiHeadAttentionKernel(
|
|||
ScopedEnvironmentVariables scoped_env_vars{
|
||||
EnvVarMap{
|
||||
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}};
|
||||
RunMultiHeadAttentionTest(
|
||||
|
|
@ -224,7 +224,7 @@ static void RunMultiHeadAttentionKernel(
|
|||
ScopedEnvironmentVariables scoped_env_vars{
|
||||
EnvVarMap{
|
||||
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"}}};
|
||||
RunMultiHeadAttentionTest(
|
||||
|
|
@ -239,7 +239,7 @@ static void RunMultiHeadAttentionKernel(
|
|||
ScopedEnvironmentVariables scoped_env_vars{
|
||||
EnvVarMap{
|
||||
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}};
|
||||
RunMultiHeadAttentionTest(
|
||||
|
|
|
|||
Loading…
Reference in a new issue