Disable fused causal attention (#14732)

There is accuracy regression in GPT-2 model. Top1 match rate (vs PyTorch
model) drops about 1%. The cause is the fused causal attention uses fp16
accumulation. Disable it by default and add an environment variable 
ORT_ENABLE_FUSED_CAUSAL_ATTENTION=1 to turn on it manually.

It also updated the GPT-2 parity test script to generate left side
padding to reflect the actual usage.

To test:
```
python -m onnxruntime.transformers.models.gpt2.convert_to_onnx -m gpt2 --output gpt2.onnx -o -p fp16 --use_gpu
```
The top1-match-rate in the output is on-par with ORT 1.13.1.
This commit is contained in:
Tianlei Wu 2023-02-21 09:53:31 -08:00 committed by GitHub
parent 25e10f413e
commit c0d2472ede
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 55 additions and 40 deletions

View file

@ -25,7 +25,7 @@ enum AttentionQkvFormat {
Q_KV_BSNH_BSN2H, // for TRT fused cross attention, kv are packed
};
enum AttentionKernelType{
enum AttentionKernelType {
AttentionKernel_Unfused,
AttentionKernel_TrtFusedAttention,
AttentionKernel_TrtFlashAttention,
@ -38,15 +38,15 @@ enum AttentionKernelType{
struct AttentionParameters {
int batch_size;
int sequence_length;
int kv_sequence_length; // input sequence length of K or V
int past_sequence_length; // sequence length in past state of K or V
int total_sequence_length; // total sequence length of K or V
int max_sequence_length; // max sequence length from 4D mask
int input_hidden_size; // first dimension of weights for input projection
int hidden_size; // hidden size of Q or K
int head_size; // hidden size per head of Q or K
int v_hidden_size; // hidden size of V
int v_head_size; // hidden size per head of V
int kv_sequence_length; // input sequence length of K or V
int past_sequence_length; // sequence length in past state of K or V
int total_sequence_length; // total sequence length of K or V
int max_sequence_length; // max sequence length from 4D mask
int input_hidden_size; // first dimension of weights for input projection
int hidden_size; // hidden size of Q or K
int head_size; // hidden size per head of Q or K
int v_hidden_size; // hidden size of V
int v_head_size; // hidden size per head of V
int num_heads;
bool is_unidirectional;
bool past_present_share_buffer;
@ -56,13 +56,17 @@ struct AttentionParameters {
};
namespace attention {
// Environment variable to enable or disable fused self/causal attention kernel. Default is 0 (enabled).
constexpr const char* kDisableFusedAttention = "ORT_DISABLE_FUSED_ATTENTION";
// Environment variable to enable or disable TRT fused self attention kernel. Default is 0 (enabled).
constexpr const char* kDisableFusedSelfAttention = "ORT_DISABLE_FUSED_ATTENTION";
// Environment variable to enable or disable fused cross attention kernel. Default is 0 (enabled).
constexpr const char* kDisableFusedCrossAttention = "ORT_DISABLE_FUSED_CROSS_ATTENTION";
// Environment variable to enable or disable TRT flash attention. Default is 0 (enabled).
// Environment variable to enable or disable TRT fused causal attention kernels. Default is 0 (disabled).
// Note that those causal attention kernels use fp16 accumulation. There is potential accuracy drop using those kernels.
constexpr const char* kEnableFusedCausalAttention = "ORT_ENABLE_FUSED_CAUSAL_ATTENTION";
// Environment variable to enable or disable TRT flash attention. This applies to both self and causal attention. Default is 0 (enabled).
constexpr const char* kDisableTrtFlashAttention = "ORT_DISABLE_TRT_FLASH_ATTENTION";
// Environment variable to enable or disable cutlass memory efficient attention. Default is 0 (enabled).

View file

@ -39,12 +39,15 @@ REGISTER_KERNEL_TYPED(MLFloat16)
template <typename T>
Attention<T>::Attention(const OpKernelInfo& info) : CudaKernel(info), AttentionBase(info, false) {
disable_fused_runner_ = sizeof(T) != 2 ||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedAttention, false);
disable_fused_self_attention_ = sizeof(T) != 2 ||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedSelfAttention, false);
enable_trt_flash_attention_ = sizeof(T) == 2 &&
!ParseEnvironmentVariableWithDefault<bool>(attention::kDisableTrtFlashAttention, false);
enable_fused_causal_attention_ = sizeof(T) == 2 &&
ParseEnvironmentVariableWithDefault<bool>(attention::kEnableFusedCausalAttention, false);
#if USE_FLASH_ATTENTION
disable_memory_efficient_attention_ = ParseEnvironmentVariableWithDefault<bool>(attention::kDisableMemoryEfficientAttention, false);
#else
@ -97,14 +100,13 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
int sm = device_prop.major * 10 + device_prop.minor;
bool is_mask_1d_seq_len = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN;
if (is_unidirectional_) { // GPT
if (is_unidirectional_ && enable_fused_causal_attention_) { // GPT
// GPT fused kernels requires left side padding. mask can be:
// none (no padding), 1D sequence lengths or 2d mask.
// Fused kernels don't support different sequence lengths of q and kv, so only apply to the first token
// where past state is empty.
bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING;
bool use_causal_fused_runner = !disable_fused_runner_ &&
(nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) &&
bool use_causal_fused_runner = (nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) &&
nullptr == relative_position_bias &&
parameters.past_sequence_length == 0 &&
parameters.hidden_size == parameters.v_hidden_size &&
@ -121,7 +123,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
fused_runner = fused_fp16_runner_.get();
}
} else { // BERT
bool use_fused_runner = !disable_fused_runner_ &&
bool use_fused_runner = !disable_fused_self_attention_ &&
(nullptr == mask_index || is_mask_1d_seq_len) &&
nullptr == past &&
nullptr == present &&

View file

@ -21,8 +21,9 @@ class Attention final : public CudaKernel, public AttentionBase {
Status ComputeInternal(OpKernelContext* context) const override;
protected:
bool disable_fused_runner_;
bool disable_fused_self_attention_;
bool enable_trt_flash_attention_;
bool enable_fused_causal_attention_;
bool disable_memory_efficient_attention_;
mutable std::unique_ptr<MHARunner> fused_fp16_runner_;
};

View file

@ -620,12 +620,14 @@ Status QkvToContext(
if (use_fused_kernel || use_fused_causal) {
int* sequence_offset = reinterpret_cast<int*>(scratch1);
if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) {
DUMP_TENSOR_D("mask", reinterpret_cast<const int*>(data.mask_index), batch_size, sequence_length);
LaunchTrtSequenceOffset2d(sequence_offset, data.mask_index, batch_size, sequence_length, stream);
} else {
sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache,
data.mask_index, batch_size, sequence_length, stream,
sequence_offset);
}
DUMP_TENSOR_D("sequence_offset", sequence_offset, 1, (data.mask_index != nullptr ? 2 : 1) * batch_size + 1);
CUDA_RETURN_IF_ERROR(cudaGetLastError());
FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast<FusedMHARunnerFP16v2*>(fused_runner);

View file

@ -42,8 +42,8 @@ MultiHeadAttention<T>::MultiHeadAttention(const OpKernelInfo& info)
mask_filter_value_ = info.GetAttrOrDefault<float>("mask_filter_value", -10000.0f);
disable_fused_runner_ = sizeof(T) != 2 ||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedAttention, false);
disable_fused_self_attention_ = sizeof(T) != 2 ||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedSelfAttention, false);
enable_trt_flash_attention_ = sizeof(T) == 2 &&
!ParseEnvironmentVariableWithDefault<bool>(attention::kDisableTrtFlashAttention, false);
@ -124,7 +124,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
}
}
bool use_fused_runner = !disable_fused_runner_ &&
bool use_fused_runner = !disable_fused_self_attention_ &&
fused_cross_attention_kernel == nullptr &&
nullptr == relative_position_bias &&
(value != nullptr || key == nullptr) &&

View file

@ -24,7 +24,7 @@ class MultiHeadAttention final : public CudaKernel {
protected:
int num_heads_; // number of attention heads
float mask_filter_value_;
bool disable_fused_runner_;
bool disable_fused_self_attention_;
bool enable_trt_flash_attention_;
bool disable_fused_cross_attention_;
bool disable_memory_efficient_attention_;

View file

@ -543,6 +543,7 @@ def get_ort_environment_variables():
# Environment variables might impact ORT performance on transformer models. Note that they are for testing only.
env_names = [
"ORT_DISABLE_FUSED_ATTENTION",
"ORT_ENABLE_FUSED_CAUSAL_ATTENTION",
"ORT_DISABLE_FUSED_CROSS_ATTENTION",
"ORT_DISABLE_TRT_FLASH_ATTENTION",
"ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION",

View file

@ -24,7 +24,6 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
from benchmark_helper import Precision
from float16 import float_to_float16_max_diff
from fusion_options import AttentionMaskFormat
from io_binding_helper import IOBindingHelper
from onnx_model import OnnxModel
from torch_onnx_export_helper import torch_onnx_export
@ -188,6 +187,7 @@ class Gpt2Helper:
input_ids_dtype: torch.dtype = torch.int32,
position_ids_dtype: torch.dtype = torch.int32,
attention_mask_dtype: torch.dtype = torch.int32,
left_side_padding: bool = True,
) -> Gpt2Inputs:
"""Create random inputs for GPT2 model.
Returns torch tensors of input_ids, position_ids, attention_mask and a list of past state tensors.
@ -218,9 +218,14 @@ class Gpt2Helper:
dtype=attention_mask_dtype,
device=device,
)
if total_sequence_length >= 2:
padding_position = random.randint(0, total_sequence_length - 1) # test input with padding.
attention_mask[:, padding_position] = 0
for i in range(batch_size):
padding_length = random.randint(0, total_sequence_length - 1)
if left_side_padding:
attention_mask[i, :padding_length] = 0
else: # right side padding
attention_mask[i, total_sequence_length - padding_length :] = 0
# Deduce position_ids from attention mask
position_ids = None
@ -517,11 +522,6 @@ class Gpt2Helper:
optimization_options = FusionOptions("gpt2")
if is_float16 and stage == 1:
# For init_decoder, enable mask index to use fused causal cuda kernel.
# Potentially, we can add other optimization like unpad for effective transformer
optimization_options.attention_mask_format = AttentionMaskFormat.MaskIndexEnd
# TODO(hasesh): Investigate parity issue for GPT-2 fp16 when SkipLayerNormalization
# is enabled
if is_float16:
@ -841,6 +841,7 @@ class Gpt2Helper:
input_ids_dtype=input_ids_dtype,
position_ids_dtype=position_ids_dtype,
attention_mask_dtype=attention_mask_dtype,
left_side_padding=True,
)
outputs = Gpt2Helper.pytorch_inference(model, dummy_inputs)
if use_io_binding:
@ -868,6 +869,7 @@ class Gpt2Helper:
max_abs_diff_list.append(max_abs_diff)
if is_all_close:
passed_test_cases += 1
if is_top1_matched:
top1_matched_cases += 1
top1_matched_cases_per_run[run_id] += 1

View file

@ -930,7 +930,8 @@ TEST(AttentionTest, Causal_EmptyPastState) {
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedAttention, "1"}}};
{onnxruntime::contrib::attention::kEnableFusedCausalAttention, "0"},
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"}}};
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional,
use_past_state, past_sequence_length, &past_data, &present_data);
@ -941,7 +942,8 @@ TEST(AttentionTest, Causal_EmptyPastState) {
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedAttention, "0"}}};
{onnxruntime::contrib::attention::kEnableFusedCausalAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"}}};
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional,
use_past_state, past_sequence_length, &past_data, &present_data);
@ -952,7 +954,8 @@ TEST(AttentionTest, Causal_EmptyPastState) {
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"},
{onnxruntime::contrib::attention::kDisableFusedAttention, "0"}}};
{onnxruntime::contrib::attention::kEnableFusedCausalAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"}}};
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional,
use_past_state, past_sequence_length, &past_data, &present_data);

View file

@ -181,7 +181,7 @@ static void RunMultiHeadAttentionKernel(
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"},
{onnxruntime::contrib::attention::kDisableFusedAttention, "0"},
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"},
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"},
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"}}};
RunMultiHeadAttentionTest(
@ -195,7 +195,7 @@ static void RunMultiHeadAttentionKernel(
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"},
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}};
RunMultiHeadAttentionTest(
@ -209,7 +209,7 @@ static void RunMultiHeadAttentionKernel(
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"},
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}};
RunMultiHeadAttentionTest(
@ -224,7 +224,7 @@ static void RunMultiHeadAttentionKernel(
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"},
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"}}};
RunMultiHeadAttentionTest(
@ -239,7 +239,7 @@ static void RunMultiHeadAttentionKernel(
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"},
{onnxruntime::contrib::attention::kDisableFusedAttention, "0"},
{onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"},
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"},
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}};
RunMultiHeadAttentionTest(