mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-23 02:38:28 +00:00
Refactor rocm attention (#14688)
Extract QKV projection and attention computation into pipelines (composed from gemms and kernel launch). This will allow us to introduce ck flash attention in next PR
This commit is contained in:
parent
f3b6664384
commit
a997bb46b6
7 changed files with 628 additions and 470 deletions
|
|
@ -1,144 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "contrib_ops/rocm/bert/attention.h"
|
||||
#include "contrib_ops/rocm/bert/attention_impl.h"
|
||||
#include "core/providers/rocm/rocm_common.h"
|
||||
#include "core/providers/rocm/shared_inc/fpgeneric.h"
|
||||
#include "core/providers/rocm/tunable/gemm.h"
|
||||
|
||||
using namespace onnxruntime::rocm;
|
||||
using namespace ::onnxruntime::common;
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
constexpr int kPastSequenceLengthInputIndex = 6;
|
||||
constexpr int kPastInputIndex = 4;
|
||||
constexpr int kPresentOutputIndex = 1;
|
||||
|
||||
#define REGISTER_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
Attention, \
|
||||
kMSDomain, \
|
||||
1, \
|
||||
T, \
|
||||
kRocmExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.MayInplace(kPastInputIndex, kPresentOutputIndex) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
|
||||
.InputMemoryType(OrtMemTypeCPUInput, kPastSequenceLengthInputIndex), \
|
||||
Attention<T>);
|
||||
|
||||
REGISTER_KERNEL_TYPED(float)
|
||||
REGISTER_KERNEL_TYPED(MLFloat16)
|
||||
|
||||
template <typename T>
|
||||
Attention<T>::Attention(const OpKernelInfo& info) : RocmKernel(info), AttentionBase(info, true) {}
|
||||
|
||||
template <typename T>
|
||||
Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
||||
const Tensor* input = context->Input<Tensor>(0);
|
||||
const Tensor* weights = context->Input<Tensor>(1);
|
||||
const Tensor* bias = context->Input<Tensor>(2);
|
||||
const Tensor* mask_index = context->Input<Tensor>(3);
|
||||
const Tensor* past = context->Input<Tensor>(4);
|
||||
const Tensor* relative_position_bias = context->Input<Tensor>(5);
|
||||
const Tensor* past_seq_len = context->Input<Tensor>(kPastSequenceLengthInputIndex);
|
||||
|
||||
auto& device_prop = GetDeviceProp();
|
||||
AttentionParameters parameters;
|
||||
ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(),
|
||||
weights->Shape(),
|
||||
bias->Shape(),
|
||||
mask_index,
|
||||
past,
|
||||
relative_position_bias,
|
||||
¶meters,
|
||||
device_prop.maxThreadsPerBlock,
|
||||
past_seq_len));
|
||||
ORT_ENFORCE(parameters.sequence_length == parameters.kv_sequence_length); // self attention
|
||||
|
||||
TensorShapeVector output_shape(3);
|
||||
output_shape[0] = static_cast<int64_t>(parameters.batch_size);
|
||||
output_shape[1] = static_cast<int64_t>(parameters.sequence_length);
|
||||
output_shape[2] = static_cast<int64_t>(parameters.v_hidden_size);
|
||||
Tensor* output = context->Output(0, output_shape);
|
||||
|
||||
std::vector<int64_t> present_dims{
|
||||
2, parameters.batch_size, parameters.num_heads,
|
||||
parameters.past_present_share_buffer ? parameters.max_sequence_length : parameters.total_sequence_length,
|
||||
parameters.head_size};
|
||||
TensorShape present_shape(present_dims);
|
||||
Tensor* present = context->Output(kPresentOutputIndex, present_shape);
|
||||
|
||||
rocblas_handle rocblas = GetRocblasHandle(context);
|
||||
constexpr size_t element_size = sizeof(T);
|
||||
|
||||
int m = parameters.batch_size * parameters.sequence_length;
|
||||
int n = (parameters.hidden_size + parameters.hidden_size + parameters.v_hidden_size);
|
||||
int k = parameters.input_hidden_size;
|
||||
auto gemm_buffer = GetScratchBuffer<T>(static_cast<size_t>(m) * n, context->GetComputeStream());
|
||||
|
||||
typedef typename ToHipType<T>::MappedType HipT;
|
||||
namespace blas = rocm::tunable::blas;
|
||||
|
||||
// Bias shape is (N), broadcast using B(N, M) = 1 * bias(N, 1) x ones(1, M) + 0 * B.
|
||||
// TODO: use custom kernel of expand to improve the performance.
|
||||
ORT_RETURN_IF_ERROR(blas::column_major::Gemm(
|
||||
GetTuningContext(), Stream(context), rocblas,
|
||||
blas::BlasOp::NonTrans, blas::BlasOp::NonTrans,
|
||||
n, m, 1,
|
||||
/*alpha=*/1.0f,
|
||||
reinterpret_cast<const HipT*>(bias->Data<T>()), n,
|
||||
GetConstOnes<HipT>(m, Stream(context)), 1,
|
||||
/*beta=*/0.0f,
|
||||
reinterpret_cast<HipT*>(gemm_buffer.get()), n));
|
||||
|
||||
// result(N, M) = 1 * weights x input + 1 x B.
|
||||
ORT_RETURN_IF_ERROR(blas::column_major::Gemm(
|
||||
GetTuningContext(), Stream(context), rocblas,
|
||||
blas::BlasOp::NonTrans, blas::BlasOp::NonTrans,
|
||||
n, m, k,
|
||||
/*alpha=*/1.0f,
|
||||
reinterpret_cast<const HipT*>(weights->Data<T>()), n,
|
||||
reinterpret_cast<const HipT*>(input->Data<T>()), k,
|
||||
/*beta=*/1.0f,
|
||||
reinterpret_cast<HipT*>(gemm_buffer.get()), n));
|
||||
|
||||
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size,
|
||||
parameters.batch_size,
|
||||
parameters.num_heads,
|
||||
parameters.head_size,
|
||||
parameters.sequence_length,
|
||||
parameters.past_sequence_length);
|
||||
|
||||
auto work_space = GetScratchBuffer<void>(workSpaceSize, context->GetComputeStream());
|
||||
return LaunchAttentionKernel(
|
||||
device_prop,
|
||||
GetTuningContext(),
|
||||
Stream(context),
|
||||
rocblas,
|
||||
element_size,
|
||||
parameters.batch_size,
|
||||
parameters.sequence_length,
|
||||
parameters.num_heads,
|
||||
parameters.head_size,
|
||||
parameters.past_sequence_length,
|
||||
parameters.is_unidirectional,
|
||||
reinterpret_cast<const void*>(gemm_buffer.get()),
|
||||
nullptr == mask_index ? nullptr : mask_index->Data<int>(),
|
||||
nullptr == mask_index ? gsl::span<const int64_t>() : mask_index->Shape().GetDims(),
|
||||
parameters.mask_filter_value,
|
||||
nullptr == past ? nullptr : past->Data<T>(),
|
||||
nullptr == relative_position_bias ? nullptr : relative_position_bias->Data<T>(),
|
||||
work_space.get(),
|
||||
output->MutableData<T>(),
|
||||
nullptr == present ? nullptr : present->MutableData<T>());
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
173
onnxruntime/contrib_ops/rocm/bert/attention.cu
Normal file
173
onnxruntime/contrib_ops/rocm/bert/attention.cu
Normal file
|
|
@ -0,0 +1,173 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "contrib_ops/rocm/bert/attention.h"
|
||||
#include "contrib_ops/rocm/bert/attention_impl.h"
|
||||
#include "contrib_ops/rocm/bert/batched_gemm_permute_pipelines.cuh"
|
||||
#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh"
|
||||
#include "contrib_ops/rocm/bert/transformer_common.h"
|
||||
#include "core/providers/rocm/rocm_common.h"
|
||||
#include "core/providers/rocm/shared_inc/fpgeneric.h"
|
||||
#include "core/providers/rocm/tunable/gemm.h"
|
||||
|
||||
using namespace onnxruntime::rocm;
|
||||
using namespace ::onnxruntime::common;
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
constexpr int kPastSequenceLengthInputIndex = 6;
|
||||
constexpr int kPastInputIndex = 4;
|
||||
constexpr int kPresentOutputIndex = 1;
|
||||
|
||||
#define REGISTER_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
Attention, \
|
||||
kMSDomain, \
|
||||
1, \
|
||||
T, \
|
||||
kRocmExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.MayInplace(kPastInputIndex, kPresentOutputIndex) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
|
||||
.InputMemoryType(OrtMemTypeCPUInput, kPastSequenceLengthInputIndex), \
|
||||
Attention<T>);
|
||||
|
||||
REGISTER_KERNEL_TYPED(float)
|
||||
REGISTER_KERNEL_TYPED(MLFloat16)
|
||||
|
||||
template <typename T>
|
||||
Attention<T>::Attention(const OpKernelInfo& info) : RocmKernel(info), AttentionBase(info, true) {}
|
||||
|
||||
template <typename T>
|
||||
Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
||||
const Tensor* input = context->Input<Tensor>(0);
|
||||
const Tensor* weights = context->Input<Tensor>(1);
|
||||
const Tensor* bias = context->Input<Tensor>(2);
|
||||
const Tensor* mask_index = context->Input<Tensor>(3);
|
||||
const Tensor* past = context->Input<Tensor>(4);
|
||||
const Tensor* relative_position_bias = context->Input<Tensor>(5);
|
||||
const Tensor* past_seq_len = context->Input<Tensor>(kPastSequenceLengthInputIndex);
|
||||
|
||||
auto& device_prop = GetDeviceProp();
|
||||
AttentionParameters attn;
|
||||
ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(),
|
||||
weights->Shape(),
|
||||
bias->Shape(),
|
||||
mask_index,
|
||||
past,
|
||||
relative_position_bias,
|
||||
&attn,
|
||||
device_prop.maxThreadsPerBlock,
|
||||
past_seq_len));
|
||||
ORT_ENFORCE(attn.sequence_length == attn.kv_sequence_length); // self attention
|
||||
|
||||
TensorShapeVector output_shape(3);
|
||||
output_shape[0] = static_cast<int64_t>(attn.batch_size);
|
||||
output_shape[1] = static_cast<int64_t>(attn.sequence_length);
|
||||
output_shape[2] = static_cast<int64_t>(attn.v_hidden_size);
|
||||
Tensor* output = context->Output(0, output_shape);
|
||||
|
||||
std::vector<int64_t> present_dims{
|
||||
2, attn.batch_size, attn.num_heads,
|
||||
past_present_share_buffer_ ? attn.max_sequence_length : attn.total_sequence_length,
|
||||
attn.head_size};
|
||||
TensorShape present_shape(present_dims);
|
||||
Tensor* present = context->Output(kPresentOutputIndex, present_shape);
|
||||
|
||||
auto stream = Stream(context);
|
||||
rocblas_handle rocblas = GetRocblasHandle(context);
|
||||
|
||||
using HipT = typename ToHipType<T>::MappedType;
|
||||
using QkvProjectGeneric = GemmPermuteGenericPipeline<HipT>;
|
||||
using AttentionGeneric = GemmSoftmaxGemmPermuteGenericPipeline<HipT>;
|
||||
|
||||
size_t qkv_project_output_bytes = QkvProjectGeneric::GetOutputNumBytes(&attn);
|
||||
size_t attention_workspace_bytes = AttentionGeneric::GetWorkspaceNumBytes(&attn);
|
||||
ORT_ENFORCE(QkvProjectGeneric::GetWorkspaceNumBytes(&attn) <= attention_workspace_bytes); // workspace reuse
|
||||
|
||||
auto qkv_project_output = GetScratchBuffer<void>(qkv_project_output_bytes, context->GetComputeStream());
|
||||
auto workspace = GetScratchBuffer<void>(attention_workspace_bytes, context->GetComputeStream());
|
||||
|
||||
GemmPermuteParams<HipT> gemm_permute_params;
|
||||
{
|
||||
auto& params = gemm_permute_params;
|
||||
params.tuning_ctx = GetTuningContext();
|
||||
params.stream = stream;
|
||||
params.handle = rocblas;
|
||||
params.attention = &attn;
|
||||
params.device_prop = &device_prop;
|
||||
|
||||
params.input_buffer = reinterpret_cast<const HipT*>(input->DataRaw());
|
||||
params.weight_buffer = reinterpret_cast<const HipT*>(weights->DataRaw());
|
||||
params.bias_buffer = reinterpret_cast<const HipT*>(bias->DataRaw());
|
||||
params.out_buffer = reinterpret_cast<HipT*>(qkv_project_output.get());
|
||||
params.ones = GetConstOnes<HipT>(attn.batch_size * attn.sequence_length, stream);
|
||||
params.workspace_buffer = reinterpret_cast<HipT*>(workspace.get()); // workspace reuse
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(QkvProjectGeneric::Run(&gemm_permute_params));
|
||||
auto [q_buffer, k_buffer, v_buffer] = QkvProjectGeneric::UnspliceOutputQKV(&gemm_permute_params);
|
||||
|
||||
if (nullptr != present) {
|
||||
// Concat past (2xBxNxS'xH) to present (2xBxNxTxH):
|
||||
// past_k (BxNxS'xH) + k (BxNxSxH) => present_k (BxNxTxH)
|
||||
// past_v (BxNxS'xH) + v (BxNxSxH) => present_v (BxNxTxH)
|
||||
const int batches = attn.batch_size * attn.num_heads;
|
||||
const int present_size_per_batch = attn.total_sequence_length * attn.head_size;
|
||||
ORT_RETURN_IF_ERROR(
|
||||
LaunchConcatPastToPresent(Stream(context),
|
||||
attn.total_sequence_length,
|
||||
attn.sequence_length,
|
||||
attn.batch_size,
|
||||
attn.head_size,
|
||||
attn.num_heads,
|
||||
device_prop.maxThreadsPerBlock,
|
||||
nullptr == past ? nullptr : reinterpret_cast<const HipT*>(past->DataRaw()),
|
||||
k_buffer,
|
||||
reinterpret_cast<HipT*>(present->MutableDataRaw())));
|
||||
|
||||
// update pointers to present_k and present_v.
|
||||
k_buffer = reinterpret_cast<HipT*>(present->MutableDataRaw());
|
||||
v_buffer = reinterpret_cast<HipT*>(present->MutableDataRaw()) + batches * present_size_per_batch;
|
||||
}
|
||||
|
||||
// For testing, environment variable ORT_TRANSFORMER_OPTIONS=1 could enable persistent softmax
|
||||
const TransformerOptions* options = TransformerOptions::GetInstance();
|
||||
bool use_persistent_softmax = options->IsPrecisionMode() && !options->DisablePersistentSoftmax();
|
||||
|
||||
GemmSoftmaxGemmPermuteParams<HipT> gemm_softmax_gemm_permute_params;
|
||||
{
|
||||
auto& params = gemm_softmax_gemm_permute_params;
|
||||
params.tuning_ctx = GetTuningContext();
|
||||
params.stream = Stream(context);
|
||||
params.handle = rocblas;
|
||||
params.attention = &attn;
|
||||
params.device_prop = &device_prop;
|
||||
// FIXME: the params.scale seems to be different from AttentionParameters::scale;
|
||||
params.scale = 1.0f / sqrt(static_cast<float>(attn.head_size));
|
||||
params.q_buffer = q_buffer;
|
||||
params.k_buffer = k_buffer;
|
||||
params.v_buffer = v_buffer;
|
||||
params.out_buffer = reinterpret_cast<HipT*>(output->MutableDataRaw());
|
||||
|
||||
if (relative_position_bias != nullptr) {
|
||||
params.bias_buffer = reinterpret_cast<const HipT*>(relative_position_bias->DataRaw());
|
||||
}
|
||||
|
||||
if (mask_index != nullptr) {
|
||||
params.mask_index_buffer = mask_index->Data<int>();
|
||||
params.mask_index_dims = mask_index->Shape().GetDims();
|
||||
}
|
||||
|
||||
params.workspace_buffer = reinterpret_cast<HipT*>(workspace.get());
|
||||
}
|
||||
|
||||
return AttentionGeneric::Run(&gemm_softmax_gemm_permute_params, use_persistent_softmax);
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -26,17 +26,21 @@ limitations under the License.
|
|||
#include "core/providers/rocm/rocm_common.h"
|
||||
#include "core/providers/rocm/shared_inc/fpgeneric.h"
|
||||
#include "core/providers/rocm/tunable/gemm.h"
|
||||
#include "core/providers/rocm/tunable/rocm_tunable.h"
|
||||
#include "contrib_ops/cpu/bert/attention_base.h"
|
||||
#include "contrib_ops/rocm/bert/attention_impl.h"
|
||||
#include "contrib_ops/rocm/bert/attention_softmax.h"
|
||||
#include "contrib_ops/rocm/bert/transformer_common.h"
|
||||
|
||||
using namespace onnxruntime::rocm;
|
||||
using namespace hipcub;
|
||||
|
||||
namespace blas = onnxruntime::rocm::tunable::blas;
|
||||
|
||||
#define CHECK_ROCM(expr) HIP_RETURN_IF_ERROR(expr)
|
||||
|
||||
using namespace onnxruntime::rocm;
|
||||
using namespace ::onnxruntime::common;
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
|
@ -49,8 +53,8 @@ size_t GetAttentionScratchSize(size_t element_size,
|
|||
int batch_size,
|
||||
int num_heads,
|
||||
int sequence_length,
|
||||
int all_sequence_length) {
|
||||
const size_t bytes = element_size * batch_size * num_heads * sequence_length * all_sequence_length;
|
||||
int total_sequence_length) {
|
||||
const size_t bytes = element_size * batch_size * num_heads * sequence_length * total_sequence_length;
|
||||
|
||||
const size_t alignment = 256;
|
||||
const size_t bytesAligned = AlignTo(bytes, alignment);
|
||||
|
|
@ -69,181 +73,9 @@ size_t GetAttentionWorkspaceSize(
|
|||
sequence_length, past_sequence_length + sequence_length);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status QkvToContext(
|
||||
const hipDeviceProp_t& prop,
|
||||
RocmTuningContext* tuning_ctx,
|
||||
rocblas_handle& rocblas,
|
||||
hipStream_t stream,
|
||||
const int batch_size,
|
||||
const int sequence_length,
|
||||
const int num_heads,
|
||||
const int head_size,
|
||||
const size_t element_size,
|
||||
const T* input,
|
||||
T* output,
|
||||
T* workspace,
|
||||
const int* mask_index,
|
||||
gsl::span<const int64_t> mask_index_dims,
|
||||
const float mask_filter_value,
|
||||
bool is_unidirectional,
|
||||
int past_sequence_length,
|
||||
const T* past,
|
||||
const T* relative_position_bias,
|
||||
T* present,
|
||||
bool use_persistent_softmax) {
|
||||
const int all_sequence_length = past_sequence_length + sequence_length;
|
||||
const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads,
|
||||
sequence_length, all_sequence_length);
|
||||
T* scratch1 = workspace;
|
||||
T* scratch2 = scratch1 + (bytes / element_size);
|
||||
T* scratch3 = scratch2 + (bytes / element_size);
|
||||
|
||||
const int max_threads_per_block = prop.maxThreadsPerBlock;
|
||||
|
||||
// input should be BxSx3xNxH => scratch3: 3xBxNxSxH
|
||||
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 3, sequence_length, batch_size, head_size, num_heads,
|
||||
max_threads_per_block, false, input, scratch3));
|
||||
|
||||
// now scratch3 has Q, K, V: each has size BxNxSxH
|
||||
const int batches = batch_size * num_heads;
|
||||
const int size_per_batch = sequence_length * head_size;
|
||||
const int total_size = batches * size_per_batch;
|
||||
|
||||
const T* q = scratch3;
|
||||
const T* k = q + total_size;
|
||||
const T* v = k + total_size;
|
||||
|
||||
rocblas_set_stream(rocblas, stream);
|
||||
|
||||
// Concat past (2xBxNxS'xH) to present (2xBxNxS*xH):
|
||||
// past_k (BxNxS'xH) + k (BxNxSxH) => present_k (BxNxS*xH)
|
||||
// past_v (BxNxS'xH) + v (BxNxSxH) => present_v (BxNxS*xH)
|
||||
const int present_size_per_batch = all_sequence_length * head_size;
|
||||
if (nullptr != present) {
|
||||
ORT_RETURN_IF_ERROR(
|
||||
LaunchConcatPastToPresent(stream, all_sequence_length, sequence_length, batch_size, head_size, num_heads,
|
||||
max_threads_per_block, past, k, present));
|
||||
|
||||
// update pointers to present_k and present_v.
|
||||
k = present;
|
||||
v = present + batches * present_size_per_batch;
|
||||
}
|
||||
|
||||
// Raw attention mask could be 2D (BxS) or 3D (BxSxS*) or 4D(Bx1xMxM), where M is the max sequence length.
|
||||
bool use_raw_attention_mask = (nullptr != mask_index && mask_index_dims.size() >= 2);
|
||||
|
||||
// compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxS*
|
||||
// Q: BxNxSxH, K (present_k): BxNxS*xH, Q*K': BxNxSxS*
|
||||
const float rsqrt_head_size = 1.f / sqrt(static_cast<float>(head_size));
|
||||
const int temp_matrix_size = sequence_length * all_sequence_length;
|
||||
|
||||
ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm(
|
||||
tuning_ctx, stream, rocblas,
|
||||
blas::BlasOp::Trans, blas::BlasOp::NonTrans,
|
||||
all_sequence_length, sequence_length, head_size,
|
||||
// For raw attention mask, the scalar if 1/sqrt(H) is moved to softmax computation.
|
||||
/*alpha=*/use_raw_attention_mask ? 1.0f : rsqrt_head_size,
|
||||
k, head_size, present_size_per_batch,
|
||||
q, head_size, size_per_batch,
|
||||
/*beta=*/0.0f,
|
||||
scratch1, all_sequence_length, temp_matrix_size,
|
||||
batches));
|
||||
|
||||
// apply softmax and store result P to scratch2: BxNxSxS*
|
||||
if (use_raw_attention_mask) { // 2d, 3d or 4d attention mask
|
||||
const int mask_dimension = static_cast<int>(mask_index_dims.size());
|
||||
const int max_sequence_length = mask_dimension == 4 ? static_cast<int>(mask_index_dims[3]) : 0;
|
||||
|
||||
T* persistent_softmax_workspace = scratch1; // replace Q*K' in place if persistent softmax is selected.
|
||||
ORT_RETURN_IF_ERROR(
|
||||
ComputeSoftmaxWithRawMask<T>(stream, all_sequence_length, sequence_length, batch_size, num_heads,
|
||||
mask_index, nullptr, relative_position_bias, scratch1, scratch2,
|
||||
is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length,
|
||||
use_persistent_softmax, persistent_softmax_workspace, mask_filter_value));
|
||||
} else if (nullptr != mask_index) { // 1d mask index
|
||||
ORT_ENFORCE(mask_index_dims.size() == 1);
|
||||
// mask_index has 1D shape: either (batch_size) or (2*batch_size). Only the later one has start postions.
|
||||
const int* mask_start = (mask_index_dims[0] > batch_size) ? mask_index + batch_size : nullptr;
|
||||
ORT_RETURN_IF_ERROR(ComputeSoftmaxWithMask1D<T>(stream, all_sequence_length, sequence_length, batch_size, num_heads,
|
||||
mask_index, mask_start, relative_position_bias, scratch1, scratch2, is_unidirectional));
|
||||
} else { // no mask
|
||||
ORT_RETURN_IF_ERROR(ComputeSoftmax<T>(stream, all_sequence_length, sequence_length, batch_size, num_heads,
|
||||
relative_position_bias, scratch1, scratch2, is_unidirectional));
|
||||
}
|
||||
|
||||
// compute P*V (as V*P), and store in scratch3: BxNxSxH
|
||||
ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm(
|
||||
tuning_ctx, stream, rocblas,
|
||||
blas::BlasOp::NonTrans, blas::BlasOp::NonTrans,
|
||||
head_size, sequence_length, all_sequence_length,
|
||||
/*alpha=*/1.0f,
|
||||
v, head_size, present_size_per_batch,
|
||||
scratch2, all_sequence_length, temp_matrix_size,
|
||||
/*beta=*/0.0f,
|
||||
scratch3, head_size, size_per_batch,
|
||||
batches));
|
||||
|
||||
// scratch3 is BxNxSxH, transpose to output BxSxNxH
|
||||
return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads,
|
||||
max_threads_per_block, false, scratch3, output);
|
||||
}
|
||||
|
||||
Status LaunchAttentionKernel(
|
||||
const hipDeviceProp_t& prop,
|
||||
RocmTuningContext* tuning_ctx,
|
||||
hipStream_t stream,
|
||||
rocblas_handle& rocblas,
|
||||
const size_t element_size,
|
||||
int batch_size,
|
||||
int sequence_length,
|
||||
int num_heads,
|
||||
int head_size,
|
||||
int past_sequence_length,
|
||||
bool is_unidirectional,
|
||||
const void* input,
|
||||
const int* mask_index,
|
||||
gsl::span<const int64_t> mask_index_dims,
|
||||
const float mask_filter_value,
|
||||
const void* past,
|
||||
const void* relative_position_bias,
|
||||
void* workspace,
|
||||
void* output,
|
||||
void* present) {
|
||||
// For testing, environment variable ORT_TRANSFORMER_OPTIONS=1 could enable persistent softmax
|
||||
const TransformerOptions* options = TransformerOptions::GetInstance();
|
||||
bool use_persistent_softmax = options->IsPrecisionMode() && !options->DisablePersistentSoftmax();
|
||||
if (element_size == 2) {
|
||||
return QkvToContext(
|
||||
prop, tuning_ctx, rocblas, stream, batch_size, sequence_length, num_heads, head_size, element_size,
|
||||
reinterpret_cast<const __half*>(input),
|
||||
reinterpret_cast<__half*>(output),
|
||||
reinterpret_cast<__half*>(workspace),
|
||||
mask_index,
|
||||
mask_index_dims,
|
||||
mask_filter_value,
|
||||
is_unidirectional,
|
||||
past_sequence_length,
|
||||
reinterpret_cast<const __half*>(past),
|
||||
reinterpret_cast<const __half*>(relative_position_bias),
|
||||
reinterpret_cast<__half*>(present),
|
||||
use_persistent_softmax);
|
||||
} else {
|
||||
return QkvToContext(
|
||||
prop, tuning_ctx, rocblas, stream, batch_size, sequence_length, num_heads, head_size, element_size,
|
||||
reinterpret_cast<const float*>(input),
|
||||
reinterpret_cast<float*>(output),
|
||||
reinterpret_cast<float*>(workspace),
|
||||
mask_index,
|
||||
mask_index_dims,
|
||||
mask_filter_value,
|
||||
is_unidirectional,
|
||||
past_sequence_length,
|
||||
reinterpret_cast<const float*>(past),
|
||||
reinterpret_cast<const float*>(relative_position_bias),
|
||||
reinterpret_cast<float*>(present),
|
||||
use_persistent_softmax);
|
||||
}
|
||||
inline int3 Get2DMaskStrides(int total_sequence_length) {
|
||||
// stride == 0 indicate broadcasting
|
||||
return {total_sequence_length, 0, 1};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -375,9 +207,11 @@ Status DecoderQkvToContext(
|
|||
}
|
||||
|
||||
if (has_key_padding_mask) {
|
||||
ORT_RETURN_IF_ERROR(ComputeSoftmaxWithRawMask<T>(stream, kv_sequence_length, sequence_length, batch_size,
|
||||
num_heads, nullptr, key_padding_mask, nullptr, scratch1, scratch2,
|
||||
false, 1, 2, static_cast<int>(0), false, nullptr, mask_filter_value));
|
||||
int3 strides = Get2DMaskStrides(kv_sequence_length);
|
||||
ORT_RETURN_IF_ERROR(ComputeSoftmaxWithRawMask<T>(
|
||||
stream, kv_sequence_length, sequence_length, batch_size, num_heads,
|
||||
strides, nullptr, key_padding_mask, nullptr, scratch1, scratch2,
|
||||
false, 1.0f, false, nullptr, mask_filter_value));
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(ComputeSoftmax<T>(stream, kv_sequence_length, sequence_length, batch_size,
|
||||
num_heads, nullptr, scratch1, scratch2, false));
|
||||
|
|
|
|||
|
|
@ -26,29 +26,6 @@ size_t GetAttentionWorkspaceSize(
|
|||
int sequence_length,
|
||||
int past_sequence_length);
|
||||
|
||||
Status LaunchAttentionKernel(
|
||||
const hipDeviceProp_t& prop, // Device Properties
|
||||
RocmTuningContext* tuning_ctx, // context for tuning
|
||||
hipStream_t stream, // Hip stream
|
||||
rocblas_handle& rocblas, // Rocblas handle
|
||||
const size_t element_size, // Element size of input tensor
|
||||
int batch_size, // Batch size (B)
|
||||
int sequence_length, // Sequence length (S)
|
||||
int num_heads, // Number of attention heads (N)
|
||||
int head_size, // Hidden layer size per head (H)
|
||||
int past_sequence_length, // Sequence length in past state
|
||||
bool is_unidirectional, // Whether there is unidirectional mask.
|
||||
const void* input, // Input tensor
|
||||
const int* mask_index, // Attention mask raw data or index. NULL means no mask.
|
||||
gsl::span<const int64_t> mask_index_dims, // Mask index shape
|
||||
const float mask_filter_value, // Mask value for filtered out positions
|
||||
const void* past, // Past state input
|
||||
const void* relative_position_bias, // Additional Add
|
||||
void* workspace, // Temporary buffer
|
||||
void* output, // Output tensor
|
||||
void* present // Present state output
|
||||
);
|
||||
|
||||
Status LaunchDecoderAttentionKernel(
|
||||
const hipDeviceProp_t& prop, // Device Properties
|
||||
RocmTuningContext* tuning_ctx, // context for tuning
|
||||
|
|
|
|||
|
|
@ -174,20 +174,23 @@ __device__ inline void SoftmaxSmall(const int all_sequence_length,
|
|||
}
|
||||
}
|
||||
|
||||
// Note about the attention_mask_strides and attention_mask/key_padding_mask
|
||||
// attention_mask accepts 2D, 3D or 4D tensor, but it will be viewed as 3D tensor uniformally and it will be indexed
|
||||
// as [batch_index, sequence_index, token_index].
|
||||
template <typename T, unsigned TPB>
|
||||
__device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
const int* attention_mask, // 2D, 3D or 4D attention mask
|
||||
const bool* key_padding_mask,
|
||||
const T* add_before_softmax,
|
||||
const T* input,
|
||||
T* output,
|
||||
const bool is_unidirectional,
|
||||
const float rsqrt_head_size,
|
||||
const int mask_dimension,
|
||||
const int max_sequence_length,
|
||||
const bool skip_softmax,
|
||||
const float mask_filter_value) {
|
||||
__global__ void SoftmaxWithRawMaskSmallKernel(
|
||||
const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
const int3 attention_mask_strides,
|
||||
const int* attention_mask, // 2D, 3D or 4D attention mask
|
||||
const bool* key_padding_mask,
|
||||
const T* add_before_softmax,
|
||||
const T* input,
|
||||
T* output,
|
||||
const bool is_unidirectional,
|
||||
const float rsqrt_head_size,
|
||||
const bool skip_softmax,
|
||||
const float mask_filter_value) {
|
||||
using BlockReduce = hipcub::BlockReduce<float, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmp_storage;
|
||||
|
||||
|
|
@ -216,16 +219,10 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length,
|
|||
}
|
||||
}
|
||||
|
||||
int mask_offset = 0;
|
||||
const int batch_index = blockIdx.y;
|
||||
if (mask_dimension == 2) {
|
||||
mask_offset = batch_index * all_sequence_length + threadIdx.x;
|
||||
} else if (mask_dimension == 3) {
|
||||
mask_offset = (batch_index * sequence_length + sequence_index) * all_sequence_length + threadIdx.x;
|
||||
} else if (mask_dimension == 4) {
|
||||
int from_index = all_sequence_length - sequence_length + sequence_index;
|
||||
mask_offset = (batch_index * max_sequence_length + from_index) * max_sequence_length + threadIdx.x;
|
||||
}
|
||||
int mask_offset = attention_mask_strides.x * batch_index +
|
||||
attention_mask_strides.y * sequence_index +
|
||||
attention_mask_strides.z * threadIdx.x;
|
||||
|
||||
if (nullptr == key_padding_mask) {
|
||||
const int& mask = attention_mask[mask_offset];
|
||||
|
|
@ -320,7 +317,7 @@ Status ComputeSoftmax(
|
|||
SoftmaxKernel<T, blockSize><<<grid, blockSize, 0, stream>>>(
|
||||
all_sequence_length, sequence_length, add_before_softmax, input, output);
|
||||
} else {
|
||||
ORT_THROW("Attention ROCM operator does not support total sequence length > 1024.");
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attention ROCM operator does not support total sequence length > 1024.");
|
||||
}
|
||||
|
||||
return HIP_CALL(hipPeekAtLastError());
|
||||
|
|
@ -375,26 +372,6 @@ __global__ void MaskedSoftmaxKernel(const int all_sequence_length, const int seq
|
|||
add_before_softmax, input, output);
|
||||
}
|
||||
|
||||
template <typename T, unsigned TPB>
|
||||
__global__ void SoftmaxWithRawMaskSmallKernel(const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
const int* attention_mask,
|
||||
const bool* key_padding_mask,
|
||||
const T* add_before_softmax,
|
||||
const T* input, T* output,
|
||||
const bool is_unidirectional,
|
||||
const float rsqrt_head_size,
|
||||
const int mask_dimension,
|
||||
const int max_sequence_length,
|
||||
const bool skip_softmax,
|
||||
const float mask_filter_value) {
|
||||
SoftmaxWithRawMaskSmall<T, TPB>(
|
||||
all_sequence_length, sequence_length,
|
||||
attention_mask, key_padding_mask, add_before_softmax, input, output,
|
||||
is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length,
|
||||
skip_softmax, mask_filter_value);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status ComputeSoftmaxWithMask1D(
|
||||
hipStream_t stream,
|
||||
|
|
@ -403,115 +380,83 @@ Status ComputeSoftmaxWithMask1D(
|
|||
const T* add_before_softmax, const T* input, T* output, const bool is_unidirectional) {
|
||||
const dim3 grid(sequence_length * num_heads, batch_size, 1);
|
||||
|
||||
#define DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(block_size) \
|
||||
MaskedSoftmaxKernelSmall<T, block_size><<<grid, block_size, 0, stream>>>( \
|
||||
all_sequence_length, sequence_length, mask_index, mask_start, \
|
||||
add_before_softmax, input, output, is_unidirectional);
|
||||
|
||||
if (all_sequence_length <= 32) {
|
||||
const int blockSize = 32;
|
||||
MaskedSoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(
|
||||
all_sequence_length, sequence_length, mask_index, mask_start,
|
||||
add_before_softmax, input, output, is_unidirectional);
|
||||
DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(32);
|
||||
} else if (all_sequence_length <= 64) {
|
||||
const int blockSize = 64;
|
||||
MaskedSoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(
|
||||
all_sequence_length, sequence_length, mask_index, mask_start,
|
||||
add_before_softmax, input, output, is_unidirectional);
|
||||
DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(64);
|
||||
} else if (all_sequence_length <= 128) {
|
||||
const int blockSize = 128;
|
||||
MaskedSoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(
|
||||
all_sequence_length, sequence_length, mask_index, mask_start,
|
||||
add_before_softmax, input, output, is_unidirectional);
|
||||
DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(128);
|
||||
} else if (all_sequence_length <= 256) {
|
||||
const int blockSize = 256;
|
||||
MaskedSoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(
|
||||
all_sequence_length, sequence_length, mask_index, mask_start,
|
||||
add_before_softmax, input, output, is_unidirectional);
|
||||
DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(256);
|
||||
} else if (all_sequence_length <= 512) {
|
||||
const int blockSize = 512;
|
||||
MaskedSoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(
|
||||
all_sequence_length, sequence_length, mask_index, mask_start,
|
||||
add_before_softmax, input, output, is_unidirectional);
|
||||
DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(512);
|
||||
} else if (all_sequence_length <= 1024) {
|
||||
const int blockSize = 1024;
|
||||
MaskedSoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(
|
||||
all_sequence_length, sequence_length, mask_index, mask_start,
|
||||
add_before_softmax, input, output, is_unidirectional);
|
||||
DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(1024);
|
||||
} else if (!is_unidirectional) {
|
||||
const int blockSize = 1024;
|
||||
MaskedSoftmaxKernel<T, blockSize><<<grid, blockSize, 0, stream>>>(
|
||||
all_sequence_length, sequence_length, mask_index, mask_start,
|
||||
add_before_softmax, input, output);
|
||||
} else {
|
||||
ORT_THROW("Attention ROCM operator does not support total sequence length > 1024.");
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attention ROCM operator does not support total sequence length > 1024.");
|
||||
}
|
||||
|
||||
#undef DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE
|
||||
|
||||
return HIP_CALL(hipPeekAtLastError());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status ComputeSoftmaxWithRawMask(hipStream_t stream,
|
||||
const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
const int batch_size,
|
||||
const int num_heads,
|
||||
const int* attention_mask,
|
||||
const bool* key_padding_mask,
|
||||
const T* add_before_softmax,
|
||||
const T* input,
|
||||
T* output,
|
||||
const bool is_unidirectional,
|
||||
const float rsqrt_head_size,
|
||||
const int mask_dimension,
|
||||
const int max_sequence_length,
|
||||
const bool use_persistent_softmax,
|
||||
T* persistent_softmax_workspace,
|
||||
const float mask_filter_value) {
|
||||
const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
const int batch_size,
|
||||
const int num_heads,
|
||||
const int3 attention_mask_strides,
|
||||
const int* attention_mask,
|
||||
const bool* key_padding_mask,
|
||||
const T* add_before_softmax,
|
||||
const T* input,
|
||||
T* output,
|
||||
const bool is_unidirectional,
|
||||
const float rsqrt_head_size,
|
||||
const bool use_persistent_softmax,
|
||||
T* persistent_softmax_workspace,
|
||||
const float mask_filter_value) {
|
||||
const dim3 grid(sequence_length * num_heads, batch_size, 1);
|
||||
|
||||
T* out = use_persistent_softmax ? persistent_softmax_workspace : output;
|
||||
|
||||
#define DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(block_size) \
|
||||
SoftmaxWithRawMaskSmallKernel<T, block_size><<<grid, block_size, 0, stream>>>( \
|
||||
all_sequence_length, sequence_length, attention_mask_strides, \
|
||||
attention_mask, key_padding_mask, add_before_softmax, input, out, \
|
||||
is_unidirectional, rsqrt_head_size, \
|
||||
use_persistent_softmax, mask_filter_value);
|
||||
|
||||
if (all_sequence_length <= 32) {
|
||||
const int blockSize = 32;
|
||||
SoftmaxWithRawMaskSmallKernel<T, blockSize><<<grid, blockSize, 0, stream>>>(
|
||||
all_sequence_length, sequence_length,
|
||||
attention_mask, key_padding_mask, add_before_softmax, input, out,
|
||||
is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length,
|
||||
use_persistent_softmax, mask_filter_value);
|
||||
DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(32);
|
||||
} else if (all_sequence_length <= 64) {
|
||||
const int blockSize = 64;
|
||||
SoftmaxWithRawMaskSmallKernel<T, blockSize><<<grid, blockSize, 0, stream>>>(
|
||||
all_sequence_length, sequence_length,
|
||||
attention_mask, key_padding_mask, add_before_softmax, input, out,
|
||||
is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length,
|
||||
use_persistent_softmax, mask_filter_value);
|
||||
DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(64);
|
||||
} else if (all_sequence_length <= 128) {
|
||||
const int blockSize = 128;
|
||||
SoftmaxWithRawMaskSmallKernel<T, blockSize><<<grid, blockSize, 0, stream>>>(
|
||||
all_sequence_length, sequence_length,
|
||||
attention_mask, key_padding_mask, add_before_softmax, input, out,
|
||||
is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length,
|
||||
use_persistent_softmax, mask_filter_value);
|
||||
DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(128);
|
||||
} else if (all_sequence_length <= 256) {
|
||||
const int blockSize = 256;
|
||||
SoftmaxWithRawMaskSmallKernel<T, blockSize><<<grid, blockSize, 0, stream>>>(
|
||||
all_sequence_length, sequence_length,
|
||||
attention_mask, key_padding_mask, add_before_softmax, input, out,
|
||||
is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length,
|
||||
use_persistent_softmax, mask_filter_value);
|
||||
DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(256);
|
||||
} else if (all_sequence_length <= 512) {
|
||||
const int blockSize = 512;
|
||||
SoftmaxWithRawMaskSmallKernel<T, blockSize><<<grid, blockSize, 0, stream>>>(
|
||||
all_sequence_length, sequence_length,
|
||||
attention_mask, key_padding_mask, add_before_softmax, input, out,
|
||||
is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length,
|
||||
use_persistent_softmax, mask_filter_value);
|
||||
DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(512);
|
||||
} else if (all_sequence_length <= 1024) {
|
||||
const int blockSize = 1024;
|
||||
SoftmaxWithRawMaskSmallKernel<T, blockSize><<<grid, blockSize, 0, stream>>>(
|
||||
all_sequence_length, sequence_length,
|
||||
attention_mask, key_padding_mask, add_before_softmax, input, out,
|
||||
is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length,
|
||||
use_persistent_softmax, mask_filter_value);
|
||||
DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE(1024);
|
||||
} else {
|
||||
ORT_THROW("Attention ROCM operator does not support total sequence length > 1024.");
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attention ROCM operator does not support total sequence length > 1024.");
|
||||
}
|
||||
|
||||
#undef DISPATCH_KERNEL_SMALL_WITH_BLOCKSIZE
|
||||
|
||||
if (use_persistent_softmax) {
|
||||
return dispatch_warpwise_softmax_forward<T, T, float, false>(stream,
|
||||
output,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,124 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/rocm/rocm_common.h"
|
||||
#include "core/providers/rocm/rocm_kernel.h"
|
||||
#include "core/providers/rocm/tunable/gemm.h"
|
||||
#include "core/providers/rocm/tunable/rocm_tunable.h"
|
||||
#include "contrib_ops/cpu/bert/attention_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
namespace blas = onnxruntime::rocm::tunable::blas;
|
||||
|
||||
namespace {
|
||||
std::tuple<int, int, int, int> GetQkvProjectGemmMNKBatch(const AttentionParameters* attention) {
|
||||
int m = attention->sequence_length;
|
||||
int n = (attention->hidden_size + attention->hidden_size + attention->v_hidden_size); // q + k + v
|
||||
int k = attention->input_hidden_size;
|
||||
int batch = attention->batch_size;
|
||||
return {m, n, k, batch};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
struct GemmPermuteParams : onnxruntime::rocm::tunable::OpParams {
|
||||
std::string Signature() const override {
|
||||
auto [m, n, k, batch] = GetQkvProjectGemmMNKBatch(attention);
|
||||
return MakeString("M", m, "_N", n, "_K", k, "_B", batch);
|
||||
}
|
||||
|
||||
rocblas_handle handle;
|
||||
const AttentionParameters* attention;
|
||||
const hipDeviceProp_t* device_prop;
|
||||
|
||||
const T* input_buffer;
|
||||
const T* weight_buffer;
|
||||
const T* bias_buffer;
|
||||
T* out_buffer;
|
||||
|
||||
int3 bias_strides;
|
||||
|
||||
const T* ones; // used for broadcasting bias if the underlying algorithm does not support strides
|
||||
T* workspace_buffer;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct GemmPermuteGenericPipeline {
|
||||
inline static size_t GetOutputNumBytes(const AttentionParameters* attn) {
|
||||
auto [m, n, _, batch] = GetQkvProjectGemmMNKBatch(attn);
|
||||
return sizeof(T) * m * n * batch;
|
||||
}
|
||||
|
||||
inline static size_t GetWorkspaceNumBytes(const AttentionParameters* attn) {
|
||||
return GetOutputNumBytes(attn);
|
||||
}
|
||||
|
||||
inline static std::tuple<int, int, int> GetGemmMNK(const GemmPermuteParams<T>* params) {
|
||||
auto [m, n, k, batch] = GetQkvProjectGemmMNKBatch(params->attention);
|
||||
return {batch * m, n, k};
|
||||
}
|
||||
|
||||
inline static std::tuple<const T*, const T*, const T*> UnspliceOutputQKV(const GemmPermuteParams<T>* params) {
|
||||
auto* attn = params->attention;
|
||||
int64_t batch = attn->batch_size * attn->num_heads;
|
||||
int64_t num_elems_per_batch = attn->sequence_length * attn->head_size;
|
||||
int64_t num_elems = batch * num_elems_per_batch;
|
||||
auto q = params->out_buffer + 0 * num_elems;
|
||||
auto k = params->out_buffer + 1 * num_elems;
|
||||
auto v = params->out_buffer + 2 * num_elems;
|
||||
return {q, k, v};
|
||||
}
|
||||
|
||||
inline static Status BroadcastBias(const GemmPermuteParams<T>* params) {
|
||||
auto [m, n, k] = GetGemmMNK(params);
|
||||
// Bias shape is (N), broadcast using B(M, N) = ones(M, 1) x bias(1, N).
|
||||
// TODO: use custom kernel of expand to improve the performance.
|
||||
return blas::row_major::Gemm(
|
||||
params->TuningContext(), params->Stream(), params->handle,
|
||||
blas::BlasOp::NonTrans, blas::BlasOp::NonTrans,
|
||||
m, n, 1,
|
||||
/*alpha=*/1.0f,
|
||||
params->ones, 1,
|
||||
params->bias_buffer, n,
|
||||
/*beta=*/0.0f,
|
||||
params->workspace_buffer, n);
|
||||
}
|
||||
|
||||
inline static Status Gemm(const GemmPermuteParams<T>* params) {
|
||||
auto [m, n, k] = GetGemmMNK(params);
|
||||
// result(M, N) = input x weights + bias.
|
||||
return blas::row_major::Gemm(
|
||||
params->TuningContext(), params->Stream(), params->handle,
|
||||
blas::BlasOp::NonTrans, blas::BlasOp::NonTrans,
|
||||
m, n, k,
|
||||
/*alpha=*/1.0f,
|
||||
params->input_buffer, k,
|
||||
params->weight_buffer, n,
|
||||
/*beta=*/1.0f,
|
||||
params->workspace_buffer, n);
|
||||
}
|
||||
|
||||
inline static Status Permute0213(const GemmPermuteParams<T>* params) {
|
||||
auto* attn = params->attention;
|
||||
// input should be BxSx3xNxH => gemm_buffer: 3xBxNxSxH
|
||||
return LaunchTransQkv(
|
||||
params->Stream(), 3, attn->sequence_length, attn->batch_size, attn->head_size, attn->num_heads,
|
||||
params->device_prop->maxThreadsPerBlock, false, params->workspace_buffer, params->out_buffer);
|
||||
}
|
||||
|
||||
static Status Run(const GemmPermuteParams<T>* params) {
|
||||
ORT_RETURN_IF_ERROR(BroadcastBias(params));
|
||||
ORT_RETURN_IF_ERROR(Gemm(params));
|
||||
ORT_RETURN_IF_ERROR(Permute0213(params));
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,249 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
/* About Computing in these Pipelines
|
||||
|
||||
B: batch size of Attention Op. NOTE: To be disambiguated with batch size of GEMMs
|
||||
S: sequence length
|
||||
T: total sequence length
|
||||
N: num of heads
|
||||
H: head dimension
|
||||
|
||||
BN: B*N, which is the batch size of GEMMs. NOTE: To be disambiguated with batch size of Attention Op
|
||||
|
||||
In QKV projection (prior to this pipeline):
|
||||
/-> Q [B,S,N*H] ->Reshape-> [B,S,N,H] ->Permute0213-> [B,N,S,H]
|
||||
X --o--> K [B,T,N*H] ->Reshape-> [B,T,N,H] ->Permute0213-> [B,N,T,H]
|
||||
\-> V [B,T,N*H] ->Reshape-> [B,T,N,H] ->Permute0213-> [B,N,T,H]
|
||||
|
||||
pre_softmax_attn_scores = Q*K' = [B,N,S,H] * [BxNxTxH]' = [B,N,S,T] Batched GEMM1
|
||||
pre_softmax_attn_scores_masked = pre_softmax_attn_scores +? bias +? mask Add Bias, +? is optional
|
||||
attn_scores = softmax(pre_softmax_attn_scores_masked * scale) = [B,N,S,T] Scale then Softmax
|
||||
scaled_multi_head_attn = attn_scores * V = [B,N,S,T] * [B,N,T,H] = [B,N,S,H] Batched GEMM2
|
||||
|
||||
Op outputs scaled_multi_head_attn:
|
||||
[B,N,S,H] ->Permute0213-> [B,S,N,H] ->Reshape-> [B,S,N*H]
|
||||
|
||||
|
||||
For the computing of pre_softmax_attn_scores +? mask +? bias:
|
||||
|
||||
GemmSoftmaxGemmPermuteGenericPipeline handles it in specialized softmax. TODO: remove it!
|
||||
|
||||
*/
|
||||
|
||||
#include "core/providers/rocm/tunable/gemm.h"
|
||||
#include "core/providers/rocm/tunable/rocm_tunable.h"
|
||||
#include "contrib_ops/cpu/bert/attention_base.h"
|
||||
#include "contrib_ops/rocm/bert/attention_impl.h"
|
||||
#include "contrib_ops/rocm/bert/attention_softmax.h"
|
||||
|
||||
namespace blas = onnxruntime::rocm::tunable::blas;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
inline int3 Get2DMaskStrides(int total_sequence_length) {
|
||||
// stride == 0 indicate broadcasting
|
||||
return {total_sequence_length, 0, 1};
|
||||
}
|
||||
|
||||
inline std::tuple<const int*, int3, int3> GetRawMaskBufferAddrSizesAndStrides(
|
||||
const int* buffer, const AttentionParameters* attn) {
|
||||
const int* offseted_buffer{buffer}; // how to view the mask buffer
|
||||
int3 sizes{-1, -1, -1}; // the logical shape of the view
|
||||
int3 strides{-1, -1, -1}; // the physical memory layout
|
||||
switch (attn->mask_type) {
|
||||
case MASK_NONE:
|
||||
case MASK_2D_DUMMY:
|
||||
break; // No mask
|
||||
case MASK_2D_KEY_PADDING:
|
||||
sizes = {attn->batch_size, 1, attn->total_sequence_length};
|
||||
strides = Get2DMaskStrides(attn->total_sequence_length);
|
||||
break;
|
||||
case MASK_3D_ATTENTION:
|
||||
sizes = {attn->batch_size, attn->sequence_length, attn->total_sequence_length};
|
||||
strides = {attn->sequence_length * attn->total_sequence_length, attn->total_sequence_length, 1};
|
||||
break;
|
||||
case MASK_4D_MEGATRON:
|
||||
// offset to skip past sequence part, so that we can index it with [batch_index, sequence_index, token_index]
|
||||
offseted_buffer = buffer + attn->past_sequence_length * attn->max_sequence_length;
|
||||
sizes = {attn->batch_size, attn->sequence_length, attn->total_sequence_length};
|
||||
strides = {attn->max_sequence_length * attn->max_sequence_length, attn->max_sequence_length, 1};
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("unsupported mask type");
|
||||
}
|
||||
return {offseted_buffer, sizes, strides};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct GemmSoftmaxGemmPermuteParams : onnxruntime::rocm::tunable::OpParams {
|
||||
std::string Signature() const override {
|
||||
auto [m, n, k, o, batch] = GetGemmsMNKOBatch();
|
||||
return MakeString("M", m, "_N", n, "_K", k, "_O", o, "_B", batch);
|
||||
}
|
||||
|
||||
std::tuple<int, int, int, int, int> GetGemmsMNKOBatch() const {
|
||||
ORT_ENFORCE(attention != nullptr);
|
||||
auto m = attention->sequence_length;
|
||||
auto n = attention->total_sequence_length;
|
||||
auto k = attention->head_size;
|
||||
auto o = attention->head_size;
|
||||
auto batch = attention->batch_size * attention->num_heads;
|
||||
return {m, n, k, o, batch};
|
||||
}
|
||||
|
||||
rocblas_handle handle;
|
||||
const AttentionParameters* attention;
|
||||
const hipDeviceProp_t* device_prop;
|
||||
|
||||
float scale;
|
||||
const T* q_buffer;
|
||||
const T* k_buffer;
|
||||
const T* v_buffer;
|
||||
T* out_buffer;
|
||||
|
||||
// optional, bias [B,N,S,T]
|
||||
const T* bias_buffer{nullptr};
|
||||
|
||||
// optional, mask value
|
||||
const int* mask_index_buffer{nullptr};
|
||||
gsl::span<const int64_t> mask_index_dims{};
|
||||
|
||||
// optional, internal
|
||||
T* workspace_buffer{nullptr};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct GemmSoftmaxGemmPermuteGenericPipeline {
|
||||
static bool UseRawAttentionMask(const GemmSoftmaxGemmPermuteParams<T>* params) {
|
||||
return params->mask_index_buffer != nullptr && params->mask_index_dims.size() >= 2;
|
||||
}
|
||||
|
||||
static std::tuple<T*, T*, T*> GetWorkspacePlan(const GemmSoftmaxGemmPermuteParams<T>* params) {
|
||||
auto bytes = GetAttentionScratchSize(
|
||||
sizeof(T),
|
||||
params->attention->batch_size,
|
||||
params->attention->num_heads,
|
||||
params->attention->sequence_length,
|
||||
params->attention->total_sequence_length);
|
||||
auto gemm1_out = params->workspace_buffer;
|
||||
auto softmax_out = gemm1_out + (bytes / sizeof(T));
|
||||
auto gemm2_out = softmax_out + (bytes / sizeof(T));
|
||||
return {gemm1_out, softmax_out, gemm2_out};
|
||||
}
|
||||
|
||||
inline static size_t GetWorkspaceNumBytes(const AttentionParameters* attn) {
|
||||
return GetAttentionWorkspaceSize(
|
||||
sizeof(T),
|
||||
attn->batch_size,
|
||||
attn->num_heads,
|
||||
attn->head_size,
|
||||
attn->sequence_length,
|
||||
attn->past_sequence_length);
|
||||
}
|
||||
|
||||
inline static Status Gemm1(const GemmSoftmaxGemmPermuteParams<T>* params) {
|
||||
auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch();
|
||||
auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params);
|
||||
// GEMM1 [m,k] * [n,k]' -> [m,n]
|
||||
return blas::row_major::StridedBatchedGemm(
|
||||
params->TuningContext(), params->Stream(), params->handle,
|
||||
blas::BlasOp::NonTrans, blas::BlasOp::Trans,
|
||||
m, n, k,
|
||||
// For raw attention mask, the scalar is moved to softmax computation.
|
||||
/*alpha=*/UseRawAttentionMask(params) ? 1.0f : params->scale,
|
||||
params->q_buffer, k, m * k,
|
||||
params->k_buffer, k, n * k,
|
||||
/*beta=*/0.0f,
|
||||
gemm1_out, n, m * n,
|
||||
batch);
|
||||
}
|
||||
|
||||
inline static Status SoftmaxRawMask(const GemmSoftmaxGemmPermuteParams<T>* params, bool use_persistent_softmax) {
|
||||
// Softmax on [m,n] along the n dimension.
|
||||
// Raw attention mask could be 2D (B,S) or 3D (B,S,T) or 4D(B,1,M,M), where M is the max sequence length.
|
||||
auto attn = params->attention;
|
||||
auto [buffer, sizes, strides] = GetRawMaskBufferAddrSizesAndStrides(params->mask_index_buffer, attn);
|
||||
auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params);
|
||||
T* persistent_softmax_workspace = gemm1_out; // replace Q*K' in place if persistent softmax is selected.
|
||||
return ComputeSoftmaxWithRawMask<T>(
|
||||
params->Stream(), attn->total_sequence_length, attn->sequence_length, attn->batch_size, attn->num_heads,
|
||||
strides, buffer, nullptr, params->bias_buffer, gemm1_out, softmax_out,
|
||||
attn->is_unidirectional, /* FIXME: this must not be attn.scale! */ params->scale,
|
||||
use_persistent_softmax, persistent_softmax_workspace, attn->mask_filter_value);
|
||||
}
|
||||
|
||||
inline static Status Softmax1DIndexMask(const GemmSoftmaxGemmPermuteParams<T>* params) {
|
||||
auto mask_1d = params->mask_index_buffer;
|
||||
auto mask_1d_size = params->mask_index_dims[0];
|
||||
// Softmax on [m,n] along the n dimension.
|
||||
// mask_index has 1D shape: either (batch_size) or (2*batch_size). Only the later one has start postions.
|
||||
auto attn = params->attention;
|
||||
auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params);
|
||||
const int* mask_start = (mask_1d_size > attn->batch_size) ? mask_1d + attn->batch_size : nullptr;
|
||||
return ComputeSoftmaxWithMask1D<T>(
|
||||
params->Stream(), attn->total_sequence_length, attn->sequence_length, attn->batch_size, attn->num_heads,
|
||||
mask_1d, mask_start, params->bias_buffer, gemm1_out, softmax_out, attn->is_unidirectional);
|
||||
}
|
||||
|
||||
inline static Status SoftmaxNoMask(const GemmSoftmaxGemmPermuteParams<T>* params) {
|
||||
// Softmax on [m,n] along the n dimension.
|
||||
auto attn = params->attention;
|
||||
auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params);
|
||||
return ComputeSoftmax<T>(
|
||||
params->Stream(), attn->total_sequence_length, attn->sequence_length, attn->batch_size, attn->num_heads,
|
||||
params->bias_buffer, gemm1_out, softmax_out, attn->is_unidirectional);
|
||||
}
|
||||
|
||||
inline static Status Gemm2(const GemmSoftmaxGemmPermuteParams<T>* params) {
|
||||
auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch();
|
||||
auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params);
|
||||
// GEMM2 [m,n] * [n,o] -> [m,o]
|
||||
// semantically, the output buffer contains B*N matrices of shape [S,H], compactly, thus B,N,S,H.
|
||||
return blas::row_major::StridedBatchedGemm(
|
||||
params->TuningContext(), params->Stream(), params->handle,
|
||||
blas::BlasOp::NonTrans, blas::BlasOp::NonTrans,
|
||||
m, o, n,
|
||||
/*alpha=*/1.0f,
|
||||
softmax_out, n, m * n,
|
||||
params->v_buffer, o, n * o,
|
||||
/*beta=*/0.0f,
|
||||
gemm2_out, o, m * o,
|
||||
batch);
|
||||
}
|
||||
|
||||
inline static Status Permute0213(const GemmSoftmaxGemmPermuteParams<T>* params) {
|
||||
// Permute 0213
|
||||
// gemm2_out is B,N,S,H, transpose to out_buffer as B,S,N,H
|
||||
auto attn = params->attention;
|
||||
auto [gemm1_out, softmax_out, gemm2_out] = GetWorkspacePlan(params);
|
||||
return LaunchTransCtx(
|
||||
params->Stream(),
|
||||
attn->sequence_length, attn->batch_size, attn->head_size, attn->num_heads,
|
||||
params->device_prop->maxThreadsPerBlock, false, gemm2_out, params->out_buffer);
|
||||
}
|
||||
|
||||
static Status Run(const GemmSoftmaxGemmPermuteParams<T>* params, bool use_persistent_softmax) {
|
||||
ORT_RETURN_IF_ERROR(Gemm1(params));
|
||||
|
||||
if (UseRawAttentionMask(params)) {
|
||||
ORT_RETURN_IF_ERROR(SoftmaxRawMask(params, use_persistent_softmax));
|
||||
} else if (params->mask_index_dims.size() == 1) { // 1d index mask
|
||||
ORT_RETURN_IF_ERROR(Softmax1DIndexMask(params));
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(SoftmaxNoMask(params));
|
||||
}
|
||||
|
||||
ORT_RETURN_IF_ERROR(Gemm2(params));
|
||||
ORT_RETURN_IF_ERROR(Permute0213(params));
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue