From 54bbbb78ae3eefe69bd27f71dcc5affa977842c2 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Sun, 12 Apr 2020 22:55:37 -0700 Subject: [PATCH] Change mask_index input of Attention op to be optional (#3459) Change Mask Index to optional --- onnxruntime/contrib_ops/cpu/bert/attention.cc | 69 ++++++++++++------- .../contrib_ops/cuda/bert/attention.cc | 4 +- .../contrib_ops/cuda/bert/attention_impl.cu | 42 ++++++++++- .../contrib_ops/cuda/bert/attention_impl.h | 26 +++---- .../core/graph/contrib_ops/contrib_defs.cc | 3 +- .../test/contrib_ops/attention_op_test.cc | 35 +++++++++- 6 files changed, 133 insertions(+), 46 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index bf5182a2e7..1732a26c77 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -35,7 +35,7 @@ Status AttentionBase::CheckInputs(const OpKernelContext* context) const { // Input 0 - input : (batch_size, sequence_length, hidden_size) // Input 1 - weights : (hidden_size, 3 * hidden_size) // Input 2 - bias : (3 * hidden_size) - // Input 3 - mask_index : (batch_size) + // Input 3 - mask_index : (batch_size) if presented // Output : (batch_size, sequence_length, hidden_size) const Tensor* input = context->Input(0); @@ -77,13 +77,15 @@ Status AttentionBase::CheckInputs(const OpKernelContext* context) const { } const Tensor* mask_index = context->Input(3); - const auto mask_dims = mask_index->Shape().GetDims(); - if (mask_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 3 is expected to have 1 dimension, got ", - mask_dims.size()); - } - if (static_cast(mask_dims[0]) != batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 3 and 0 shall have same length at dimension 0"); + if (mask_index != nullptr) { // mask_index is optional + const auto mask_dims = mask_index->Shape().GetDims(); + if (mask_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 3 is expected to have 1 dimension, got ", + mask_dims.size()); + } + if (static_cast(mask_dims[0]) != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 3 and 0 shall have same length at dimension 0"); + } } return Status::OK(); @@ -179,22 +181,36 @@ Status Attention::Compute(OpKernelContext* context) const { // STEP.2: scratch(B, N, S, S) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, S, H -> B, N, H, S) + 1 x mask_index(B -> B, 1, // 1, 1) - auto scratch_data = - allocator->Alloc(SafeInt(batch_size) * num_heads_ * sequence_length * sequence_length * element_size); + size_t scratch_data_bytes = SafeInt(batch_size) * num_heads_ * sequence_length * sequence_length * element_size; + auto scratch_data = allocator->Alloc(scratch_data_bytes); BufferUniquePtr scratch_buffer(scratch_data, BufferDeleter(allocator)); { - auto scratch_broadcast_data = allocator->Alloc(SafeInt(batch_size) * sequence_length * element_size); - BufferUniquePtr scratch_broadcast_buffer(scratch_broadcast_data, BufferDeleter(allocator)); - memset(scratch_broadcast_data, 0, batch_size * sequence_length * element_size); - T* p_scratch_broadcast_current_data = reinterpret_cast(scratch_broadcast_data); - for (int b_i = 0; b_i < batch_size; b_i++) { - // TODO: mask_index can be used in softmax to save some calculation. - int mask = mask_index->template Data()[b_i]; - for (int m_i = mask; m_i < sequence_length; m_i++) { - p_scratch_broadcast_current_data[m_i] = static_cast(-10000.0); + size_t mask_data_bytes = 0; + if (mask_index != nullptr) { + mask_data_bytes = SafeInt(batch_size) * sequence_length * element_size; + } + + void* mask_data = nullptr; + if (mask_data_bytes > 0) { + mask_data = allocator->Alloc(mask_data_bytes); + memset(mask_data, 0, mask_data_bytes); + } + BufferUniquePtr mask_data_buffer(mask_data, BufferDeleter(allocator)); + + if (mask_index != nullptr) { + T* p_mask = reinterpret_cast(mask_data); + for (int b_i = 0; b_i < batch_size; b_i++) { + // TODO: mask_index can be used in softmax to save some calculation. + // Convert mask_index to mask (-10000 means out of range, which will be 0 after softmax): B => BxS + int valid_length = mask_index->template Data()[b_i]; + for (int m_i = valid_length; m_i < sequence_length; m_i++) { + p_mask[m_i] = static_cast(-10000.0); + } + p_mask += sequence_length; } - p_scratch_broadcast_current_data += sequence_length; + } else { + memset(scratch_data, 0, scratch_data_bytes); } const int loop_len = batch_size * num_heads_; @@ -206,12 +222,15 @@ Status Attention::Compute(OpKernelContext* context) const { ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { for (std::ptrdiff_t i = begin; i != end; ++i) { const std::ptrdiff_t batch_index = i / num_heads_; + // broadcast masks (B) -> (B.N.)S.S - const T* broadcast_data_src = reinterpret_cast(scratch_broadcast_data) + batch_index * sequence_length; - T* broadcast_data_dest = reinterpret_cast(scratch_data) + sequence_length * sequence_length * i; - for (int seq_index = 0; seq_index < sequence_length; seq_index++) { - memcpy(broadcast_data_dest, broadcast_data_src, sequence_length * sizeof(T)); - broadcast_data_dest += sequence_length; + if (mask_index != nullptr) { + const T* broadcast_data_src = reinterpret_cast(mask_data) + batch_index * sequence_length; + T* broadcast_data_dest = reinterpret_cast(scratch_data) + sequence_length * sequence_length * i; + for (int seq_index = 0; seq_index < sequence_length; seq_index++) { + memcpy(broadcast_data_dest, broadcast_data_src, sequence_length * sizeof(T)); + broadcast_data_dest += sequence_length; + } } // gemm diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 0b70556ffa..4959b07b6a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -40,7 +40,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { // Input 0 - input : (batch_size, sequence_length, hidden_size) // Input 1 - weights : (hidden_size, 3 * hidden_size) // Input 2 - bias : (3 * hidden_size) - // Input 3 - mask_index : (batch_size) + // Input 3 - mask_index : (batch_size) if presented // Output : (batch_size, sequence_length, hidden_size) const Tensor* input = context->Input(0); const Tensor* weights = context->Input(1); @@ -88,7 +88,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { auto temp_buffer = GetScratchBuffer(workSpaceSize); if (!LaunchAttentionKernel( reinterpret_cast(gemm_buffer.get()), - mask_index->template Data(), + nullptr == mask_index ? nullptr : mask_index->template Data(), output->template MutableData(), batch_size, sequence_length, diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 26aee9affb..adb73cbf64 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -152,6 +152,38 @@ __device__ inline void SoftmaxSmall(const int ld, const int num_valid, const T* } } +template +__global__ void SoftmaxKernelSmall(const int sequence_length, const T* input, T* output) { + SoftmaxSmall(sequence_length, sequence_length, input, output); +} + +template +__global__ void SoftmaxKernel(const int sequence_length, const T* input, T* output) { + Softmax(sequence_length, sequence_length, input, output); +} + +template +bool ComputeSoftmax( + cudaStream_t stream, const int sequence_length, const int batch_size, const int num_heads, + const T* input, T* output) { + const dim3 grid(sequence_length * num_heads, batch_size, 1); + if (sequence_length <= 32) { + const int blockSize = 32; + SoftmaxKernelSmall<<>>(sequence_length, input, output); + } else if (sequence_length <= 128) { + const int blockSize = 128; + SoftmaxKernelSmall<<>>(sequence_length, input, output); + } else if (sequence_length == 384) { + const int blockSize = 384; + SoftmaxKernelSmall<<>>(sequence_length, input, output); + } else { + const int blockSize = 256; + SoftmaxKernel<<>>(sequence_length, input, output); + } + + return CUDA_CALL(cudaPeekAtLastError()); +} + template __global__ void MaskedSoftmaxKernelSmall(const int sequence_length, const int* mask_index, const T* input, T* output) { __shared__ int num_valid; @@ -390,8 +422,14 @@ bool QkvToContext( } // apply softmax and store result P to scratch2: BxNxSxS - if (!ComputeMaskedSoftmax(stream, sequence_length, batch_size, num_heads, mask_index, scratch1, scratch2)) { - return false; + if (nullptr != mask_index) { + if (!ComputeMaskedSoftmax(stream, sequence_length, batch_size, num_heads, mask_index, scratch1, scratch2)) { + return false; + } + } else { + if (!ComputeSoftmax(stream, sequence_length, batch_size, num_heads, scratch1, scratch2)) { + return false; + } } // compute P*V (as V*P), and store in scratch3: BxNxSxH diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index 5e8a53c539..d3a1857fce 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -7,20 +7,20 @@ namespace onnxruntime { namespace contrib { namespace cuda { - size_t GetAttentionWorkspaceSize(size_t element_size, int batchsize, int num_heads, int head_size, int sequence_length); +size_t GetAttentionWorkspaceSize(size_t element_size, int batchsize, int num_heads, int head_size, int sequence_length); - bool LaunchAttentionKernel( - const void* input, // Input tensor - const int* mask_index, // Nask index where each element is length of a sequence - void* output, // Output tensor - int batch_size, // Batch size (B) - int sequence_length, // Sequence length (S) - int num_heads, // Number of attention heads (N) - int head_size, // Hidden layer size per head (H) - void* workspace, // Temporary buffer - cublasHandle_t& cublas, // Cublas handle - const size_t element_size // Element size of input tensor - ); +bool LaunchAttentionKernel( + const void* input, // Input tensor + const int* mask_index, // Mask index (length of each sequence). NULL means no mask. + void* output, // Output tensor + int batch_size, // Batch size (B) + int sequence_length, // Sequence length (S) + int num_heads, // Number of attention heads (N) + int head_size, // Hidden layer size per head (H) + void* workspace, // Temporary buffer + cublasHandle_t& cublas, // Cublas handle + const size_t element_size // Element size of input tensor +); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 5472912013..25cfd774e1 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -299,7 +299,7 @@ void RegisterBertSchemas() { .Input(0, "input", "3D input tensor with shape (batch_size, sequence_length, hidden_size), hidden_size = num_heads * head_size", "T") .Input(1, "weight", "2D input tensor with shape (hidden_size, 3 * hidden_size)", "T") .Input(2, "bias", "1D input tensor with shape (3 * hidden_size)", "T") - .Input(3, "mask_index", "Attention mask index with shape (batch_size)", "M") + .Input(3, "mask_index", "Attention mask index with shape (batch_size).", "M", OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", "T") .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") .TypeConstraint("M", {"tensor(int32)"}, "Constrain mask index to integer types") @@ -2325,7 +2325,6 @@ It's an extension of Gelu. It takes the sum of input A and bias input B as the i .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput); RegisterBertSchemas(); - } } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc index 464524b016..571d4e99b1 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test.cc @@ -39,16 +39,18 @@ static void RunAttentionTest( tester.AddInput("input", input_dims, ToFloat16(input_data)); tester.AddInput("weight", weights_dims, ToFloat16(weights_data)); tester.AddInput("bias", bias_dims, ToFloat16(bias_data)); - tester.AddInput("mask_index", mask_index_dims, mask_index_data); tester.AddOutput("output", output_dims, ToFloat16(output_data)); } else { tester.AddInput("input", input_dims, input_data); tester.AddInput("weight", weights_dims, weights_data); tester.AddInput("bias", bias_dims, bias_data); - tester.AddInput("mask_index", mask_index_dims, mask_index_data); tester.AddOutput("output", output_dims, output_data); } + if (mask_index_data.size() > 0) { // mask index is optional. + tester.AddInput("mask_index", mask_index_dims, mask_index_data); + } + tester.Run(); } } @@ -204,5 +206,34 @@ TEST(AttentionTest, AttentionMaskExceedSequence) { batch_size, sequence_length, hidden_size, number_of_heads); } +TEST(AttentionTest, AttentionNoMaskIndex) { + int batch_size = 1; + int sequence_length = 2; + int hidden_size = 4; + int number_of_heads = 2; + + std::vector input_data = { + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f}; + + std::vector weight_data = { + 0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f, + 0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f, + 0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f, + 0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f}; + + std::vector bias_data = { + -0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f}; + + // No mask_index + std::vector mask_index_data = {}; + + std::vector output_data = { + 3.1495983600616455f, 0.10843668878078461f, 4.25f, 5.6499996185302734f, + 3.9696791172027588f, 0.073143675923347473f, 4.2499995231628418f, 5.6499991416931152f}; + + RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, + batch_size, sequence_length, hidden_size, number_of_heads); +} } // namespace test } // namespace onnxruntime