From fd16085cea3d466a3e71a7cfb47373b616f796a1 Mon Sep 17 00:00:00 2001 From: zhangyaobit <1034716+zhangyaobit@users.noreply.github.com> Date: Thu, 17 Feb 2022 09:02:45 -0800 Subject: [PATCH] Zhanyao/attention (#10545) * Enable Attention op for ROCM EP. As a note, potential hipify improvements: (1) handle math contants (attention_softmax.h), (2) correctly generate transpose options for the GEMM helpers, consider counterpart/dummy API for CublasMathModeSetter (attention_impl.cu, attention_impl.cu). After these improvements, we don't need to manually keep copies of the above mentioned files any more. * Clean up debugging code. --- .../contrib_ops/rocm/bert/attention.cc | 126 ++++++ .../contrib_ops/rocm/bert/attention_impl.cu | 213 +++++++++ .../contrib_ops/rocm/bert/attention_softmax.h | 423 ++++++++++++++++++ .../contrib_ops/rocm/rocm_contrib_kernels.cc | 4 +- .../test/contrib_ops/attention_op_test.cc | 9 +- tools/ci_build/amd_hipify.py | 5 +- 6 files changed, 773 insertions(+), 7 deletions(-) create mode 100644 onnxruntime/contrib_ops/rocm/bert/attention.cc create mode 100644 onnxruntime/contrib_ops/rocm/bert/attention_impl.cu create mode 100644 onnxruntime/contrib_ops/rocm/bert/attention_softmax.h diff --git a/onnxruntime/contrib_ops/rocm/bert/attention.cc b/onnxruntime/contrib_ops/rocm/bert/attention.cc new file mode 100644 index 0000000000..1320f4c3c8 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/attention.cc @@ -0,0 +1,126 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/rocm/bert/attention.h" +#include "contrib_ops/rocm/bert/attention_impl.h" +#include "core/providers/rocm/rocm_common.h" +#include "core/providers/rocm/shared_inc/fpgeneric.h" + +using namespace onnxruntime::rocm; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + Attention, \ + kMSDomain, \ + 1, \ + T, \ + kRocmExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Attention); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) + +template +Attention::Attention(const OpKernelInfo& info) : RocmKernel(info), AttentionBase(info) {} + +template +Status Attention::ComputeInternal(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const Tensor* weights = context->Input(1); + const Tensor* bias = context->Input(2); + const Tensor* mask_index = context->Input(3); + const Tensor* past = context->Input(4); + const Tensor* extra_add_qk = context->Input(5); + + auto& device_prop = GetDeviceProp(); + ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), mask_index, past, extra_add_qk, device_prop.maxThreadsPerBlock)); + + // input shape (batch_size, sequence_length, input_hidden_size) + const auto& shape = input->Shape(); + int batch_size = static_cast(shape[0]); + int sequence_length = static_cast(shape[1]); + int input_hidden_size = static_cast(shape[2]); + + // bias shape (3 * hidden_size) + const auto& bias_shape = bias->Shape(); + int hidden_size = static_cast(bias_shape[0]) / 3; + + int head_size = hidden_size / num_heads_; + + TensorShapeVector output_shape(3); + output_shape[0] = shape[0]; + output_shape[1] = shape[1]; + output_shape[2] = static_cast(hidden_size); + Tensor* output = context->Output(0, output_shape); + + int past_sequence_length = 0; + Tensor* present = GetPresent(context, past, batch_size, head_size, sequence_length, past_sequence_length); + + rocblas_handle rocblas = RocblasHandle(); + constexpr size_t element_size = sizeof(T); + + // Use GEMM for fully connection. + int m = batch_size * sequence_length; + int n = 3 * hidden_size; + int k = input_hidden_size; + auto gemm_buffer = GetScratchBuffer(batch_size * sequence_length * 3 * hidden_size * element_size); + + typedef typename ToHipType::MappedType HipT; + HipT one = ToHipType::FromFloat(1.0f); + HipT zero = ToHipType::FromFloat(0.0f); + + // Bias shape is (N), broadcast using B(N, M) = 1 * bias(N, 1) x ones(1, M) + 0 * B. + // TODO: use custom kernel of expand to improve the performance. + ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper( + rocblas, rocblas_operation_none, rocblas_operation_none, n, m, 1, &one, + reinterpret_cast(bias->template Data()), n, + GetConstOnes(m), 1, + &zero, reinterpret_cast(gemm_buffer.get()), n)); + + // Gemm, note that ROCM assumes col-major, so result(N, M) = 1 * weights x input + 1 x B. + ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper( + rocblas, rocblas_operation_none, rocblas_operation_none, n, m, k, &one, + reinterpret_cast(weights->template Data()), n, + reinterpret_cast(input->template Data()), k, + &one, reinterpret_cast(gemm_buffer.get()), n)); + + size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, num_heads_, head_size, sequence_length, past_sequence_length); + auto temp_buffer = GetScratchBuffer(workSpaceSize); + if (!LaunchAttentionKernel( + device_prop, + Stream(), + reinterpret_cast(gemm_buffer.get()), + nullptr == mask_index ? nullptr : mask_index->template Data(), + nullptr == mask_index ? gsl::span() : mask_index->Shape().GetDims(), + output->template MutableData(), + batch_size, + sequence_length, + num_heads_, + head_size, + temp_buffer.get(), + rocblas, + element_size, + is_unidirectional_, + past_sequence_length, + nullptr == past ? nullptr : past->template Data(), + nullptr == extra_add_qk ? nullptr : extra_add_qk->template Data(), + nullptr == present ? nullptr : present->template MutableData())) { + // Get last error to reset it to hipSuccess. + HIP_CALL(hipGetLastError()); + return Status(common::ONNXRUNTIME, common::FAIL); + } + + return Status::OK(); +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu new file mode 100644 index 0000000000..83514aff9b --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu @@ -0,0 +1,213 @@ +/* + The implementation of this file is based on qkvToContext plugin in TensorRT demo: + https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ + +Copyright 2019 NVIDIA Corporation + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Modifications: scaling is moved from masked softmax to the gemm before that. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "core/providers/rocm/cu_inc/common.cuh" +#include "core/providers/rocm/rocm_common.h" +#include "core/providers/rocm/shared_inc/fpgeneric.h" +#include "contrib_ops/rocm/bert/attention_impl.h" +#include "contrib_ops/rocm/bert/attention_softmax.h" +#include "contrib_ops/rocm/bert/transformer_common.h" + +using namespace onnxruntime::rocm; +using namespace hipcub; + +#define CHECK_ROCM(expr) \ + if (!HIP_CALL(expr)) { \ + return false; \ + } + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +static size_t AlignTo(size_t a, size_t b) { + return CeilDiv(a, b) * b; +} + +size_t GetAttentionScratchSize(size_t element_size, int batch_size, int num_heads, int sequence_length, int all_sequence_length) { + const size_t len = batch_size * num_heads * sequence_length * all_sequence_length; + const size_t bytes = len * element_size; + + const size_t alignment = 256; + const size_t bytesAligned = AlignTo(bytes, alignment); + return bytesAligned; +} + +size_t GetAttentionWorkspaceSize( + size_t element_size, + int batch_size, + int num_heads, + int head_size, + int sequence_length, + int past_sequence_length) { + size_t qkv_size = 3 * batch_size * sequence_length * num_heads * head_size * element_size; + return qkv_size + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, past_sequence_length + sequence_length); +} + +template +bool QkvToContext( + const hipDeviceProp_t& prop, rocblas_handle& rocblas, hipStream_t stream, + const int batch_size, const int sequence_length, const int num_heads, const int head_size, const size_t element_size, + const T* input, T* output, T* workspace, + const int* mask_index, gsl::span mask_index_dims, + bool is_unidirectional, int past_sequence_length, const T* past, const T* extra_add_qk, T* present, bool use_persistent_softmax) { + const int all_sequence_length = past_sequence_length + sequence_length; + const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, all_sequence_length); + T* scratch1 = workspace; + T* scratch2 = scratch1 + (bytes / element_size); + T* scratch3 = scratch2 + (bytes / element_size); + + const int max_threads_per_block = prop.maxThreadsPerBlock; + + // input should be BxSx3xNxH => scratch3: 3xBxNxSxH + if (!LaunchTransQkv(stream, 3, sequence_length, batch_size, head_size, num_heads, max_threads_per_block, false, input, scratch3)) { + return false; + } + + // now scratch3 has Q, K, V: each has size BxNxSxH + const int batches = batch_size * num_heads; + const int size_per_batch = sequence_length * head_size; + const int total_size = batches * size_per_batch; + + const T* q = scratch3; + const T* k = q + total_size; + const T* v = k + total_size; + + rocblas_set_stream(rocblas, stream); + + // Concat past (2xBxNxS'xH) to present (2xBxNxS*xH): + // past_k (BxNxS'xH) + k (BxNxSxH) => present_k (BxNxS*xH) + // past_v (BxNxS'xH) + v (BxNxSxH) => present_v (BxNxS*xH) + const int present_size_per_batch = all_sequence_length * head_size; + if (nullptr != present) { + if (!LaunchConcatPastToPresent(stream, all_sequence_length, sequence_length, batch_size, head_size, num_heads, max_threads_per_block, past, k, present)) { + return false; + } + + // update pointers to present_k and present_v. + k = present; + v = present + batches * present_size_per_batch; + } + + // Raw attention mask could be 2D (BxS) or 3D (BxSxS*) or 4D(Bx1xMxM), where M is the max sequence length. + bool use_raw_attention_mask = (nullptr != mask_index && mask_index_dims.size() >= 2); + + // compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxS* + // Q: BxNxSxH, K (present_k): BxNxS*xH, Q*K': BxNxSxS* + const float rsqrt_head_size = 1.f / sqrt(static_cast(head_size)); + const int temp_matrix_size = sequence_length * all_sequence_length; + + typedef typename ToHipType::MappedType HipT; + + //float one = 1.0f; + //float zero = 0.f; + const HipT one = ToHipType::FromFloat(1.0f); + const HipT zero = ToHipType::FromFloat(0.f); + + // For raw attention mask, the scalar if 1/sqrt(H) is moved to softmax computation. + //float temp_alpha = use_raw_attention_mask ? one : rsqrt_head_size; + const HipT alpha = use_raw_attention_mask ? one : ToHipType::FromFloat(rsqrt_head_size); + + if (!ROCBLAS_CALL(rocblasGemmStridedBatchedHelper( + rocblas, rocblas_operation_transpose, rocblas_operation_none, all_sequence_length, sequence_length, head_size, &alpha, k, head_size, present_size_per_batch, + q, head_size, size_per_batch, &zero, scratch1, all_sequence_length, temp_matrix_size, batches))) { + return false; + } + + // apply softmax and store result P to scratch2: BxNxSxS* + if (use_raw_attention_mask) { // 2d, 3d or 4d attention mask + const int mask_dimension = static_cast(mask_index_dims.size()); + const int64_t max_sequence_length = mask_dimension == 4 ? mask_index_dims.at(3) : 0; + + T* persistent_softmax_workspace = scratch1; // replace Q*K' in place with masked score if persistent softmax is selected. + if (!ComputeSoftmaxWithRawMask(stream, all_sequence_length, sequence_length, batch_size, num_heads, mask_index, nullptr, extra_add_qk, scratch1, scratch2, + is_unidirectional, rsqrt_head_size, mask_dimension, static_cast(max_sequence_length), + use_persistent_softmax, persistent_softmax_workspace)) { + return false; + } + } else if (nullptr != mask_index) { // 1d mask index + ORT_ENFORCE(mask_index_dims.size() == 1); + // mask_index has 1D shape: either (batch_size) or (2*batch_size). Only the later one has start postions. + const int* mask_start = (mask_index_dims.at(0) > batch_size) ? mask_index + batch_size : nullptr; + if (!ComputeSoftmaxWithMask1D(stream, all_sequence_length, sequence_length, batch_size, num_heads, mask_index, mask_start, extra_add_qk, scratch1, scratch2, is_unidirectional)) { + return false; + } + } else { // no mask + if (!ComputeSoftmax(stream, all_sequence_length, sequence_length, batch_size, num_heads, extra_add_qk, scratch1, scratch2, is_unidirectional)) { + return false; + } + } + + // compute P*V (as V*P), and store in scratch3: BxNxSxH + if (!ROCBLAS_CALL(rocblasGemmStridedBatchedHelper( + rocblas, rocblas_operation_none, rocblas_operation_none, head_size, sequence_length, all_sequence_length, &one, v, head_size, present_size_per_batch, + scratch2, all_sequence_length, temp_matrix_size, &zero, scratch3, head_size, size_per_batch, batches))) { + return false; + } + + // scratch3 is BxNxSxH, transpose to output BxSxNxH + return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads, max_threads_per_block, false, scratch3, output); +} + +bool LaunchAttentionKernel( + const hipDeviceProp_t& prop, + hipStream_t stream, + const void* input, + const int* mask_index, + gsl::span mask_index_dims, + void* output, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + void* workspace, + rocblas_handle& rocblas, + const size_t element_size, + bool is_unidirectional, + int past_sequence_length, + const void* past, + const void* extra_add_qk, + void* present) { + // For testing, environment variable ORT_TRANSFORMER_OPTIONS=1 could enable persistent softmax + const TransformerOptions* options = TransformerOptions::GetInstance(); + bool use_persistent_softmax = options->IsPrecisionMode() && !options->DisablePersistentSoftmax(); + if (element_size == 2) { + return QkvToContext(prop, rocblas, stream, + batch_size, sequence_length, num_heads, head_size, element_size, + reinterpret_cast(input), reinterpret_cast<__half*>(output), reinterpret_cast<__half*>(workspace), + mask_index, mask_index_dims, is_unidirectional, + past_sequence_length, reinterpret_cast(past), reinterpret_cast(extra_add_qk), + reinterpret_cast<__half*>(present), use_persistent_softmax); + } else { + return QkvToContext(prop, rocblas, stream, + batch_size, sequence_length, num_heads, head_size, element_size, + reinterpret_cast(input), reinterpret_cast(output), reinterpret_cast(workspace), + mask_index, mask_index_dims, is_unidirectional, + past_sequence_length, reinterpret_cast(past), reinterpret_cast(extra_add_qk), + reinterpret_cast(present), use_persistent_softmax); + } +} +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h b/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h new file mode 100644 index 0000000000..cbd7472a17 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h @@ -0,0 +1,423 @@ +#include "hip/hip_runtime.h" +/* + The implementation of this file is based on qkvToContext plugin in TensorRT demo: + https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ + +Copyright 2019 NVIDIA Corporation + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#pragma once + +#include +#include +#include +#include "core/providers/rocm/cu_inc/common.cuh" +#include "core/providers/rocm/rocm_common.h" +#include "core/providers/rocm/math/softmax.h" + +#define ROCMRT_INF_F __int_as_float(0x7f800000) + +using namespace onnxruntime::rocm; +using namespace hipcub; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +template +__device__ inline void Softmax(const int all_sequence_length, + const int sequence_length, + const int valid_end, + const int valid_start, + const T* add_before_softmax, + const T* input, + T* output) { + using BlockReduce = hipcub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmp_storage; + + __shared__ float sum_reverse_block; + __shared__ float max_block; + + float thread_data_max(-ROCMRT_INF_F); + + // e^x is represented as infinity if x is large enough, like 100.f. + // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. + // a math transform as below is leveraged to get a stable softmax: + // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) + const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length; + for (int i = threadIdx.x; i < valid_end; i += TPB) { + if (i >= valid_start) { + const int index = offset + i; + float input_at_idx = add_before_softmax == nullptr ? float(input[index]) : float(input[index] + add_before_softmax[index]); + if (thread_data_max < input_at_idx) { + thread_data_max = input_at_idx; + } + } + } + + const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, hipcub::Max()); + + // Store max value + if (threadIdx.x == 0) { + max_block = max; + } + __syncthreads(); + + float thread_data_sum(0.f); + for (int i = threadIdx.x; i < valid_end; i += TPB) { + if (i >= valid_start) { + const int index = offset + i; + float val = add_before_softmax == nullptr ? input[index] : input[index] + add_before_softmax[index]; + thread_data_sum += expf(val - max_block); + } + } + + const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_sum, hipcub::Sum()); + if (threadIdx.x == 0) { + sum_reverse_block = 1.f / sum; + } + __syncthreads(); + + for (int i = threadIdx.x; i < all_sequence_length; i += TPB) { + const int index = offset + i; + float input_at_idx = add_before_softmax == nullptr ? float(input[index]) : float(input[index] + add_before_softmax[index]); + const float val = (i >= valid_start && i < valid_end) ? expf(input_at_idx - max_block) * sum_reverse_block : 0.f; + output[index] = T(val); + } +} + +template +__device__ inline void SoftmaxSmall(const int all_sequence_length, + const int sequence_length, + const int valid_end, + const int valid_start, + const T* add_before_softmax, + const T* input, + T* output, + bool is_unidirectional) { + using BlockReduce = hipcub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmp_storage; + + __shared__ float sum_reverse_block; + __shared__ float max_block; + + // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S; + const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length; + const int index = offset + threadIdx.x; + + bool is_valid = false; // whether it has attention mask == 1. + + // Update end position for unidirectional. + int end = valid_end; + if (is_unidirectional) { + int end_unid = all_sequence_length - sequence_length + (blockIdx.x % sequence_length) + 1; + if (end_unid <= valid_start) { + // In this situation, mask of [0, end_unid) and [valid_start, valid_end) has -10000, and [end_unid, valid_start) and [valid_end, all_seq_len) has -20000. + // So [0, end_unid) will also have value after softmax. + is_valid = threadIdx.x < end_unid; + } else { + end = min(valid_end, end_unid); + } + } + + is_valid = is_valid || (threadIdx.x >= valid_start && threadIdx.x < end); + + // e^x is represented as infinity if x is large enough, like 100.f. + // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. + // a math transform as below is leveraged to get a stable softmax: + // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) + float input_data = add_before_softmax == nullptr ? float(input[index]) : float(input[index] + add_before_softmax[index]); + float thread_data_max = is_valid ? input_data : float(-ROCMRT_INF_F); + const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, hipcub::Max(), end); + + // Store max value + if (threadIdx.x == 0) { + max_block = max; + } + __syncthreads(); + + float thread_data_exp(0.f); + if (is_valid) { + thread_data_exp = expf(input_data - max_block); + } + + const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, hipcub::Sum(), end); + + // Store value of 1.0/sum. + if (threadIdx.x == 0) { + sum_reverse_block = (1.f) / sum; + } + __syncthreads(); + + // threadIdx.x might be larger than all_sequence_length due to alignment to 32x. + if (threadIdx.x < all_sequence_length) { + output[index] = T(thread_data_exp * sum_reverse_block); + } +} + +template +__device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, + const int sequence_length, + const int* attention_mask, // 2D, 3D or 4D attention mask + const bool* key_padding_mask, + const T* add_before_softmax, + const T* input, + T* output, + const bool is_unidirectional, + const float rsqrt_head_size, + const int mask_dimension, + const int max_sequence_length, + const bool skip_softmax) { + using BlockReduce = hipcub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmp_storage; + + __shared__ float sum_reverse_block; + __shared__ float max_block; + + // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S; + int index = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length + threadIdx.x; + + float thread_data = -ROCMRT_INF_F; + if (threadIdx.x < all_sequence_length) { + if (add_before_softmax == nullptr) { + thread_data = float(input[index]) * rsqrt_head_size; + } else { + thread_data = float(input[index] + add_before_softmax[index]) * rsqrt_head_size; + } + + const int sequence_index = blockIdx.x % sequence_length; + if (is_unidirectional) { + int from_index = all_sequence_length - sequence_length + sequence_index; // offset of from token in all sequence length. + if (threadIdx.x > from_index) { + thread_data = -10000.0f; + } + } + + int mask_offset = 0; + const int batch_index = blockIdx.y; + if (mask_dimension == 2) { + mask_offset = batch_index * all_sequence_length + threadIdx.x; + } else if (mask_dimension == 3) { + mask_offset = (batch_index * sequence_length + sequence_index) * all_sequence_length + threadIdx.x; + } else if (mask_dimension == 4) { + mask_offset = (batch_index * max_sequence_length + all_sequence_length - sequence_length + sequence_index) * max_sequence_length + threadIdx.x; + } + + if (nullptr == key_padding_mask) { + const int& mask = attention_mask[mask_offset]; + if (mask == 0) + thread_data += -10000.0f; + } else { + const bool mask = key_padding_mask[mask_offset]; + if (mask) { + thread_data = -ROCMRT_INF_F; + } + } + } + + if (skip_softmax) { + if (threadIdx.x < all_sequence_length) { + output[index] = T(thread_data); + } + return; + } + + const float max = BlockReduce(tmp_storage).Reduce(thread_data, hipcub::Max(), all_sequence_length); + + // Store max value + if (threadIdx.x == 0) { + max_block = max; + } + __syncthreads(); + + float thread_data_exp = threadIdx.x < all_sequence_length ? expf(thread_data - max_block) : 0.0f; + const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, hipcub::Sum(), all_sequence_length); + + // Store value of 1.0/sum + if (threadIdx.x == 0) { + sum_reverse_block = (1.f) / sum; + } + __syncthreads(); + + if (threadIdx.x < all_sequence_length) { + output[index] = T(thread_data_exp * sum_reverse_block); + } +} + +template +__global__ void SoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, const T* add_before_softmax, const T* input, T* output, bool is_unidirectional) { + SoftmaxSmall(all_sequence_length, sequence_length, all_sequence_length, 0, add_before_softmax, input, output, is_unidirectional); +} + +template +__global__ void SoftmaxKernel(const int all_sequence_length, const int sequence_length, const T* add_before_softmax, const T* input, T* output) { + Softmax(all_sequence_length, sequence_length, all_sequence_length, 0, add_before_softmax, input, output); +} + +template +bool ComputeSoftmax( + hipStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, + const T* add_before_softmax, const T* input, T* output, bool is_unidirectional) { + const dim3 grid(sequence_length * num_heads, batch_size, 1); + if (all_sequence_length <= 32) { + const int blockSize = 32; + hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall), grid, blockSize, 0, stream, all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); + } else if (all_sequence_length <= 64) { + const int blockSize = 64; + hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall), grid, blockSize, 0, stream, all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); + } else if (all_sequence_length <= 128) { + const int blockSize = 128; + hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall), grid, blockSize, 0, stream, all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); + } else if (all_sequence_length <= 256) { + const int blockSize = 256; + hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall), grid, blockSize, 0, stream, all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); + } else if (all_sequence_length <= 512) { + const int blockSize = 512; + hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall), grid, blockSize, 0, stream, all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); + } else if (all_sequence_length <= 1024) { + const int blockSize = 1024; + hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall), grid, blockSize, 0, stream, all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional); + } else if (!is_unidirectional) { + const int blockSize = 1024; + hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernel), grid, blockSize, 0, stream, all_sequence_length, sequence_length, add_before_softmax, input, output); + } else { + ORT_THROW("Attention ROCM operator does not support total sequence length > 1024."); + } + + return HIP_CALL(hipPeekAtLastError()); +} + +template +__global__ void MaskedSoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, const int* mask_end, const int* mask_start, const T* add_before_softmax, const T* input, T* output, bool is_unidirectional) { + __shared__ int start_position; + __shared__ int end_position; + + if (threadIdx.x == 0) { + const int batch = blockIdx.y; + start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0; + end_position = min(all_sequence_length, mask_end[batch]); + + // Attend to no word has same effect as attend to all words. This is added to get parity with CPU result. + if (start_position >= end_position) { + start_position = 0; + end_position = all_sequence_length; + } + } + __syncthreads(); + + SoftmaxSmall(all_sequence_length, sequence_length, end_position, start_position, add_before_softmax, input, output, is_unidirectional); +} + +template +__global__ void MaskedSoftmaxKernel(const int all_sequence_length, const int sequence_length, const int* mask_end, const int* mask_start, const T* add_before_softmax, const T* input, T* output) { + __shared__ int start_position; + __shared__ int end_position; + + if (threadIdx.x == 0) { + const int batch = blockIdx.y; + start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0; + end_position = min(all_sequence_length, mask_end[batch]); + + // Attend to no word has same effect as attend to all words. This is added to get parity with CPU result. + if (start_position >= end_position) { + start_position = 0; + end_position = all_sequence_length; + } + } + __syncthreads(); + + Softmax(all_sequence_length, sequence_length, end_position, start_position, add_before_softmax, input, output); +} + +template +__global__ void SoftmaxWithRawMaskSmallKernel(const int all_sequence_length, const int sequence_length, const int* attention_mask, const bool* key_padding_mask, const T* add_before_softmax, const T* input, + T* output, const bool is_unidirectional, const float rsqrt_head_size, const int mask_dimension, const int max_sequence_length, + const bool skip_softmax) { + SoftmaxWithRawMaskSmall(all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, output, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, skip_softmax); +} + +template +bool ComputeSoftmaxWithMask1D(hipStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, + const int* mask_index, const int* mask_start, const T* add_before_softmax, const T* input, T* output, const bool is_unidirectional) { + const dim3 grid(sequence_length * num_heads, batch_size, 1); + + if (all_sequence_length <= 32) { + const int blockSize = 32; + hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall), grid, blockSize, 0, stream, all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output, is_unidirectional); + } else if (all_sequence_length <= 64) { + const int blockSize = 64; + hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall), grid, blockSize, 0, stream, all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output, is_unidirectional); + } else if (all_sequence_length <= 128) { + const int blockSize = 128; + hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall), grid, blockSize, 0, stream, all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output, is_unidirectional); + } else if (all_sequence_length <= 256) { + const int blockSize = 256; + hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall), grid, blockSize, 0, stream, all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output, is_unidirectional); + } else if (all_sequence_length <= 512) { + const int blockSize = 512; + hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall), grid, blockSize, 0, stream, all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output, is_unidirectional); + } else if (all_sequence_length <= 1024) { + const int blockSize = 1024; + hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall), grid, blockSize, 0, stream, all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output, is_unidirectional); + } else if (!is_unidirectional) { + const int blockSize = 1024; + hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernel), grid, blockSize, 0, stream, all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output); + } else { + ORT_THROW("Attention ROCM operator does not support total sequence length > 1024."); + } + + return HIP_CALL(hipPeekAtLastError()); +} + +template +bool ComputeSoftmaxWithRawMask(hipStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, + const int* attention_mask, const bool* key_padding_mask, const T* add_before_softmax, const T* input, T* output, const bool is_unidirectional, + const float rsqrt_head_size, const int mask_dimension, const int max_sequence_length, const bool use_persistent_softmax, T* persistent_softmax_workspace) { + const dim3 grid(sequence_length * num_heads, batch_size, 1); + + T* out = use_persistent_softmax ? persistent_softmax_workspace : output; + if (all_sequence_length <= 32) { + const int blockSize = 32; + hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel), grid, blockSize, 0, stream, all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax); + } else if (all_sequence_length <= 64) { + const int blockSize = 64; + hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel), grid, blockSize, 0, stream, all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax); + } else if (all_sequence_length <= 128) { + const int blockSize = 128; + hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel), grid, blockSize, 0, stream, all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax); + } else if (all_sequence_length <= 256) { + const int blockSize = 256; + hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel), grid, blockSize, 0, stream, all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax); + } else if (all_sequence_length <= 512) { + const int blockSize = 512; + hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel), grid, blockSize, 0, stream, all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax); + } else if (all_sequence_length <= 1024) { + const int blockSize = 1024; + hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel), grid, blockSize, 0, stream, all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax); + } else { + ORT_THROW("Attention ROCM operator does not support total sequence length > 1024."); + } + + if (use_persistent_softmax) { + dispatch_warpwise_softmax_forward(stream, output, persistent_softmax_workspace, all_sequence_length, all_sequence_length, batch_size * num_heads * sequence_length); + } + + return HIP_CALL(hipPeekAtLastError()); +} + + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc index ac2435baa0..841acc144c 100644 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc @@ -125,8 +125,8 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc index 64e85d88c5..6a5ed608c1 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test.cc @@ -46,10 +46,11 @@ static void RunAttentionTest( int min_cuda_architecture = use_float16 ? 530 : 0; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture) && !is_weights_constant && !only_enable_cpu; + bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()) && !is_weights_constant && !only_enable_cpu; bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()) && !use_float16 && !only_enable_cuda; int head_size = hidden_size / number_of_heads; - if (enable_cpu || enable_cuda) { + if (enable_cpu || enable_cuda || enable_rocm) { OpTester tester("Attention", 1, onnxruntime::kMSDomain); tester.AddAttribute("num_heads", static_cast(number_of_heads)); tester.AddAttribute("unidirectional", static_cast(is_unidirectional ? 1 : 0)); @@ -151,6 +152,12 @@ static void RunAttentionTest( tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } + if (enable_rocm) { + std::vector> execution_providers; + execution_providers.push_back(DefaultRocmExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (enable_cpu) { std::vector> execution_providers; execution_providers.push_back(DefaultCpuExecutionProvider()); diff --git a/tools/ci_build/amd_hipify.py b/tools/ci_build/amd_hipify.py index 4f7be01519..62bfd1cc9e 100644 --- a/tools/ci_build/amd_hipify.py +++ b/tools/ci_build/amd_hipify.py @@ -16,11 +16,8 @@ training_ops_path = 'orttraining/orttraining/training_ops' contrib_ops_excluded_files = [ 'bert/attention.cc', - 'bert/attention.h', 'bert/attention_impl.cu', - 'bert/attention_impl.h', - 'bert/attention_transpose.cu', - 'bert/attention_concat.cu', + 'bert/attention_softmax.h', 'bert/decoder_attention.h', 'bert/decoder_attention.cc', 'bert/embed_layer_norm.cc',