diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index 46f2b292d5..b42ad93f1d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -112,11 +112,12 @@ class AttentionCPUBase : public AttentionBase { { if (mask_data != nullptr) { - PrepareMask(mask_index, mask_index_dims, mask_data, is_unidirectional_, batch_size, sequence_length, past_sequence_length); - } else { // no any mask - memset(attention_probs, 0, static_cast(batch_size) * num_heads_ * sequence_length * all_sequence_length * sizeof(T)); + // Convert attention mask data from int to float (0 to -10000.0f). The mask_data shape is BxSxS*. + PrepareMask(mask_index, mask_index_dims, mask_data, batch_size, sequence_length, past_sequence_length); } + memset(attention_probs, 0, static_cast(batch_size) * num_heads_ * sequence_length * all_sequence_length * sizeof(T)); + const int loop_len = batch_size * num_heads_; const float alpha = 1.0f / sqrt(static_cast(head_size)); @@ -127,32 +128,45 @@ class AttentionCPUBase : public AttentionBase { for (std::ptrdiff_t i = begin; i != end; ++i) { const std::ptrdiff_t batch_index = i / num_heads_; - // broadcast mask data: (Bx)SxS* -> (BxNx)SxS* - if (mask_data != nullptr) { - const T* broadcast_data_src = reinterpret_cast(mask_data) + batch_index * sequence_length * all_sequence_length; - T* broadcast_data_dest = reinterpret_cast(attention_probs) + sequence_length * all_sequence_length * i; - memcpy(broadcast_data_dest, broadcast_data_src, sequence_length * all_sequence_length * sizeof(T)); - } - const T* k = K + input_chunk_length * i; if (nullptr != present) { - // concatenate past_K and K : (BxNx)S'xH, (BxNx)SxH -> (BxNx)S*xH + // Concatenate past_K and K : (BxNx)S'xH, (BxNx)SxH -> (BxNx)S*xH k = ConcatStateChunk(past, k, present, past_chunk_length, present_chunk_length, i); } - // gemm + int offset = sequence_length * all_sequence_length * static_cast(i); + T* output = reinterpret_cast(attention_probs) + offset; + + // Compute Q*K' // original transposed each iteration // A: Q (B x N x) S x H (B x N x) S x H S x H // B: K' (B x N x) S* x H (B x N x) H x S* H x S* // C: attention_probs (B x N x) S x S* (B x N x) S x S* S x S* math::Gemm(CblasNoTrans, CblasTrans, sequence_length, all_sequence_length, head_size, alpha, Q + input_chunk_length * i, k, 1.0, - reinterpret_cast(attention_probs) + sequence_length * all_sequence_length * i, nullptr); + output, nullptr); + + // Apply unidirectional mask and set future words to -10000.0f. + if (is_unidirectional_) { + for (int s = 0; s < sequence_length - 1; s++) { + for (int t = past_sequence_length + s + 1; t < all_sequence_length; t++) { + output[s * all_sequence_length + t] = static_cast(-10000.0f); + } + } + } + + // Apply attention mask + if (mask_data != nullptr) { + const T* attention_mask = reinterpret_cast(mask_data) + batch_index * sequence_length * all_sequence_length; + for (int j = 0; j < sequence_length * all_sequence_length ; j++) { + output[j] += attention_mask[j]; + } + } if (extra_add_qk_data != nullptr) { - int extra_add_qk_offset = static_cast(i) * sequence_length * all_sequence_length; + const T* extra = extra_add_qk_data + offset; for (int j = 0; j < sequence_length * all_sequence_length ; j++) { - attention_probs[extra_add_qk_offset+j] += extra_add_qk_data[extra_add_qk_offset + j]; + output[j] += extra[j]; } } } diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h index cf9408a2db..c274dee7ea 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h @@ -64,7 +64,6 @@ template void PrepareMask(const int32_t* mask_index, const std::vector* mask_index_dims, T* mask_data, - bool is_unidirectional, int batch_size, int sequence_length, int past_sequence_length) { @@ -84,18 +83,6 @@ void PrepareMask(const int32_t* mask_index, for (int i = 0; i < batch_size * sequence_length * all_sequence_length; i++) { p_mask[i] = (mask_index[i] > 0) ? static_cast(0.0f) : static_cast(-10000.0f); } - - if (is_unidirectional) { - for (int b_i = 0; b_i < batch_size; b_i++) { - for (int s_i = 0; s_i < sequence_length - 1; s_i++) { - for (int m_i = past_sequence_length + s_i + 1; m_i < all_sequence_length; m_i++) { - p_mask[s_i * all_sequence_length + m_i] += static_cast(-10000.0f); - } - } - p_mask += sequence_length * all_sequence_length; - } - } - return; } @@ -135,15 +122,6 @@ void PrepareMask(const int32_t* mask_index, memcpy(p_mask + s_i * all_sequence_length, p_mask, all_sequence_length * sizeof(T)); } - // Apply unidirectional mask. - if (is_unidirectional) { - for (int s_i = 0; s_i < sequence_length - 1; s_i++) { - for (int m_i = past_sequence_length + s_i + 1; m_i < all_sequence_length; m_i++) { - p_mask[s_i * all_sequence_length + m_i] += static_cast(-10000.0f); - } - } - } - p_mask += sequence_length * all_sequence_length; } } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 106bedcbb8..e9c41f7754 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -27,6 +27,7 @@ limitations under the License. #include "core/providers/cuda/shared_inc/fpgeneric.h" #include "attention_impl.h" #include "attention_softmax.h" +#include "transformer_common.h" using namespace onnxruntime::cuda; using namespace cub; @@ -65,7 +66,7 @@ bool QkvToContext( const int batch_size, const int sequence_length, const int num_heads, const int head_size, const size_t element_size, const T* input, T* output, T* workspace, const int* mask_index, const std::vector* mask_index_dims, - bool is_unidirectional, int past_sequence_length, const T* past, T* present) { + bool is_unidirectional, int past_sequence_length, const T* past, T* present, bool use_persistent_softmax) { const int all_sequence_length = past_sequence_length + sequence_length; const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, all_sequence_length); T* scratch1 = workspace; @@ -127,8 +128,11 @@ bool QkvToContext( if (use_raw_attention_mask) { // 2d, 3d or 4d attention mask const int mask_dimension = static_cast(mask_index_dims->size()); const int64_t max_sequence_length = mask_dimension == 4 ? mask_index_dims->at(3) : 0; + + T* persistent_softmax_workspace = scratch1; // replace Q*K' in place with masked score if persistent softmax is selected. if (!ComputeSoftmaxWithRawMask(stream, all_sequence_length, sequence_length, batch_size, num_heads, mask_index, scratch1, scratch2, is_unidirectional, - rsqrt_head_size, mask_dimension, static_cast(max_sequence_length))) { + rsqrt_head_size, mask_dimension, static_cast(max_sequence_length), + use_persistent_softmax, persistent_softmax_workspace)) { return false; } } else if (nullptr != mask_index) { // 1d mask index @@ -173,18 +177,27 @@ bool LaunchAttentionKernel( int past_sequence_length, const void* past, void* present) { + + // GPT-2 model is more sensitive on parity since error will accumulate in text generation. + // So use persistent softmax for GPT-2 model by default. + // For testing, environment variable ORT_TRANSFORMER_OPTIONS=1 or 2 could enable or disable it explicitly. + const TransformerOptions* options = TransformerOptions::GetInstance(); + bool use_persistent_softmax = (is_unidirectional || options->IsPrecisionMode()) && !options->DisablePersistentSoftmax(); + if (element_size == 2) { return QkvToContext(prop, cublas, stream, batch_size, sequence_length, num_heads, head_size, element_size, reinterpret_cast(input), reinterpret_cast(output), reinterpret_cast(workspace), mask_index, mask_index_dims, is_unidirectional, - past_sequence_length, reinterpret_cast(past), reinterpret_cast(present)); + past_sequence_length, reinterpret_cast(past), reinterpret_cast(present), + use_persistent_softmax); } else { return QkvToContext(prop, cublas, stream, batch_size, sequence_length, num_heads, head_size, element_size, reinterpret_cast(input), reinterpret_cast(output), reinterpret_cast(workspace), mask_index, mask_index_dims, is_unidirectional, - past_sequence_length, reinterpret_cast(past), reinterpret_cast(present)); + past_sequence_length, reinterpret_cast(past), reinterpret_cast(present), + use_persistent_softmax); } } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h index a3cc1a9a7d..c77af81d42 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h @@ -19,11 +19,13 @@ limitations under the License. #pragma once +#include #include #include #include #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/math/softmax.h" using namespace onnxruntime::cuda; using namespace cub; @@ -165,9 +167,10 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, const T* input, T* output, const bool is_unidirectional, - const float scalar, + const float rsqrt_head_size, const int mask_dimension, - const int max_sequence_length) { + const int max_sequence_length, + const bool skip_softmax) { using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmp_storage; @@ -179,30 +182,36 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, float thread_data = -CUDART_INF_F; if (threadIdx.x < all_sequence_length) { - const int batch_index = blockIdx.y; + thread_data = float(input[index]) * rsqrt_head_size; + const int sequence_index = blockIdx.x % sequence_length; + if (is_unidirectional) { + int from_index = all_sequence_length - sequence_length + sequence_index; // offset of from token in all sequence length. + if (threadIdx.x > from_index) { + thread_data = -10000.0f; + } + } + int mask_offset = 0; + const int batch_index = blockIdx.y; if (mask_dimension == 2) { mask_offset = batch_index * all_sequence_length + threadIdx.x; } else if (mask_dimension == 3) { mask_offset = (batch_index * sequence_length + sequence_index) * all_sequence_length + threadIdx.x; - } else if (mask_dimension == 4){ - // Megatron code: - // ltor_mask = ltor_mask[..., (attention_scores.size(3)-hidden_states.size(1)):attention_scores.size(3), :attention_scores.size(3)] + } else if (mask_dimension == 4) { mask_offset = (batch_index * max_sequence_length + all_sequence_length - sequence_length + sequence_index) * max_sequence_length + threadIdx.x; } const int& mask = attention_mask[mask_offset]; - float mask_value = mask > 0 ? 0.0f : -10000.0f; + if (mask == 0) + thread_data += -10000.0f; + } - if (is_unidirectional) { - int from_index = all_sequence_length - sequence_length + sequence_index; // offset of from token in all sequence length. - if (threadIdx.x > from_index) { - mask_value += -10000.0f; - } + if (skip_softmax) { + if (threadIdx.x < all_sequence_length) { + output[index] = T(thread_data); } - - thread_data = float(input[index]) * scalar + mask_value; + return; } const float max = BlockReduce(tmp_storage).Reduce(thread_data, cub::Max(), all_sequence_length); @@ -264,7 +273,7 @@ bool ComputeSoftmax( const int blockSize = 1024; SoftmaxKernel<<>>(all_sequence_length, sequence_length, input, output); } else { - ORT_THROW("Attention CUDA operator does not support unidirectional with total sequence length > 1024."); + ORT_THROW("Attention CUDA operator does not support total sequence length > 1024."); } return CUDA_CALL(cudaPeekAtLastError()); @@ -314,8 +323,9 @@ __global__ void MaskedSoftmaxKernel(const int all_sequence_length, const int seq template __global__ void SoftmaxWithRawMaskSmallKernel(const int all_sequence_length, const int sequence_length, const int* attention_mask, const T* input, T* output, - const bool is_unidirectional, const float scalar, const int mask_dimension, const int max_sequence_length) { - SoftmaxWithRawMaskSmall(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension, max_sequence_length); + const bool is_unidirectional, const float rsqrt_head_size, const int mask_dimension, const int max_sequence_length, + const bool skip_softmax) { + SoftmaxWithRawMaskSmall(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, skip_softmax); } template @@ -352,7 +362,7 @@ bool ComputeSoftmaxWithMask1D(cudaStream_t stream, const int all_sequence_length MaskedSoftmaxKernel <<>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output); } else { - ORT_THROW("Attention CUDA operator does not support unidirectional with total sequence length > 1024."); + ORT_THROW("Attention CUDA operator does not support total sequence length > 1024."); } return CUDA_CALL(cudaPeekAtLastError()); @@ -360,36 +370,42 @@ bool ComputeSoftmaxWithMask1D(cudaStream_t stream, const int all_sequence_length template bool ComputeSoftmaxWithRawMask(cudaStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, - const int* attention_mask, const T* input, T* output, const bool is_unidirectional, const float scalar, - const int mask_dimension, const int max_sequence_length) { + const int* attention_mask, const T* input, T* output, const bool is_unidirectional, const float rsqrt_head_size, + const int mask_dimension, const int max_sequence_length, + const bool use_persistent_softmax, T* persistent_softmax_workspace) { const dim3 grid(sequence_length * num_heads, batch_size, 1); + T* out = use_persistent_softmax ? persistent_softmax_workspace : output; if (all_sequence_length <= 32) { const int blockSize = 32; SoftmaxWithRawMaskSmallKernel - <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension, max_sequence_length); + <<>>(all_sequence_length, sequence_length, attention_mask, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax); } else if (all_sequence_length <= 64) { const int blockSize = 64; SoftmaxWithRawMaskSmallKernel - <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension, max_sequence_length); + <<>>(all_sequence_length, sequence_length, attention_mask, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax); } else if (all_sequence_length <= 128) { const int blockSize = 128; SoftmaxWithRawMaskSmallKernel - <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension, max_sequence_length); + <<>>(all_sequence_length, sequence_length, attention_mask, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax); } else if (all_sequence_length <= 256) { const int blockSize = 256; SoftmaxWithRawMaskSmallKernel - <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension, max_sequence_length); + <<>>(all_sequence_length, sequence_length, attention_mask, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax); } else if (all_sequence_length <= 512) { const int blockSize = 512; SoftmaxWithRawMaskSmallKernel - <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension, max_sequence_length); + <<>>(all_sequence_length, sequence_length, attention_mask, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax); } else if (all_sequence_length <= 1024) { const int blockSize = 1024; SoftmaxWithRawMaskSmallKernel - <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension, max_sequence_length); + <<>>(all_sequence_length, sequence_length, attention_mask, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax); } else { - ORT_THROW("Attention CUDA operator does not supported 2D attention mask with total sequence length > 1024."); + ORT_THROW("Attention CUDA operator does not support total sequence length > 1024."); + } + + if (use_persistent_softmax) { + dispatch_softmax_forward(stream, output, persistent_softmax_workspace, all_sequence_length, all_sequence_length, batch_size * num_heads * sequence_length); } return CUDA_CALL(cudaPeekAtLastError()); diff --git a/onnxruntime/contrib_ops/cuda/bert/transformer_common.cc b/onnxruntime/contrib_ops/cuda/bert/transformer_common.cc new file mode 100644 index 0000000000..58eed22996 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/transformer_common.cc @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include +#include "core/providers/shared_library/provider_api.h" // Include this otherwise Windows build complains Env::Default() missing +#include "core/platform/env_var_utils.h" +#include "transformer_common.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// The environment variable is for testing purpose only, and it might be removed in the future. +// If you need some option in production, please file a feature request. +constexpr const char* kTransformerOptions = "ORT_TRANSFORMER_OPTIONS"; + +// Initialize the singleton instance +TransformerOptions TransformerOptions::instance; + +const TransformerOptions* TransformerOptions::GetInstance() { + if (!instance.initialized_) { + // We do not use critical section here since it is fine to initialize multiple times by different threads. + int value = ParseEnvironmentVariableWithDefault(kTransformerOptions, 0); + instance.Initialize(value); + + if (value > 0) + std::cout << "ORT_TRANSFORMER_OPTIONS: IsPrecisionMode=" << instance.IsPrecisionMode() + << ",DisablePersistentSoftmax=" << instance.DisablePersistentSoftmax() << std::endl; + } + + return &instance; +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cuda/bert/transformer_common.h b/onnxruntime/contrib_ops/cuda/bert/transformer_common.h new file mode 100644 index 0000000000..fad851a937 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/transformer_common.h @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +class TransformerOptions { + public: + static const TransformerOptions* GetInstance(); + + bool IsPrecisionMode() const { return is_precision_mode_; } + + bool DisablePersistentSoftmax() const { return disable_persistent_softmax_; } + + void Initialize(int value) { + is_precision_mode_ = (value & 0x01) > 0; + disable_persistent_softmax_ = (value & 0x02) > 0; + initialized_ = true; + } + + private: + // Default is false. If the mode is on, prefer precision than speed. + bool is_precision_mode_{false}; + + // Disable persistent softmax. + bool disable_persistent_softmax_{false}; + + bool initialized_{false}; + + static TransformerOptions instance; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py b/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py index 70a94da700..ef858237ce 100644 --- a/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py @@ -39,6 +39,7 @@ class FusionSkipLayerNormalization(Fusion): if self.shape_infer_helper is not None: if not self.shape_infer_helper.compare_shape(add.input[0], add.input[1]): + logger.debug(f"skip skiplayernorm fusion since shape of inputs ({add.input[0]}, {add.input[1]}) are not same") return else: # shape_infer_helper can not handle subgraphs. Current work around is to disable skiplayernorm fusion diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index d0292ba7db..7bf8ced4e6 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -799,6 +799,7 @@ class OnnxModel: else: deps_to_nodes[input_name].append(node_idx) + # Note: this logic only applies to top level graph since a sub graph could use intializer from parent graph initializer_names = [init.name for init in graph.initializer] graph_input_names = [input.name for input in graph.input] input_names = initializer_names + graph_input_names diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert.py b/onnxruntime/python/tools/transformers/onnx_model_bert.py index 20e33687b0..e04c63db49 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert.py @@ -185,9 +185,6 @@ class BertOnnxModel(OnnxModel): return def adjust_reshape_and_expand(self): - # Remove reshape nodes that having same shape of input and output based on symbolic shape inference. - FusionUtils.remove_useless_reshape_nodes(self) - nodes_to_remove = [] for node in self.nodes(): if node.op_type == 'Reshape': @@ -289,7 +286,9 @@ class BertOnnxModel(OnnxModel): if (options is None) or options.enable_embed_layer_norm: self.fuse_embed_layer() - # Post-processing like removing extra reshape nodes. + # Remove reshape nodes that having same shape of input and output based on symbolic shape inference. + FusionUtils.remove_useless_reshape_nodes(self) + self.postprocess() # Bias fusion is done after postprocess to avoid extra Reshape between bias and Gelu/FastGelu/SkipLayerNormalization diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc index 9cd2863de4..03c8cd32fd 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test.cc @@ -1291,7 +1291,7 @@ TEST(AttentionTest, AttentionUnidirectional3DMask) { 1, 1}; std::vector output_data = { - 3.967245340f, 0.07324841f, 4.25f, 5.65f, + 3.0146f, 0.1142f, 3.9834f, 5.3394f, 8.69f, -0.13f, 4.25f, 5.65f, 8.69f, -0.13f, 4.25f, 5.65f, 3.96967912f, 0.07314367f, 4.25f, 5.65f}; @@ -1332,7 +1332,7 @@ TEST(AttentionTest, AttentionUnidirectionalAttentionMask) { std::vector mask_index_data = {0, 1, 1, 1}; std::vector output_data = { - 3.967245340f, 0.07324841f, 4.25f, 5.65f, + 3.0146f, 0.1142f, 3.9834f, 5.3394f, 8.69f, -0.13f, 4.25f, 5.65f, 8.69f, -0.13f, 4.25f, 5.65f, 3.96967912f, 0.07314367f, 4.25f, 5.65f}; diff --git a/onnxruntime/test/python/transformers/test_parity_huggingface_gpt_attention.py b/onnxruntime/test/python/transformers/test_parity_huggingface_gpt_attention.py new file mode 100644 index 0000000000..023fff85dd --- /dev/null +++ b/onnxruntime/test/python/transformers/test_parity_huggingface_gpt_attention.py @@ -0,0 +1,511 @@ +# -------------------------------------------------------------------------- +# Copyright 2020 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ------------------------------------------------------------------------- + +import unittest +import pytest +import torch +from torch import nn +import random +from onnx import helper +import onnx +import numpy +import os +from transformers.modeling_utils import Conv1D + +DEBUG_OUTPUTS = ["qk", "norm_qk", "softmax", "attn_weights"] + + +class MyGPT2Attention(nn.Module): + """ + This module is modifed from Gpt2Attention of huggingface transformers v4.9.1. + Code related to crosss attention, c_proj, attn_dropout and head_mask etc are removed. + """ + def __init__(self, + max_position_embeddings=1024, + hidden_size=768, + num_attention_heads=12, + use_cache=True, + debug=False, + fix_onnx_export=True): + super().__init__() + max_positions = max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), + dtype=torch.uint8)).view(1, 1, max_positions, max_positions), + ) + self.register_buffer("masked_bias", torch.tensor(-1e4)) + self.embed_dim = hidden_size + self.num_heads = num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.split_size = self.embed_dim + assert self.head_dim * self.num_heads == self.embed_dim + + self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) + # Use random bias instead of zeros for parity test. + self.c_attn.bias = nn.Parameter(torch.normal(0.0, 0.1, (3 * self.embed_dim, ))) + + self.use_cache = use_cache + self.debug = debug + self.fix_onnx_export = fix_onnx_export + + def _attn(self, query, key, value, attention_mask=None): + qk = torch.matmul(query, key.transpose(-1, -2)) + + # Torch has special handling for Div and Mul by a scalar: + # https://github.com/pytorch/pytorch/blob/5536cda19a5def9e0553b318f04d297d602ac956/aten/src/ATen/native/cuda/BinaryMulDivKernel.cu#L52-L60 + # https://github.com/pytorch/pytorch/blob/5536cda19a5def9e0553b318f04d297d602ac956/aten/src/ATen/native/cuda/BinaryMulDivKernel.cu#L185-L194 + # Modify the code to use same processing in onnx export so as to get parity result without attention fusion. + # This walkaround is not needed when attention fusion will be applied later since the subgraph will be replaced by an Attention node. + if self.fix_onnx_export and torch.onnx.is_in_onnx_export(): + if qk.dtype == torch.float16: + norm_qk = qk.to(torch.float32) * (1.0 / (float(value.size(-1))**0.5)) + norm_qk = norm_qk.to(torch.float16) + else: + norm_qk = qk * (1.0 / (float(value.size(-1))**0.5)) + else: + norm_qk = qk / (float(value.size(-1))**0.5) + + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].bool() + attn_weights = torch.where(causal_mask, norm_qk, self.masked_bias.to(norm_qk.dtype)) + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + softmax = nn.Softmax(dim=-1)(attn_weights) + + attn_output = torch.matmul(softmax, value) + + if self.debug: + return attn_output, qk, norm_qk, softmax, attn_weights + else: + return attn_output + + def _split_heads(self, tensor, num_heads, attn_head_size): + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(*new_shape) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def _merge_heads(self, tensor, num_heads, attn_head_size): + tensor = tensor.permute(0, 2, 1, 3).contiguous() + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size, ) + return tensor.view(new_shape) + + @staticmethod + def concat_key_value(key, value): + return torch.cat((key.unsqueeze(0), value.unsqueeze(0)), dim=0) + + @staticmethod + def process_mask(attention_mask, dtype): + # Create a 4D attention mask with shape [batch_size, 1, 1, to_seq_length] from a 2D tensor mask. + attention_mask = attention_mask[:, None, None, :] + attention_mask = attention_mask.to(dtype=dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * -10000.0 + return attention_mask + + def forward(self, hidden_states, attention_mask=None, layer_past=None): + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key, past_value = layer_past + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + if self.use_cache is True: + # Instead of present = (key, value), here we merge them into one tensor to be compatible with Attention operator. + present = MyGPT2Attention.concat_key_value(key, value) + else: + present = None + + mask = MyGPT2Attention.process_mask(attention_mask, dtype=query.dtype) # mask processing is moved to here. + + if self.debug: + attn_output, qk, norm_qk, softmax, attn_weights = self._attn(query, key, value, mask) + else: + attn_output = self._attn(query, key, value, mask) + + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + + outputs = (attn_output, present) + if self.debug: + if "qk" in DEBUG_OUTPUTS: + outputs += (qk, ) + if "norm_qk" in DEBUG_OUTPUTS: + outputs += (norm_qk, ) + if "softmax" in DEBUG_OUTPUTS: + outputs += (softmax, ) + if "attn_weights" in DEBUG_OUTPUTS: + outputs += (attn_weights, ) + + return outputs + + +def create_inputs(batch_size=1, + hidden_size=768, + num_attention_heads=12, + sequence_length=1, + past_sequence_length=5, + float16=False, + device=torch.device('cuda'), + padding_length=0): + float_type = torch.float16 if float16 else torch.float32 + + past_shape = [batch_size, num_attention_heads, past_sequence_length, int(hidden_size / num_attention_heads)] + past_key = torch.rand(past_shape, dtype=float_type, device=device) + past_value = torch.rand(past_shape, dtype=float_type, device=device) + layer_past = MyGPT2Attention.concat_key_value(past_key, past_value) + + total_sequence_length = past_sequence_length + sequence_length + + attention_mask = torch.ones([batch_size, total_sequence_length], dtype=torch.int32, device=device) + if padding_length > 0: + padding_position = total_sequence_length - padding_length + attention_mask[:, padding_position:] = 0 + elif padding_length < 0: # mask a random position + for i in range(batch_size): + padding_position = random.randint(0, total_sequence_length - 1) + attention_mask[i, padding_position] = 0 + + input_hidden_states = torch.normal(mean=0.0, std=0.1, + size=(batch_size, sequence_length, hidden_size)).to(float_type).to(device) + return input_hidden_states, attention_mask, layer_past + + +def get_output_names(debug=False): + outputs = ["attn_output", "present"] + if debug: + outputs += DEBUG_OUTPUTS + return outputs + + +def export_onnx(model, onnx_model_path, float16, hidden_size, num_attention_heads, debug, device): + from pathlib import Path + Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True) + + input_hidden_states, attention_mask, layer_past = create_inputs(float16=float16, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + device=device) + + with torch.no_grad(): + outputs = model(input_hidden_states, attention_mask=attention_mask, layer_past=layer_past) + + dynamic_axes = { + 'input_hidden_states': { + 0: 'batch_size', + 1: 'seq_len' + }, + "attn_output": { + 0: 'batch_size', + 1: 'seq_len' + }, + "past": { + 1: 'batch_size', + 3: 'past_seq_len' + }, + "present": { + 1: 'batch_size', + 3: 'total_seq_len' + }, + "attention_mask": { + 0: 'batch_size', + 1: 'total_seq_len' + } + } + if debug: + debug_dynamic_axes = { + "qk": { + 0: 'batch_size', + 1: 'seq_len' + }, + "norm_qk": { + 0: 'batch_size', + 1: 'seq_len' + }, + "softmax": { + 0: 'batch_size', + 1: 'seq_len' + }, + "attn_weights": { + 0: 'batch_size', + 1: 'seq_len' + } + } + for name in DEBUG_OUTPUTS: + dynamic_axes[name] = debug_dynamic_axes[name] + + torch.onnx.export(model, + args=(input_hidden_states, { + 'attention_mask': attention_mask, + 'layer_past': layer_past + }), + f=onnx_model_path, + input_names=['input_hidden_states', 'attention_mask', 'past'], + output_names=get_output_names(debug), + dynamic_axes=dynamic_axes, + example_outputs=outputs, + opset_version=11, + do_constant_folding=True) + print("exported:", onnx_model_path) + + +def optimize_onnx(input_onnx_path, optimized_onnx_path, num_heads, debug): + from onnxruntime.transformers.onnx_model import OnnxModel + m = onnx.load(input_onnx_path) + onnx_model = OnnxModel(m) + + nodes_to_remove = onnx_model.nodes() + output_names = ["attn_output", "present"] + DEBUG_OUTPUTS if debug else ["attn_output", "present"] + node_to_add = helper.make_node("Attention", + ["input_hidden_states", "c_attn.weight", "c_attn.bias", "attention_mask", "past"], + output_names, + "gpt2_attention", + num_heads=num_heads, + unidirectional=1, + domain="com.microsoft") + + onnx_model.remove_nodes(nodes_to_remove) + onnx_model.add_node(node_to_add) + onnx_model.prune_graph() + onnx_model.save_model_to_file(optimized_onnx_path) + + +def diff_outputs(torch_outputs, ort_outputs, index, relative=False): + """ Returns the maximum difference between PyTorch and OnnxRuntime outputs. + """ + expected_outputs = torch_outputs[index].cpu().numpy() + diff = numpy.abs(expected_outputs - ort_outputs[index]) + if relative: + return numpy.amax(diff / (numpy.abs(expected_outputs) + 1e-6)) + else: + return numpy.amax(diff) + + +def compare_outputs(torch_outputs, ort_outputs, rtol=1e-03, atol=1e-03, verbose=False): + """ Returns True if torch and ORT outputs are close for given thresholds, and False otherwise. + """ + is_all_close = True + max_abs_diff = [] + for i in range(len(ort_outputs)): + is_close = numpy.allclose(ort_outputs[i], torch_outputs[i].cpu().numpy(), rtol=rtol, atol=atol) + if not is_close: + is_all_close = False + if verbose: + print(f'output {i} ({ort_outputs[i].name}) are close: {is_close}') + max_abs_diff.append(diff_outputs(torch_outputs, ort_outputs, i)) + + if (not is_all_close) or (verbose and max(max_abs_diff) > 0): + messages = ["max_abs_diff per output:"] + output_names = ["attn_output", "present"] + DEBUG_OUTPUTS + for i, diff in enumerate(max_abs_diff): + messages.append(f"{output_names[i]}={diff},") + print(" ".join(messages)) + + return is_all_close, max(max_abs_diff) + + +def create_ort_session(onnx_model_path, use_gpu=True): + from onnxruntime import SessionOptions, InferenceSession, GraphOptimizationLevel, __version__ as onnxruntime_version + sess_options = SessionOptions() + sess_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL + sess_options.intra_op_num_threads = 2 + sess_options.log_severity_level = 2 + execution_providers = ['CPUExecutionProvider' + ] if not use_gpu else ['CUDAExecutionProvider', 'CPUExecutionProvider'] + return InferenceSession(onnx_model_path, sess_options, providers=execution_providers) + + +def onnxruntime_inference(ort_session, input_hidden_states, attention_mask, past): + ort_inputs = { + 'past': numpy.ascontiguousarray(past.cpu().numpy()), + 'attention_mask': numpy.ascontiguousarray(attention_mask.cpu().numpy()), + 'input_hidden_states': numpy.ascontiguousarray(input_hidden_states.cpu().numpy()), + } + + ort_outputs = ort_session.run(None, ort_inputs) + return ort_outputs + + +def verify_attention(model, + onnx_model_path, + batch_size, + hidden_size, + num_attention_heads, + sequence_length, + past_sequence_length, + float16, + device, + padding_length, + optimized, + test_cases=100): + print( + f"optimized={optimized}, batch_size={batch_size}, hidden_size={hidden_size}, num_attention_heads={num_attention_heads}, sequence_length={sequence_length}, past_sequence_length={past_sequence_length}, float16={float16}, padding_length={padding_length}, device={device}" + ) + passed_cases = 0 + max_diffs = [] + + ort_session = create_ort_session(onnx_model_path, device.type == 'cuda') + for i in range(test_cases): + input_hidden_states, attention_mask, layer_past = create_inputs(batch_size, hidden_size, num_attention_heads, + sequence_length, past_sequence_length, float16, + device, padding_length) + + with torch.no_grad(): + torch_outputs = model(input_hidden_states, layer_past=layer_past, attention_mask=attention_mask) + + ort_outputs = onnxruntime_inference(ort_session, input_hidden_states, attention_mask, layer_past) + + tolerance = 1e-03 if float16 else 1e-05 + is_all_close, max_diff = compare_outputs(torch_outputs, ort_outputs, rtol=tolerance, atol=tolerance) + max_diffs.append(max_diff) + if is_all_close: + passed_cases += 1 + + max_diff = max(max_diffs) + diff_count = len([i for i in max_diffs if i > 0]) + success_flag = "[FAILED]" if passed_cases < test_cases else "[OK]" + print(f"{success_flag} Passed_cases={passed_cases}/{test_cases}; Max_diff={max_diff}; Diff_count={diff_count}") + return test_cases - passed_cases + + +def run(batch_size, float16, optimized, hidden_size, num_attention_heads, device, test_cases): + test_name = f"batch_size={batch_size}, float16={float16}, optimized={optimized}, hidden_size={hidden_size}, num_attention_heads={num_attention_heads}" + print(f"\nTesting ONNX parity: {test_name}") + + debug = (not optimized + ) # or DEBUG_OUTPUTS==["softmax"] when you add an extra output for softmax result in Attention operator + model = MyGPT2Attention(hidden_size=hidden_size, num_attention_heads=num_attention_heads, debug=debug) + model.eval() + model.to(device) + if float16: + model.half() + + # Do not re-use onnx file from previous test since weights of model are random. + onnx_model_path = './temp/gpt_attention_{}.onnx'.format("fp16" if float16 else "fp32") + export_onnx(model, onnx_model_path, float16, hidden_size, num_attention_heads, debug, device) + + if optimized: + optimized_onnx_path = './temp/gpt_attention_opt_{}.onnx'.format("fp16" if float16 else "fp32") + optimize_onnx(onnx_model_path, optimized_onnx_path, num_attention_heads, debug) + onnx_path = optimized_onnx_path + else: + onnx_path = onnx_model_path + + # Test Case: No past state + sequence_length = 2 + past_sequence_length = 0 + padding_length = 0 + num_failure = 0 + num_failure += verify_attention(model, + onnx_path, + batch_size, + hidden_size, + num_attention_heads, + sequence_length, + past_sequence_length, + float16, + device, + padding_length, + optimized, + test_cases) + + # Test Case: with past state and padding last 2 words + sequence_length = 3 + past_sequence_length = 5 + padding_length = 2 + num_failure += verify_attention(model, + onnx_path, + batch_size, + hidden_size, + num_attention_heads, + sequence_length, + past_sequence_length, + float16, + device, + padding_length, + optimized, + test_cases) + + # Test Case: random mask one word + sequence_length = 1 + past_sequence_length = 128 + padding_length = -1 + num_failure += verify_attention(model, + onnx_path, + batch_size, + hidden_size, + num_attention_heads, + sequence_length, + past_sequence_length, + float16, + device, + padding_length, + optimized, + test_cases) + + # clean up onnx file + os.remove(onnx_model_path) + if optimized: + os.remove(onnx_path) + + return num_failure, test_name + + +class TestGptAttentionHuggingfaceParity(unittest.TestCase): + def setUp(self): + self.optimized = True # Change it to False if you want to test parity of non optimized ONNX + self.test_cases = 10 # Number of test cases per test run + + def run_test(self, batch_size, float16, optimized, hidden_size, num_attention_heads, device): + if float16 and device.type=='cpu': # CPU does not support FP16 + return + num_failure, test_name = run(batch_size, float16, optimized, hidden_size, num_attention_heads, device, self.test_cases) + self.assertTrue(num_failure == 0, test_name) + + def run_small(self, optimized, device): + for batch_size in [64]: + self.run_test(batch_size, float16=False, optimized=optimized, hidden_size=768, num_attention_heads=12, device=device) + self.run_test(batch_size, float16=True, optimized=optimized, hidden_size=768, num_attention_heads=12, device=device) + + def run_large(self, optimized, device): + for batch_size in [2]: + self.run_test(batch_size, float16=False, optimized=optimized, hidden_size=4096, num_attention_heads=32, device=device) + self.run_test(batch_size, float16=True, optimized=optimized, hidden_size=4096, num_attention_heads=32, device=device) + + def test_cpu(self): + cpu = torch.device('cpu') + self.run_small(self.optimized, cpu) + + def test_cuda(self): + if not torch.cuda.is_available(): + import pytest + pytest.skip('test requires GPU and torch+cuda') + else: + gpu = torch.device('cuda') + self.run_small(self.optimized, gpu) + + @pytest.mark.slow + def test_large_cuda(self): + if not torch.cuda.is_available(): + import pytest + pytest.skip('test requires GPU and torch+cuda') + else: + gpu = torch.device('cuda') + self.run_large(self.optimized, gpu) + +if __name__ == '__main__': + unittest.main()