diff --git a/onnxruntime/contrib_ops/rocm/bert/attention.cc b/onnxruntime/contrib_ops/rocm/bert/attention.cc deleted file mode 100644 index 1210442580..0000000000 --- a/onnxruntime/contrib_ops/rocm/bert/attention.cc +++ /dev/null @@ -1,144 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/attention.h" -#include "contrib_ops/rocm/bert/attention_impl.h" -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/shared_inc/fpgeneric.h" -#include "core/providers/rocm/tunable/gemm.h" - -using namespace onnxruntime::rocm; -using namespace ::onnxruntime::common; -using namespace ONNX_NAMESPACE; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -constexpr int kPastSequenceLengthInputIndex = 6; -constexpr int kPastInputIndex = 4; -constexpr int kPresentOutputIndex = 1; - -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - Attention, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .MayInplace(kPastInputIndex, kPresentOutputIndex) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .InputMemoryType(OrtMemTypeCPUInput, kPastSequenceLengthInputIndex), \ - Attention); - -REGISTER_KERNEL_TYPED(float) -REGISTER_KERNEL_TYPED(MLFloat16) - -template -Attention::Attention(const OpKernelInfo& info) : RocmKernel(info), AttentionBase(info, true) {} - -template -Status Attention::ComputeInternal(OpKernelContext* context) const { - const Tensor* input = context->Input(0); - const Tensor* weights = context->Input(1); - const Tensor* bias = context->Input(2); - const Tensor* mask_index = context->Input(3); - const Tensor* past = context->Input(4); - const Tensor* relative_position_bias = context->Input(5); - const Tensor* past_seq_len = context->Input(kPastSequenceLengthInputIndex); - - auto& device_prop = GetDeviceProp(); - AttentionParameters parameters; - ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), - weights->Shape(), - bias->Shape(), - mask_index, - past, - relative_position_bias, - ¶meters, - device_prop.maxThreadsPerBlock, - past_seq_len)); - ORT_ENFORCE(parameters.sequence_length == parameters.kv_sequence_length); // self attention - - TensorShapeVector output_shape(3); - output_shape[0] = static_cast(parameters.batch_size); - output_shape[1] = static_cast(parameters.sequence_length); - output_shape[2] = static_cast(parameters.v_hidden_size); - Tensor* output = context->Output(0, output_shape); - - std::vector present_dims{ - 2, parameters.batch_size, parameters.num_heads, - parameters.past_present_share_buffer ? parameters.max_sequence_length : parameters.total_sequence_length, - parameters.head_size}; - TensorShape present_shape(present_dims); - Tensor* present = context->Output(kPresentOutputIndex, present_shape); - - rocblas_handle rocblas = GetRocblasHandle(context); - constexpr size_t element_size = sizeof(T); - - int m = parameters.batch_size * parameters.sequence_length; - int n = (parameters.hidden_size + parameters.hidden_size + parameters.v_hidden_size); - int k = parameters.input_hidden_size; - auto gemm_buffer = GetScratchBuffer(static_cast(m) * n, context->GetComputeStream()); - - typedef typename ToHipType::MappedType HipT; - namespace blas = rocm::tunable::blas; - - // Bias shape is (N), broadcast using B(N, M) = 1 * bias(N, 1) x ones(1, M) + 0 * B. - // TODO: use custom kernel of expand to improve the performance. - ORT_RETURN_IF_ERROR(blas::column_major::Gemm( - GetTuningContext(), Stream(context), rocblas, - blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, - n, m, 1, - /*alpha=*/1.0f, - reinterpret_cast(bias->Data()), n, - GetConstOnes(m, Stream(context)), 1, - /*beta=*/0.0f, - reinterpret_cast(gemm_buffer.get()), n)); - - // result(N, M) = 1 * weights x input + 1 x B. - ORT_RETURN_IF_ERROR(blas::column_major::Gemm( - GetTuningContext(), Stream(context), rocblas, - blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, - n, m, k, - /*alpha=*/1.0f, - reinterpret_cast(weights->Data()), n, - reinterpret_cast(input->Data()), k, - /*beta=*/1.0f, - reinterpret_cast(gemm_buffer.get()), n)); - - size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, - parameters.batch_size, - parameters.num_heads, - parameters.head_size, - parameters.sequence_length, - parameters.past_sequence_length); - - auto work_space = GetScratchBuffer(workSpaceSize, context->GetComputeStream()); - return LaunchAttentionKernel( - device_prop, - GetTuningContext(), - Stream(context), - rocblas, - element_size, - parameters.batch_size, - parameters.sequence_length, - parameters.num_heads, - parameters.head_size, - parameters.past_sequence_length, - parameters.is_unidirectional, - reinterpret_cast(gemm_buffer.get()), - nullptr == mask_index ? nullptr : mask_index->Data(), - nullptr == mask_index ? gsl::span() : mask_index->Shape().GetDims(), - parameters.mask_filter_value, - nullptr == past ? nullptr : past->Data(), - nullptr == relative_position_bias ? nullptr : relative_position_bias->Data(), - work_space.get(), - output->MutableData(), - nullptr == present ? nullptr : present->MutableData()); -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/attention.cu b/onnxruntime/contrib_ops/rocm/bert/attention.cu new file mode 100644 index 0000000000..da4013a5e9 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/attention.cu @@ -0,0 +1,173 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/rocm/bert/attention.h" +#include "contrib_ops/rocm/bert/attention_impl.h" +#include "contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh" +#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh" +#include "contrib_ops/rocm/bert/transformer_common.h" +#include "core/providers/rocm/rocm_common.h" +#include "core/providers/rocm/shared_inc/fpgeneric.h" +#include "core/providers/rocm/tunable/gemm.h" + +using namespace onnxruntime::rocm; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +constexpr int kPastSequenceLengthInputIndex = 6; +constexpr int kPastInputIndex = 4; +constexpr int kPresentOutputIndex = 1; + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + Attention, \ + kMSDomain, \ + 1, \ + T, \ + kRocmExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .MayInplace(kPastInputIndex, kPresentOutputIndex) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPUInput, kPastSequenceLengthInputIndex), \ + Attention); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) + +template +Attention::Attention(const OpKernelInfo& info) : RocmKernel(info), AttentionBase(info, true) {} + +template +Status Attention::ComputeInternal(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const Tensor* weights = context->Input(1); + const Tensor* bias = context->Input(2); + const Tensor* mask_index = context->Input(3); + const Tensor* past = context->Input(4); + const Tensor* relative_position_bias = context->Input(5); + const Tensor* past_seq_len = context->Input(kPastSequenceLengthInputIndex); + + auto& device_prop = GetDeviceProp(); + AttentionParameters attn; + ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), + weights->Shape(), + bias->Shape(), + mask_index, + past, + relative_position_bias, + &attn, + device_prop.maxThreadsPerBlock, + past_seq_len)); + ORT_ENFORCE(attn.sequence_length == attn.kv_sequence_length); // self attention + + TensorShapeVector output_shape(3); + output_shape[0] = static_cast(attn.batch_size); + output_shape[1] = static_cast(attn.sequence_length); + output_shape[2] = static_cast(attn.v_hidden_size); + Tensor* output = context->Output(0, output_shape); + + std::vector present_dims{ + 2, attn.batch_size, attn.num_heads, + past_present_share_buffer_ ? attn.max_sequence_length : attn.total_sequence_length, + attn.head_size}; + TensorShape present_shape(present_dims); + Tensor* present = context->Output(kPresentOutputIndex, present_shape); + + auto stream = Stream(context); + rocblas_handle rocblas = GetRocblasHandle(context); + + using HipT = typename ToHipType::MappedType; + using QkvProjectGeneric = GemmPermuteGenericPipeline; + using AttentionGeneric = GemmSoftmaxGemmPermuteGenericPipeline; + + size_t qkv_project_output_bytes = QkvProjectGeneric::GetOutputNumBytes(&attn); + size_t attention_workspace_bytes = AttentionGeneric::GetWorkspaceNumBytes(&attn); + ORT_ENFORCE(QkvProjectGeneric::GetWorkspaceNumBytes(&attn) <= attention_workspace_bytes); // workspace reuse + + auto qkv_project_output = GetScratchBuffer(qkv_project_output_bytes, context->GetComputeStream()); + auto workspace = GetScratchBuffer(attention_workspace_bytes, context->GetComputeStream()); + + GemmPermuteParams gemm_permute_params; + { + auto& params = gemm_permute_params; + params.tuning_ctx = GetTuningContext(); + params.stream = stream; + params.handle = rocblas; + params.attention = &attn; + params.device_prop = &device_prop; + + params.input_buffer = reinterpret_cast(input->DataRaw()); + params.weight_buffer = reinterpret_cast(weights->DataRaw()); + params.bias_buffer = reinterpret_cast(bias->DataRaw()); + params.out_buffer = reinterpret_cast(qkv_project_output.get()); + params.ones = GetConstOnes(attn.batch_size * attn.sequence_length, stream); + params.workspace_buffer = reinterpret_cast(workspace.get()); // workspace reuse + } + + ORT_RETURN_IF_ERROR(QkvProjectGeneric::Run(&gemm_permute_params)); + auto [q_buffer, k_buffer, v_buffer] = QkvProjectGeneric::UnspliceOutputQKV(&gemm_permute_params); + + if (nullptr != present) { + // Concat past (2xBxNxS'xH) to present (2xBxNxTxH): + // past_k (BxNxS'xH) + k (BxNxSxH) => present_k (BxNxTxH) + // past_v (BxNxS'xH) + v (BxNxSxH) => present_v (BxNxTxH) + const int batches = attn.batch_size * attn.num_heads; + const int present_size_per_batch = attn.total_sequence_length * attn.head_size; + ORT_RETURN_IF_ERROR( + LaunchConcatPastToPresent(Stream(context), + attn.total_sequence_length, + attn.sequence_length, + attn.batch_size, + attn.head_size, + attn.num_heads, + device_prop.maxThreadsPerBlock, + nullptr == past ? nullptr : reinterpret_cast(past->DataRaw()), + k_buffer, + reinterpret_cast(present->MutableDataRaw()))); + + // update pointers to present_k and present_v. + k_buffer = reinterpret_cast(present->MutableDataRaw()); + v_buffer = reinterpret_cast(present->MutableDataRaw()) + batches * present_size_per_batch; + } + + // For testing, environment variable ORT_TRANSFORMER_OPTIONS=1 could enable persistent softmax + const TransformerOptions* options = TransformerOptions::GetInstance(); + bool use_persistent_softmax = options->IsPrecisionMode() && !options->DisablePersistentSoftmax(); + + GemmSoftmaxGemmPermuteParams gemm_softmax_gemm_permute_params; + { + auto& params = gemm_softmax_gemm_permute_params; + params.tuning_ctx = GetTuningContext(); + params.stream = Stream(context); + params.handle = rocblas; + params.attention = &attn; + params.device_prop = &device_prop; + // FIXME: the params.scale seems to be different from AttentionParameters::scale; + params.scale = 1.0f / sqrt(static_cast(attn.head_size)); + params.q_buffer = q_buffer; + params.k_buffer = k_buffer; + params.v_buffer = v_buffer; + params.out_buffer = reinterpret_cast(output->MutableDataRaw()); + + if (relative_position_bias != nullptr) { + params.bias_buffer = reinterpret_cast(relative_position_bias->DataRaw()); + } + + if (mask_index != nullptr) { + params.mask_index_buffer = mask_index->Data(); + params.mask_index_dims = mask_index->Shape().GetDims(); + } + + params.workspace_buffer = reinterpret_cast(workspace.get()); + } + + return AttentionGeneric::Run(&gemm_softmax_gemm_permute_params, use_persistent_softmax); +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu index e42fb2b2eb..8750334303 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu @@ -26,17 +26,21 @@ limitations under the License. #include "core/providers/rocm/rocm_common.h" #include "core/providers/rocm/shared_inc/fpgeneric.h" #include "core/providers/rocm/tunable/gemm.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" +#include "contrib_ops/cpu/bert/attention_base.h" #include "contrib_ops/rocm/bert/attention_impl.h" #include "contrib_ops/rocm/bert/attention_softmax.h" -#include "contrib_ops/rocm/bert/transformer_common.h" using namespace onnxruntime::rocm; -using namespace hipcub; namespace blas = onnxruntime::rocm::tunable::blas; #define CHECK_ROCM(expr) HIP_RETURN_IF_ERROR(expr) +using namespace onnxruntime::rocm; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; + namespace onnxruntime { namespace contrib { namespace rocm { @@ -49,8 +53,8 @@ size_t GetAttentionScratchSize(size_t element_size, int batch_size, int num_heads, int sequence_length, - int all_sequence_length) { - const size_t bytes = element_size * batch_size * num_heads * sequence_length * all_sequence_length; + int total_sequence_length) { + const size_t bytes = element_size * batch_size * num_heads * sequence_length * total_sequence_length; const size_t alignment = 256; const size_t bytesAligned = AlignTo(bytes, alignment); @@ -69,181 +73,9 @@ size_t GetAttentionWorkspaceSize( sequence_length, past_sequence_length + sequence_length); } -template -Status QkvToContext( - const hipDeviceProp_t& prop, - RocmTuningContext* tuning_ctx, - rocblas_handle& rocblas, - hipStream_t stream, - const int batch_size, - const int sequence_length, - const int num_heads, - const int head_size, - const size_t element_size, - const T* input, - T* output, - T* workspace, - const int* mask_index, - gsl::span mask_index_dims, - const float mask_filter_value, - bool is_unidirectional, - int past_sequence_length, - const T* past, - const T* relative_position_bias, - T* present, - bool use_persistent_softmax) { - const int all_sequence_length = past_sequence_length + sequence_length; - const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads, - sequence_length, all_sequence_length); - T* scratch1 = workspace; - T* scratch2 = scratch1 + (bytes / element_size); - T* scratch3 = scratch2 + (bytes / element_size); - - const int max_threads_per_block = prop.maxThreadsPerBlock; - - // input should be BxSx3xNxH => scratch3: 3xBxNxSxH - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 3, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, false, input, scratch3)); - - // now scratch3 has Q, K, V: each has size BxNxSxH - const int batches = batch_size * num_heads; - const int size_per_batch = sequence_length * head_size; - const int total_size = batches * size_per_batch; - - const T* q = scratch3; - const T* k = q + total_size; - const T* v = k + total_size; - - rocblas_set_stream(rocblas, stream); - - // Concat past (2xBxNxS'xH) to present (2xBxNxS*xH): - // past_k (BxNxS'xH) + k (BxNxSxH) => present_k (BxNxS*xH) - // past_v (BxNxS'xH) + v (BxNxSxH) => present_v (BxNxS*xH) - const int present_size_per_batch = all_sequence_length * head_size; - if (nullptr != present) { - ORT_RETURN_IF_ERROR( - LaunchConcatPastToPresent(stream, all_sequence_length, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, past, k, present)); - - // update pointers to present_k and present_v. - k = present; - v = present + batches * present_size_per_batch; - } - - // Raw attention mask could be 2D (BxS) or 3D (BxSxS*) or 4D(Bx1xMxM), where M is the max sequence length. - bool use_raw_attention_mask = (nullptr != mask_index && mask_index_dims.size() >= 2); - - // compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxS* - // Q: BxNxSxH, K (present_k): BxNxS*xH, Q*K': BxNxSxS* - const float rsqrt_head_size = 1.f / sqrt(static_cast(head_size)); - const int temp_matrix_size = sequence_length * all_sequence_length; - - ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( - tuning_ctx, stream, rocblas, - blas::BlasOp::Trans, blas::BlasOp::NonTrans, - all_sequence_length, sequence_length, head_size, - // For raw attention mask, the scalar if 1/sqrt(H) is moved to softmax computation. - /*alpha=*/use_raw_attention_mask ? 1.0f : rsqrt_head_size, - k, head_size, present_size_per_batch, - q, head_size, size_per_batch, - /*beta=*/0.0f, - scratch1, all_sequence_length, temp_matrix_size, - batches)); - - // apply softmax and store result P to scratch2: BxNxSxS* - if (use_raw_attention_mask) { // 2d, 3d or 4d attention mask - const int mask_dimension = static_cast(mask_index_dims.size()); - const int max_sequence_length = mask_dimension == 4 ? static_cast(mask_index_dims[3]) : 0; - - T* persistent_softmax_workspace = scratch1; // replace Q*K' in place if persistent softmax is selected. - ORT_RETURN_IF_ERROR( - ComputeSoftmaxWithRawMask(stream, all_sequence_length, sequence_length, batch_size, num_heads, - mask_index, nullptr, relative_position_bias, scratch1, scratch2, - is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, - use_persistent_softmax, persistent_softmax_workspace, mask_filter_value)); - } else if (nullptr != mask_index) { // 1d mask index - ORT_ENFORCE(mask_index_dims.size() == 1); - // mask_index has 1D shape: either (batch_size) or (2*batch_size). Only the later one has start postions. - const int* mask_start = (mask_index_dims[0] > batch_size) ? mask_index + batch_size : nullptr; - ORT_RETURN_IF_ERROR(ComputeSoftmaxWithMask1D(stream, all_sequence_length, sequence_length, batch_size, num_heads, - mask_index, mask_start, relative_position_bias, scratch1, scratch2, is_unidirectional)); - } else { // no mask - ORT_RETURN_IF_ERROR(ComputeSoftmax(stream, all_sequence_length, sequence_length, batch_size, num_heads, - relative_position_bias, scratch1, scratch2, is_unidirectional)); - } - - // compute P*V (as V*P), and store in scratch3: BxNxSxH - ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm( - tuning_ctx, stream, rocblas, - blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, - head_size, sequence_length, all_sequence_length, - /*alpha=*/1.0f, - v, head_size, present_size_per_batch, - scratch2, all_sequence_length, temp_matrix_size, - /*beta=*/0.0f, - scratch3, head_size, size_per_batch, - batches)); - - // scratch3 is BxNxSxH, transpose to output BxSxNxH - return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, false, scratch3, output); -} - -Status LaunchAttentionKernel( - const hipDeviceProp_t& prop, - RocmTuningContext* tuning_ctx, - hipStream_t stream, - rocblas_handle& rocblas, - const size_t element_size, - int batch_size, - int sequence_length, - int num_heads, - int head_size, - int past_sequence_length, - bool is_unidirectional, - const void* input, - const int* mask_index, - gsl::span mask_index_dims, - const float mask_filter_value, - const void* past, - const void* relative_position_bias, - void* workspace, - void* output, - void* present) { - // For testing, environment variable ORT_TRANSFORMER_OPTIONS=1 could enable persistent softmax - const TransformerOptions* options = TransformerOptions::GetInstance(); - bool use_persistent_softmax = options->IsPrecisionMode() && !options->DisablePersistentSoftmax(); - if (element_size == 2) { - return QkvToContext( - prop, tuning_ctx, rocblas, stream, batch_size, sequence_length, num_heads, head_size, element_size, - reinterpret_cast(input), - reinterpret_cast<__half*>(output), - reinterpret_cast<__half*>(workspace), - mask_index, - mask_index_dims, - mask_filter_value, - is_unidirectional, - past_sequence_length, - reinterpret_cast(past), - reinterpret_cast(relative_position_bias), - reinterpret_cast<__half*>(present), - use_persistent_softmax); - } else { - return QkvToContext( - prop, tuning_ctx, rocblas, stream, batch_size, sequence_length, num_heads, head_size, element_size, - reinterpret_cast(input), - reinterpret_cast(output), - reinterpret_cast(workspace), - mask_index, - mask_index_dims, - mask_filter_value, - is_unidirectional, - past_sequence_length, - reinterpret_cast(past), - reinterpret_cast(relative_position_bias), - reinterpret_cast(present), - use_persistent_softmax); - } +inline int3 Get2DMaskStrides(int total_sequence_length) { + // stride == 0 indicate broadcasting + return {total_sequence_length, 0, 1}; } template @@ -375,9 +207,11 @@ Status DecoderQkvToContext( } if (has_key_padding_mask) { - ORT_RETURN_IF_ERROR(ComputeSoftmaxWithRawMask(stream, kv_sequence_length, sequence_length, batch_size, - num_heads, nullptr, key_padding_mask, nullptr, scratch1, scratch2, - false, 1, 2, static_cast(0), false, nullptr, mask_filter_value)); + int3 strides = Get2DMaskStrides(kv_sequence_length); + ORT_RETURN_IF_ERROR(ComputeSoftmaxWithRawMask( + stream, kv_sequence_length, sequence_length, batch_size, num_heads, + strides, nullptr, key_padding_mask, nullptr, scratch1, scratch2, + false, 1.0f, false, nullptr, mask_filter_value)); } else { ORT_RETURN_IF_ERROR(ComputeSoftmax(stream, kv_sequence_length, sequence_length, batch_size, num_heads, nullptr, scratch1, scratch2, false)); diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h index 3fcfeb5175..c8ba9ec875 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h @@ -26,29 +26,6 @@ size_t GetAttentionWorkspaceSize( int sequence_length, int past_sequence_length); -Status LaunchAttentionKernel( - const hipDeviceProp_t& prop, // Device Properties - RocmTuningContext* tuning_ctx, // context for tuning - hipStream_t stream, // Hip stream - rocblas_handle& rocblas, // Rocblas handle - const size_t element_size, // Element size of input tensor - int batch_size, // Batch size (B) - int sequence_length, // Sequence length (S) - int num_heads, // Number of attention heads (N) - int head_size, // Hidden layer size per head (H) - int past_sequence_length, // Sequence length in past state - bool is_unidirectional, // Whether there is unidirectional mask. - const void* input, // Input tensor - const int* mask_index, // Attention mask raw data or index. NULL means no mask. - gsl::span mask_index_dims, // Mask index shape - const float mask_filter_value, // Mask value for filtered out positions - const void* past, // Past state input - const void* relative_position_bias, // Additional Add - void* workspace, // Temporary buffer - void* output, // Output tensor - void* present // Present state output -); - Status LaunchDecoderAttentionKernel( const hipDeviceProp_t& prop, // Device Properties RocmTuningContext* tuning_ctx, // context for tuning diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h b/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h index 7c99fc05ec..8e47a62403 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h +++ b/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h @@ -174,20 +174,23 @@ __device__ inline void SoftmaxSmall(const int all_sequence_length, } } +// Note about the attention_mask_strides and attention_mask/key_padding_mask +// attention_mask accepts 2D, 3D or 4D tensor, but it will be viewed as 3D tensor uniformally and it will be indexed +// as [batch_index, sequence_index, token_index]. template -__device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, - const int sequence_length, - const int* attention_mask, // 2D, 3D or 4D attention mask - const bool* key_padding_mask, - const T* add_before_softmax, - const T* input, - T* output, - const bool is_unidirectional, - const float rsqrt_head_size, - const int mask_dimension, - const int max_sequence_length, - const bool skip_softmax, - const float mask_filter_value) { +__global__ void SoftmaxWithRawMaskSmallKernel( + const int all_sequence_length, + const int sequence_length, + const int3 attention_mask_strides, + const int* attention_mask, // 2D, 3D or 4D attention mask + const bool* key_padding_mask, + const T* add_before_softmax, + const T* input, + T* output, + const bool is_unidirectional, + const float rsqrt_head_size, + const bool skip_softmax, + const float mask_filter_value) { using BlockReduce = hipcub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmp_storage; @@ -216,16 +219,10 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, } } - int mask_offset = 0; const int batch_index = blockIdx.y; - if (mask_dimension == 2) { - mask_offset = batch_index * all_sequence_length + threadIdx.x; - } else if (mask_dimension == 3) { - mask_offset = (batch_index * sequence_length + sequence_index) * all_sequence_length + threadIdx.x; - } else if (mask_dimension == 4) { - int from_index = all_sequence_length - sequence_length + sequence_index; - mask_offset = (batch_index * max_sequence_length + from_index) * max_sequence_length + threadIdx.x; - } + int mask_offset = attention_mask_strides.x * batch_index + + attention_mask_strides.y * sequence_index + + attention_mask_strides.z * threadIdx.x; if (nullptr == key_padding_mask) { const int& mask = attention_mask[mask_offset]; @@ -320,7 +317,7 @@ Status ComputeSoftmax( SoftmaxKernel<<>>( all_sequence_length, sequence_length, add_before_softmax, input, output); } else { - ORT_THROW("Attention ROCM operator does not support total sequence length > 1024."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attention ROCM operator does not support total sequence length > 1024."); } return HIP_CALL(hipPeekAtLastError()); @@ -375,26 +372,6 @@ __global__ void MaskedSoftmaxKernel(const int all_sequence_length, const int seq add_before_softmax, input, output); } -template -__global__ void SoftmaxWithRawMaskSmallKernel(const int all_sequence_length, - const int sequence_length, - const int* attention_mask, - const bool* key_padding_mask, - const T* add_before_softmax, - const T* input, T* output, - const bool is_unidirectional, - const float rsqrt_head_size, - const int mask_dimension, - const int max_sequence_length, - const bool skip_softmax, - const float mask_filter_value) { - SoftmaxWithRawMaskSmall( - all_sequence_length, sequence_length, - attention_mask, key_padding_mask, add_before_softmax, input, output, - is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, - skip_softmax, mask_filter_value); -} - template Status ComputeSoftmaxWithMask1D( hipStream_t stream, @@ -403,115 +380,83 @@ Status ComputeSoftmaxWithMask1D( const T* add_before_softmax, const T* input, T* output, const bool is_unidirectional) { const dim3 grid(sequence_length * num_heads, batch_size, 1); +#define DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(block_size) \ + MaskedSoftmaxKernelSmall<<>>( \ + all_sequence_length, sequence_length, mask_index, mask_start, \ + add_before_softmax, input, output, is_unidirectional); + if (all_sequence_length <= 32) { - const int blockSize = 32; - MaskedSoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, mask_index, mask_start, - add_before_softmax, input, output, is_unidirectional); + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(32); } else if (all_sequence_length <= 64) { - const int blockSize = 64; - MaskedSoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, mask_index, mask_start, - add_before_softmax, input, output, is_unidirectional); + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(64); } else if (all_sequence_length <= 128) { - const int blockSize = 128; - MaskedSoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, mask_index, mask_start, - add_before_softmax, input, output, is_unidirectional); + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(128); } else if (all_sequence_length <= 256) { - const int blockSize = 256; - MaskedSoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, mask_index, mask_start, - add_before_softmax, input, output, is_unidirectional); + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(256); } else if (all_sequence_length <= 512) { - const int blockSize = 512; - MaskedSoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, mask_index, mask_start, - add_before_softmax, input, output, is_unidirectional); + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(512); } else if (all_sequence_length <= 1024) { - const int blockSize = 1024; - MaskedSoftmaxKernelSmall<<>>( - all_sequence_length, sequence_length, mask_index, mask_start, - add_before_softmax, input, output, is_unidirectional); + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(1024); } else if (!is_unidirectional) { const int blockSize = 1024; MaskedSoftmaxKernel<<>>( all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output); } else { - ORT_THROW("Attention ROCM operator does not support total sequence length > 1024."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attention ROCM operator does not support total sequence length > 1024."); } +#undef DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE + return HIP_CALL(hipPeekAtLastError()); } template Status ComputeSoftmaxWithRawMask(hipStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int num_heads, - const int* attention_mask, - const bool* key_padding_mask, - const T* add_before_softmax, - const T* input, - T* output, - const bool is_unidirectional, - const float rsqrt_head_size, - const int mask_dimension, - const int max_sequence_length, - const bool use_persistent_softmax, - T* persistent_softmax_workspace, - const float mask_filter_value) { + const int all_sequence_length, + const int sequence_length, + const int batch_size, + const int num_heads, + const int3 attention_mask_strides, + const int* attention_mask, + const bool* key_padding_mask, + const T* add_before_softmax, + const T* input, + T* output, + const bool is_unidirectional, + const float rsqrt_head_size, + const bool use_persistent_softmax, + T* persistent_softmax_workspace, + const float mask_filter_value) { const dim3 grid(sequence_length * num_heads, batch_size, 1); T* out = use_persistent_softmax ? persistent_softmax_workspace : output; + +#define DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(block_size) \ + SoftmaxWithRawMaskSmallKernel<<>>( \ + all_sequence_length, sequence_length, attention_mask_strides, \ + attention_mask, key_padding_mask, add_before_softmax, input, out, \ + is_unidirectional, rsqrt_head_size, \ + use_persistent_softmax, mask_filter_value); + if (all_sequence_length <= 32) { - const int blockSize = 32; - SoftmaxWithRawMaskSmallKernel<<>>( - all_sequence_length, sequence_length, - attention_mask, key_padding_mask, add_before_softmax, input, out, - is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, - use_persistent_softmax, mask_filter_value); + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(32); } else if (all_sequence_length <= 64) { - const int blockSize = 64; - SoftmaxWithRawMaskSmallKernel<<>>( - all_sequence_length, sequence_length, - attention_mask, key_padding_mask, add_before_softmax, input, out, - is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, - use_persistent_softmax, mask_filter_value); + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(64); } else if (all_sequence_length <= 128) { - const int blockSize = 128; - SoftmaxWithRawMaskSmallKernel<<>>( - all_sequence_length, sequence_length, - attention_mask, key_padding_mask, add_before_softmax, input, out, - is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, - use_persistent_softmax, mask_filter_value); + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(128); } else if (all_sequence_length <= 256) { - const int blockSize = 256; - SoftmaxWithRawMaskSmallKernel<<>>( - all_sequence_length, sequence_length, - attention_mask, key_padding_mask, add_before_softmax, input, out, - is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, - use_persistent_softmax, mask_filter_value); + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(256); } else if (all_sequence_length <= 512) { - const int blockSize = 512; - SoftmaxWithRawMaskSmallKernel<<>>( - all_sequence_length, sequence_length, - attention_mask, key_padding_mask, add_before_softmax, input, out, - is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, - use_persistent_softmax, mask_filter_value); + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(512); } else if (all_sequence_length <= 1024) { - const int blockSize = 1024; - SoftmaxWithRawMaskSmallKernel<<>>( - all_sequence_length, sequence_length, - attention_mask, key_padding_mask, add_before_softmax, input, out, - is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, - use_persistent_softmax, mask_filter_value); + DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(1024); } else { - ORT_THROW("Attention ROCM operator does not support total sequence length > 1024."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attention ROCM operator does not support total sequence length > 1024."); } +#undef DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE + if (use_persistent_softmax) { return dispatch_warpwise_softmax_forward(stream, output, diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh new file mode 100644 index 0000000000..486d6dca28 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh @@ -0,0 +1,124 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/rocm/rocm_common.h" +#include "core/providers/rocm/rocm_kernel.h" +#include "core/providers/rocm/tunable/gemm.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" +#include "contrib_ops/cpu/bert/attention_common.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +namespace blas = onnxruntime::rocm::tunable::blas; + +namespace { +std::tuple GetQkvProjectGemmMNKBatch(const AttentionParameters* attention) { + int m = attention->sequence_length; + int n = (attention->hidden_size + attention->hidden_size + attention->v_hidden_size); // q + k + v + int k = attention->input_hidden_size; + int batch = attention->batch_size; + return {m, n, k, batch}; +} +} // namespace + +template +struct GemmPermuteParams : onnxruntime::rocm::tunable::OpParams { + std::string Signature() const override { + auto [m, n, k, batch] = GetQkvProjectGemmMNKBatch(attention); + return MakeString("M", m, "_N", n, "_K", k, "_B", batch); + } + + rocblas_handle handle; + const AttentionParameters* attention; + const hipDeviceProp_t* device_prop; + + const T* input_buffer; + const T* weight_buffer; + const T* bias_buffer; + T* out_buffer; + + int3 bias_strides; + + const T* ones; // used for broadcasting bias if the underlying algorithm does not support strides + T* workspace_buffer; +}; + +template +struct GemmPermuteGenericPipeline { + inline static size_t GetOutputNumBytes(const AttentionParameters* attn) { + auto [m, n, _, batch] = GetQkvProjectGemmMNKBatch(attn); + return sizeof(T) * m * n * batch; + } + + inline static size_t GetWorkspaceNumBytes(const AttentionParameters* attn) { + return GetOutputNumBytes(attn); + } + + inline static std::tuple GetGemmMNK(const GemmPermuteParams* params) { + auto [m, n, k, batch] = GetQkvProjectGemmMNKBatch(params->attention); + return {batch * m, n, k}; + } + + inline static std::tuple UnspliceOutputQKV(const GemmPermuteParams* params) { + auto* attn = params->attention; + int64_t batch = attn->batch_size * attn->num_heads; + int64_t num_elems_per_batch = attn->sequence_length * attn->head_size; + int64_t num_elems = batch * num_elems_per_batch; + auto q = params->out_buffer + 0 * num_elems; + auto k = params->out_buffer + 1 * num_elems; + auto v = params->out_buffer + 2 * num_elems; + return {q, k, v}; + } + + inline static Status BroadcastBias(const GemmPermuteParams* params) { + auto [m, n, k] = GetGemmMNK(params); + // Bias shape is (N), broadcast using B(M, N) = ones(M, 1) x bias(1, N). + // TODO: use custom kernel of expand to improve the performance. + return blas::row_major::Gemm( + params->TuningContext(), params->Stream(), params->handle, + blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, + m, n, 1, + /*alpha=*/1.0f, + params->ones, 1, + params->bias_buffer, n, + /*beta=*/0.0f, + params->workspace_buffer, n); + } + + inline static Status Gemm(const GemmPermuteParams* params) { + auto [m, n, k] = GetGemmMNK(params); + // result(M, N) = input x weights + bias. + return blas::row_major::Gemm( + params->TuningContext(), params->Stream(), params->handle, + blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, + m, n, k, + /*alpha=*/1.0f, + params->input_buffer, k, + params->weight_buffer, n, + /*beta=*/1.0f, + params->workspace_buffer, n); + } + + inline static Status Permute0213(const GemmPermuteParams* params) { + auto* attn = params->attention; + // input should be BxSx3xNxH => gemm_buffer: 3xBxNxSxH + return LaunchTransQkv( + params->Stream(), 3, attn->sequence_length, attn->batch_size, attn->head_size, attn->num_heads, + params->device_prop->maxThreadsPerBlock, false, params->workspace_buffer, params->out_buffer); + } + + static Status Run(const GemmPermuteParams* params) { + ORT_RETURN_IF_ERROR(BroadcastBias(params)); + ORT_RETURN_IF_ERROR(Gemm(params)); + ORT_RETURN_IF_ERROR(Permute0213(params)); + return Status::OK(); + } +}; + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh new file mode 100644 index 0000000000..ece4cd8912 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh @@ -0,0 +1,249 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +/* About Computing in these Pipelines + +B: batch size of Attention Op. NOTE: To be disambiguated with batch size of GEMMs +S: sequence length +T: total sequence length +N: num of heads +H: head dimension + +BN: B*N, which is the batch size of GEMMs. NOTE: To be disambiguated with batch size of Attention Op + +In QKV projection (prior to this pipeline): + /-> Q [B,S,N*H] ->Reshape-> [B,S,N,H] ->Permute0213-> [B,N,S,H] +X --o--> K [B,T,N*H] ->Reshape-> [B,T,N,H] ->Permute0213-> [B,N,T,H] + \-> V [B,T,N*H] ->Reshape-> [B,T,N,H] ->Permute0213-> [B,N,T,H] + +pre_softmax_attn_scores = Q*K' = [B,N,S,H] * [BxNxTxH]' = [B,N,S,T] Batched GEMM1 +pre_softmax_attn_scores_masked = pre_softmax_attn_scores +? bias +? mask Add Bias, +? is optional +attn_scores = softmax(pre_softmax_attn_scores_masked * scale) = [B,N,S,T] Scale then Softmax +scaled_multi_head_attn = attn_scores * V = [B,N,S,T] * [B,N,T,H] = [B,N,S,H] Batched GEMM2 + +Op outputs scaled_multi_head_attn: +[B,N,S,H] ->Permute0213-> [B,S,N,H] ->Reshape-> [B,S,N*H] + + +For the computing of pre_softmax_attn_scores +? mask +? bias: + +GemmSoftmaxGemmPermuteGenericPipeline handles it in specialized softmax. TODO: remove it! + +*/ + +#include "core/providers/rocm/tunable/gemm.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" +#include "contrib_ops/cpu/bert/attention_base.h" +#include "contrib_ops/rocm/bert/attention_impl.h" +#include "contrib_ops/rocm/bert/attention_softmax.h" + +namespace blas = onnxruntime::rocm::tunable::blas; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +inline int3 Get2DMaskStrides(int total_sequence_length) { + // stride == 0 indicate broadcasting + return {total_sequence_length, 0, 1}; +} + +inline std::tuple GetRawMaskBufferAddrSizesAndStrides( + const int* buffer, const AttentionParameters* attn) { + const int* offseted_buffer{buffer}; // how to view the mask buffer + int3 sizes{-1, -1, -1}; // the logical shape of the view + int3 strides{-1, -1, -1}; // the physical memory layout + switch (attn->mask_type) { + case MASK_NONE: + case MASK_2D_DUMMY: + break; // No mask + case MASK_2D_KEY_PADDING: + sizes = {attn->batch_size, 1, attn->total_sequence_length}; + strides = Get2DMaskStrides(attn->total_sequence_length); + break; + case MASK_3D_ATTENTION: + sizes = {attn->batch_size, attn->sequence_length, attn->total_sequence_length}; + strides = {attn->sequence_length * attn->total_sequence_length, attn->total_sequence_length, 1}; + break; + case MASK_4D_MEGATRON: + // offset to skip past sequence part, so that we can index it with [batch_index, sequence_index, token_index] + offseted_buffer = buffer + attn->past_sequence_length * attn->max_sequence_length; + sizes = {attn->batch_size, attn->sequence_length, attn->total_sequence_length}; + strides = {attn->max_sequence_length * attn->max_sequence_length, attn->max_sequence_length, 1}; + break; + default: + throw std::runtime_error("unsupported mask type"); + } + return {offseted_buffer, sizes, strides}; +} + +template +struct GemmSoftmaxGemmPermuteParams : onnxruntime::rocm::tunable::OpParams { + std::string Signature() const override { + auto [m, n, k, o, batch] = GetGemmsMNKOBatch(); + return MakeString("M", m, "_N", n, "_K", k, "_O", o, "_B", batch); + } + + std::tuple GetGemmsMNKOBatch() const { + ORT_ENFORCE(attention != nullptr); + auto m = attention->sequence_length; + auto n = attention->total_sequence_length; + auto k = attention->head_size; + auto o = attention->head_size; + auto batch = attention->batch_size * attention->num_heads; + return {m, n, k, o, batch}; + } + + rocblas_handle handle; + const AttentionParameters* attention; + const hipDeviceProp_t* device_prop; + + float scale; + const T* q_buffer; + const T* k_buffer; + const T* v_buffer; + T* out_buffer; + + // optional, bias [B,N,S,T] + const T* bias_buffer{nullptr}; + + // optional, mask value + const int* mask_index_buffer{nullptr}; + gsl::span mask_index_dims{}; + + // optional, internal + T* workspace_buffer{nullptr}; +}; + +template +struct GemmSoftmaxGemmPermuteGenericPipeline { + static bool UseRawAttentionMask(const GemmSoftmaxGemmPermuteParams* params) { + return params->mask_index_buffer != nullptr && params->mask_index_dims.size() >= 2; + } + + static std::tuple GetWorkspacePlan(const GemmSoftmaxGemmPermuteParams* params) { + auto bytes = GetAttentionScratchSize( + sizeof(T), + params->attention->batch_size, + params->attention->num_heads, + params->attention->sequence_length, + params->attention->total_sequence_length); + auto gemm1_out = params->workspace_buffer; + auto softmax_out = gemm1_out + (bytes / sizeof(T)); + auto gemm2_out = softmax_out + (bytes / sizeof(T)); + return {gemm1_out, softmax_out, gemm2_out}; + } + + inline static size_t GetWorkspaceNumBytes(const AttentionParameters* attn) { + return GetAttentionWorkspaceSize( + sizeof(T), + attn->batch_size, + attn->num_heads, + attn->head_size, + attn->sequence_length, + attn->past_sequence_length); + } + + inline static Status Gemm1(const GemmSoftmaxGemmPermuteParams* params) { + auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch(); + auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); + // GEMM1 [m,k] * [n,k]' -> [m,n] + return blas::row_major::StridedBatchedGemm( + params->TuningContext(), params->Stream(), params->handle, + blas::BlasOp::NonTrans, blas::BlasOp::Trans, + m, n, k, + // For raw attention mask, the scalar is moved to softmax computation. + /*alpha=*/UseRawAttentionMask(params) ? 1.0f : params->scale, + params->q_buffer, k, m * k, + params->k_buffer, k, n * k, + /*beta=*/0.0f, + gemm1_out, n, m * n, + batch); + } + + inline static Status SoftmaxRawMask(const GemmSoftmaxGemmPermuteParams* params, bool use_persistent_softmax) { + // Softmax on [m,n] along the n dimension. + // Raw attention mask could be 2D (B,S) or 3D (B,S,T) or 4D(B,1,M,M), where M is the max sequence length. + auto attn = params->attention; + auto [buffer, sizes, strides] = GetRawMaskBufferAddrSizesAndStrides(params->mask_index_buffer, attn); + auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); + T* persistent_softmax_workspace = gemm1_out; // replace Q*K' in place if persistent softmax is selected. + return ComputeSoftmaxWithRawMask( + params->Stream(), attn->total_sequence_length, attn->sequence_length, attn->batch_size, attn->num_heads, + strides, buffer, nullptr, params->bias_buffer, gemm1_out, softmax_out, + attn->is_unidirectional, /* FIXME: this must not be attn.scale! */ params->scale, + use_persistent_softmax, persistent_softmax_workspace, attn->mask_filter_value); + } + + inline static Status Softmax1DIndexMask(const GemmSoftmaxGemmPermuteParams* params) { + auto mask_1d = params->mask_index_buffer; + auto mask_1d_size = params->mask_index_dims[0]; + // Softmax on [m,n] along the n dimension. + // mask_index has 1D shape: either (batch_size) or (2*batch_size). Only the later one has start postions. + auto attn = params->attention; + auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); + const int* mask_start = (mask_1d_size > attn->batch_size) ? mask_1d + attn->batch_size : nullptr; + return ComputeSoftmaxWithMask1D( + params->Stream(), attn->total_sequence_length, attn->sequence_length, attn->batch_size, attn->num_heads, + mask_1d, mask_start, params->bias_buffer, gemm1_out, softmax_out, attn->is_unidirectional); + } + + inline static Status SoftmaxNoMask(const GemmSoftmaxGemmPermuteParams* params) { + // Softmax on [m,n] along the n dimension. + auto attn = params->attention; + auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); + return ComputeSoftmax( + params->Stream(), attn->total_sequence_length, attn->sequence_length, attn->batch_size, attn->num_heads, + params->bias_buffer, gemm1_out, softmax_out, attn->is_unidirectional); + } + + inline static Status Gemm2(const GemmSoftmaxGemmPermuteParams* params) { + auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch(); + auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); + // GEMM2 [m,n] * [n,o] -> [m,o] + // semantically, the output buffer contains B*N matrices of shape [S,H], compactly, thus B,N,S,H. + return blas::row_major::StridedBatchedGemm( + params->TuningContext(), params->Stream(), params->handle, + blas::BlasOp::NonTrans, blas::BlasOp::NonTrans, + m, o, n, + /*alpha=*/1.0f, + softmax_out, n, m * n, + params->v_buffer, o, n * o, + /*beta=*/0.0f, + gemm2_out, o, m * o, + batch); + } + + inline static Status Permute0213(const GemmSoftmaxGemmPermuteParams* params) { + // Permute 0213 + // gemm2_out is B,N,S,H, transpose to out_buffer as B,S,N,H + auto attn = params->attention; + auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params); + return LaunchTransCtx( + params->Stream(), + attn->sequence_length, attn->batch_size, attn->head_size, attn->num_heads, + params->device_prop->maxThreadsPerBlock, false, gemm2_out, params->out_buffer); + } + + static Status Run(const GemmSoftmaxGemmPermuteParams* params, bool use_persistent_softmax) { + ORT_RETURN_IF_ERROR(Gemm1(params)); + + if (UseRawAttentionMask(params)) { + ORT_RETURN_IF_ERROR(SoftmaxRawMask(params, use_persistent_softmax)); + } else if (params->mask_index_dims.size() == 1) { // 1d index mask + ORT_RETURN_IF_ERROR(Softmax1DIndexMask(params)); + } else { + ORT_RETURN_IF_ERROR(SoftmaxNoMask(params)); + } + + ORT_RETURN_IF_ERROR(Gemm2(params)); + ORT_RETURN_IF_ERROR(Permute0213(params)); + return Status::OK(); + } +}; + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime