Zhanyao/attention (#10545)

* Enable Attention op for ROCM EP.

As a note, potential hipify improvements: (1) handle math
contants (attention_softmax.h), (2) correctly generate transpose
options for the GEMM helpers, consider counterpart/dummy API for
CublasMathModeSetter (attention_impl.cu, attention_impl.cu). After
these improvements, we don't need to manually keep copies of the
above mentioned files any more.

* Clean up debugging code.
This commit is contained in:
zhangyaobit 2022-02-17 09:02:45 -08:00 committed by GitHub
parent 8d06e5a9df
commit fd16085cea
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 773 additions and 7 deletions

View file

@ -0,0 +1,126 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/rocm/bert/attention.h"
#include "contrib_ops/rocm/bert/attention_impl.h"
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/shared_inc/fpgeneric.h"
using namespace onnxruntime::rocm;
using namespace ::onnxruntime::common;
using namespace ONNX_NAMESPACE;
namespace onnxruntime {
namespace contrib {
namespace rocm {
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
Attention, \
kMSDomain, \
1, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
Attention<T>);
REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
template <typename T>
Attention<T>::Attention(const OpKernelInfo& info) : RocmKernel(info), AttentionBase(info) {}
template <typename T>
Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* input = context->Input<Tensor>(0);
const Tensor* weights = context->Input<Tensor>(1);
const Tensor* bias = context->Input<Tensor>(2);
const Tensor* mask_index = context->Input<Tensor>(3);
const Tensor* past = context->Input<Tensor>(4);
const Tensor* extra_add_qk = context->Input<Tensor>(5);
auto& device_prop = GetDeviceProp();
ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), mask_index, past, extra_add_qk, device_prop.maxThreadsPerBlock));
// input shape (batch_size, sequence_length, input_hidden_size)
const auto& shape = input->Shape();
int batch_size = static_cast<int>(shape[0]);
int sequence_length = static_cast<int>(shape[1]);
int input_hidden_size = static_cast<int>(shape[2]);
// bias shape (3 * hidden_size)
const auto& bias_shape = bias->Shape();
int hidden_size = static_cast<int>(bias_shape[0]) / 3;
int head_size = hidden_size / num_heads_;
TensorShapeVector output_shape(3);
output_shape[0] = shape[0];
output_shape[1] = shape[1];
output_shape[2] = static_cast<int64_t>(hidden_size);
Tensor* output = context->Output(0, output_shape);
int past_sequence_length = 0;
Tensor* present = GetPresent(context, past, batch_size, head_size, sequence_length, past_sequence_length);
rocblas_handle rocblas = RocblasHandle();
constexpr size_t element_size = sizeof(T);
// Use GEMM for fully connection.
int m = batch_size * sequence_length;
int n = 3 * hidden_size;
int k = input_hidden_size;
auto gemm_buffer = GetScratchBuffer<T>(batch_size * sequence_length * 3 * hidden_size * element_size);
typedef typename ToHipType<T>::MappedType HipT;
HipT one = ToHipType<T>::FromFloat(1.0f);
HipT zero = ToHipType<T>::FromFloat(0.0f);
// Bias shape is (N), broadcast using B(N, M) = 1 * bias(N, 1) x ones(1, M) + 0 * B.
// TODO: use custom kernel of expand to improve the performance.
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, 1, &one,
reinterpret_cast<const HipT*>(bias->template Data<T>()), n,
GetConstOnes<HipT>(m), 1,
&zero, reinterpret_cast<HipT*>(gemm_buffer.get()), n));
// Gemm, note that ROCM assumes col-major, so result(N, M) = 1 * weights x input + 1 x B.
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, k, &one,
reinterpret_cast<const HipT*>(weights->template Data<T>()), n,
reinterpret_cast<const HipT*>(input->template Data<T>()), k,
&one, reinterpret_cast<HipT*>(gemm_buffer.get()), n));
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, num_heads_, head_size, sequence_length, past_sequence_length);
auto temp_buffer = GetScratchBuffer<void>(workSpaceSize);
if (!LaunchAttentionKernel(
device_prop,
Stream(),
reinterpret_cast<const HipT*>(gemm_buffer.get()),
nullptr == mask_index ? nullptr : mask_index->template Data<int>(),
nullptr == mask_index ? gsl::span<const int64_t>() : mask_index->Shape().GetDims(),
output->template MutableData<T>(),
batch_size,
sequence_length,
num_heads_,
head_size,
temp_buffer.get(),
rocblas,
element_size,
is_unidirectional_,
past_sequence_length,
nullptr == past ? nullptr : past->template Data<T>(),
nullptr == extra_add_qk ? nullptr : extra_add_qk->template Data<T>(),
nullptr == present ? nullptr : present->template MutableData<T>())) {
// Get last error to reset it to hipSuccess.
HIP_CALL(hipGetLastError());
return Status(common::ONNXRUNTIME, common::FAIL);
}
return Status::OK();
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,213 @@
/*
The implementation of this file is based on qkvToContext plugin in TensorRT demo:
https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/
Copyright 2019 NVIDIA Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Modifications: scaling is moved from masked softmax to the gemm before that.
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <hip/hip_fp16.h>
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/shared_inc/fpgeneric.h"
#include "contrib_ops/rocm/bert/attention_impl.h"
#include "contrib_ops/rocm/bert/attention_softmax.h"
#include "contrib_ops/rocm/bert/transformer_common.h"
using namespace onnxruntime::rocm;
using namespace hipcub;
#define CHECK_ROCM(expr) \
if (!HIP_CALL(expr)) { \
return false; \
}
namespace onnxruntime {
namespace contrib {
namespace rocm {
static size_t AlignTo(size_t a, size_t b) {
return CeilDiv(a, b) * b;
}
size_t GetAttentionScratchSize(size_t element_size, int batch_size, int num_heads, int sequence_length, int all_sequence_length) {
const size_t len = batch_size * num_heads * sequence_length * all_sequence_length;
const size_t bytes = len * element_size;
const size_t alignment = 256;
const size_t bytesAligned = AlignTo(bytes, alignment);
return bytesAligned;
}
size_t GetAttentionWorkspaceSize(
size_t element_size,
int batch_size,
int num_heads,
int head_size,
int sequence_length,
int past_sequence_length) {
size_t qkv_size = 3 * batch_size * sequence_length * num_heads * head_size * element_size;
return qkv_size + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, past_sequence_length + sequence_length);
}
template <typename T>
bool QkvToContext(
const hipDeviceProp_t& prop, rocblas_handle& rocblas, hipStream_t stream,
const int batch_size, const int sequence_length, const int num_heads, const int head_size, const size_t element_size,
const T* input, T* output, T* workspace,
const int* mask_index, gsl::span<const int64_t> mask_index_dims,
bool is_unidirectional, int past_sequence_length, const T* past, const T* extra_add_qk, T* present, bool use_persistent_softmax) {
const int all_sequence_length = past_sequence_length + sequence_length;
const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, all_sequence_length);
T* scratch1 = workspace;
T* scratch2 = scratch1 + (bytes / element_size);
T* scratch3 = scratch2 + (bytes / element_size);
const int max_threads_per_block = prop.maxThreadsPerBlock;
// input should be BxSx3xNxH => scratch3: 3xBxNxSxH
if (!LaunchTransQkv(stream, 3, sequence_length, batch_size, head_size, num_heads, max_threads_per_block, false, input, scratch3)) {
return false;
}
// now scratch3 has Q, K, V: each has size BxNxSxH
const int batches = batch_size * num_heads;
const int size_per_batch = sequence_length * head_size;
const int total_size = batches * size_per_batch;
const T* q = scratch3;
const T* k = q + total_size;
const T* v = k + total_size;
rocblas_set_stream(rocblas, stream);
// Concat past (2xBxNxS'xH) to present (2xBxNxS*xH):
// past_k (BxNxS'xH) + k (BxNxSxH) => present_k (BxNxS*xH)
// past_v (BxNxS'xH) + v (BxNxSxH) => present_v (BxNxS*xH)
const int present_size_per_batch = all_sequence_length * head_size;
if (nullptr != present) {
if (!LaunchConcatPastToPresent(stream, all_sequence_length, sequence_length, batch_size, head_size, num_heads, max_threads_per_block, past, k, present)) {
return false;
}
// update pointers to present_k and present_v.
k = present;
v = present + batches * present_size_per_batch;
}
// Raw attention mask could be 2D (BxS) or 3D (BxSxS*) or 4D(Bx1xMxM), where M is the max sequence length.
bool use_raw_attention_mask = (nullptr != mask_index && mask_index_dims.size() >= 2);
// compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxS*
// Q: BxNxSxH, K (present_k): BxNxS*xH, Q*K': BxNxSxS*
const float rsqrt_head_size = 1.f / sqrt(static_cast<float>(head_size));
const int temp_matrix_size = sequence_length * all_sequence_length;
typedef typename ToHipType<T>::MappedType HipT;
//float one = 1.0f;
//float zero = 0.f;
const HipT one = ToHipType<T>::FromFloat(1.0f);
const HipT zero = ToHipType<T>::FromFloat(0.f);
// For raw attention mask, the scalar if 1/sqrt(H) is moved to softmax computation.
//float temp_alpha = use_raw_attention_mask ? one : rsqrt_head_size;
const HipT alpha = use_raw_attention_mask ? one : ToHipType<T>::FromFloat(rsqrt_head_size);
if (!ROCBLAS_CALL(rocblasGemmStridedBatchedHelper(
rocblas, rocblas_operation_transpose, rocblas_operation_none, all_sequence_length, sequence_length, head_size, &alpha, k, head_size, present_size_per_batch,
q, head_size, size_per_batch, &zero, scratch1, all_sequence_length, temp_matrix_size, batches))) {
return false;
}
// apply softmax and store result P to scratch2: BxNxSxS*
if (use_raw_attention_mask) { // 2d, 3d or 4d attention mask
const int mask_dimension = static_cast<int>(mask_index_dims.size());
const int64_t max_sequence_length = mask_dimension == 4 ? mask_index_dims.at(3) : 0;
T* persistent_softmax_workspace = scratch1; // replace Q*K' in place with masked score if persistent softmax is selected.
if (!ComputeSoftmaxWithRawMask<T>(stream, all_sequence_length, sequence_length, batch_size, num_heads, mask_index, nullptr, extra_add_qk, scratch1, scratch2,
is_unidirectional, rsqrt_head_size, mask_dimension, static_cast<int>(max_sequence_length),
use_persistent_softmax, persistent_softmax_workspace)) {
return false;
}
} else if (nullptr != mask_index) { // 1d mask index
ORT_ENFORCE(mask_index_dims.size() == 1);
// mask_index has 1D shape: either (batch_size) or (2*batch_size). Only the later one has start postions.
const int* mask_start = (mask_index_dims.at(0) > batch_size) ? mask_index + batch_size : nullptr;
if (!ComputeSoftmaxWithMask1D<T>(stream, all_sequence_length, sequence_length, batch_size, num_heads, mask_index, mask_start, extra_add_qk, scratch1, scratch2, is_unidirectional)) {
return false;
}
} else { // no mask
if (!ComputeSoftmax<T>(stream, all_sequence_length, sequence_length, batch_size, num_heads, extra_add_qk, scratch1, scratch2, is_unidirectional)) {
return false;
}
}
// compute P*V (as V*P), and store in scratch3: BxNxSxH
if (!ROCBLAS_CALL(rocblasGemmStridedBatchedHelper(
rocblas, rocblas_operation_none, rocblas_operation_none, head_size, sequence_length, all_sequence_length, &one, v, head_size, present_size_per_batch,
scratch2, all_sequence_length, temp_matrix_size, &zero, scratch3, head_size, size_per_batch, batches))) {
return false;
}
// scratch3 is BxNxSxH, transpose to output BxSxNxH
return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads, max_threads_per_block, false, scratch3, output);
}
bool LaunchAttentionKernel(
const hipDeviceProp_t& prop,
hipStream_t stream,
const void* input,
const int* mask_index,
gsl::span<const int64_t> mask_index_dims,
void* output,
const int batch_size,
const int sequence_length,
const int num_heads,
const int head_size,
void* workspace,
rocblas_handle& rocblas,
const size_t element_size,
bool is_unidirectional,
int past_sequence_length,
const void* past,
const void* extra_add_qk,
void* present) {
// For testing, environment variable ORT_TRANSFORMER_OPTIONS=1 could enable persistent softmax
const TransformerOptions* options = TransformerOptions::GetInstance();
bool use_persistent_softmax = options->IsPrecisionMode() && !options->DisablePersistentSoftmax();
if (element_size == 2) {
return QkvToContext(prop, rocblas, stream,
batch_size, sequence_length, num_heads, head_size, element_size,
reinterpret_cast<const __half*>(input), reinterpret_cast<__half*>(output), reinterpret_cast<__half*>(workspace),
mask_index, mask_index_dims, is_unidirectional,
past_sequence_length, reinterpret_cast<const __half*>(past), reinterpret_cast<const __half*>(extra_add_qk),
reinterpret_cast<__half*>(present), use_persistent_softmax);
} else {
return QkvToContext(prop, rocblas, stream,
batch_size, sequence_length, num_heads, head_size, element_size,
reinterpret_cast<const float*>(input), reinterpret_cast<float*>(output), reinterpret_cast<float*>(workspace),
mask_index, mask_index_dims, is_unidirectional,
past_sequence_length, reinterpret_cast<const float*>(past), reinterpret_cast<const float*>(extra_add_qk),
reinterpret_cast<float*>(present), use_persistent_softmax);
}
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,423 @@
#include "hip/hip_runtime.h"
/*
The implementation of this file is based on qkvToContext plugin in TensorRT demo:
https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/
Copyright 2019 NVIDIA Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#pragma once
#include <type_traits>
#include <hipcub/hipcub.hpp>
#include <hip/hip_fp16.h>
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/math/softmax.h"
#define ROCMRT_INF_F __int_as_float(0x7f800000)
using namespace onnxruntime::rocm;
using namespace hipcub;
namespace onnxruntime {
namespace contrib {
namespace rocm {
template <typename T, unsigned TPB>
__device__ inline void Softmax(const int all_sequence_length,
const int sequence_length,
const int valid_end,
const int valid_start,
const T* add_before_softmax,
const T* input,
T* output) {
using BlockReduce = hipcub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmp_storage;
__shared__ float sum_reverse_block;
__shared__ float max_block;
float thread_data_max(-ROCMRT_INF_F);
// e^x is represented as infinity if x is large enough, like 100.f.
// Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough.
// a math transform as below is leveraged to get a stable softmax:
// e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max))
const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length;
for (int i = threadIdx.x; i < valid_end; i += TPB) {
if (i >= valid_start) {
const int index = offset + i;
float input_at_idx = add_before_softmax == nullptr ? float(input[index]) : float(input[index] + add_before_softmax[index]);
if (thread_data_max < input_at_idx) {
thread_data_max = input_at_idx;
}
}
}
const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, hipcub::Max());
// Store max value
if (threadIdx.x == 0) {
max_block = max;
}
__syncthreads();
float thread_data_sum(0.f);
for (int i = threadIdx.x; i < valid_end; i += TPB) {
if (i >= valid_start) {
const int index = offset + i;
float val = add_before_softmax == nullptr ? input[index] : input[index] + add_before_softmax[index];
thread_data_sum += expf(val - max_block);
}
}
const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_sum, hipcub::Sum());
if (threadIdx.x == 0) {
sum_reverse_block = 1.f / sum;
}
__syncthreads();
for (int i = threadIdx.x; i < all_sequence_length; i += TPB) {
const int index = offset + i;
float input_at_idx = add_before_softmax == nullptr ? float(input[index]) : float(input[index] + add_before_softmax[index]);
const float val = (i >= valid_start && i < valid_end) ? expf(input_at_idx - max_block) * sum_reverse_block : 0.f;
output[index] = T(val);
}
}
template <typename T, unsigned TPB>
__device__ inline void SoftmaxSmall(const int all_sequence_length,
const int sequence_length,
const int valid_end,
const int valid_start,
const T* add_before_softmax,
const T* input,
T* output,
bool is_unidirectional) {
using BlockReduce = hipcub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmp_storage;
__shared__ float sum_reverse_block;
__shared__ float max_block;
// Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S;
const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length;
const int index = offset + threadIdx.x;
bool is_valid = false; // whether it has attention mask == 1.
// Update end position for unidirectional.
int end = valid_end;
if (is_unidirectional) {
int end_unid = all_sequence_length - sequence_length + (blockIdx.x % sequence_length) + 1;
if (end_unid <= valid_start) {
// In this situation, mask of [0, end_unid) and [valid_start, valid_end) has -10000, and [end_unid, valid_start) and [valid_end, all_seq_len) has -20000.
// So [0, end_unid) will also have value after softmax.
is_valid = threadIdx.x < end_unid;
} else {
end = min(valid_end, end_unid);
}
}
is_valid = is_valid || (threadIdx.x >= valid_start && threadIdx.x < end);
// e^x is represented as infinity if x is large enough, like 100.f.
// Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough.
// a math transform as below is leveraged to get a stable softmax:
// e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max))
float input_data = add_before_softmax == nullptr ? float(input[index]) : float(input[index] + add_before_softmax[index]);
float thread_data_max = is_valid ? input_data : float(-ROCMRT_INF_F);
const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, hipcub::Max(), end);
// Store max value
if (threadIdx.x == 0) {
max_block = max;
}
__syncthreads();
float thread_data_exp(0.f);
if (is_valid) {
thread_data_exp = expf(input_data - max_block);
}
const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, hipcub::Sum(), end);
// Store value of 1.0/sum.
if (threadIdx.x == 0) {
sum_reverse_block = (1.f) / sum;
}
__syncthreads();
// threadIdx.x might be larger than all_sequence_length due to alignment to 32x.
if (threadIdx.x < all_sequence_length) {
output[index] = T(thread_data_exp * sum_reverse_block);
}
}
template <typename T, unsigned TPB>
__device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length,
const int sequence_length,
const int* attention_mask, // 2D, 3D or 4D attention mask
const bool* key_padding_mask,
const T* add_before_softmax,
const T* input,
T* output,
const bool is_unidirectional,
const float rsqrt_head_size,
const int mask_dimension,
const int max_sequence_length,
const bool skip_softmax) {
using BlockReduce = hipcub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmp_storage;
__shared__ float sum_reverse_block;
__shared__ float max_block;
// Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S;
int index = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length + threadIdx.x;
float thread_data = -ROCMRT_INF_F;
if (threadIdx.x < all_sequence_length) {
if (add_before_softmax == nullptr) {
thread_data = float(input[index]) * rsqrt_head_size;
} else {
thread_data = float(input[index] + add_before_softmax[index]) * rsqrt_head_size;
}
const int sequence_index = blockIdx.x % sequence_length;
if (is_unidirectional) {
int from_index = all_sequence_length - sequence_length + sequence_index; // offset of from token in all sequence length.
if (threadIdx.x > from_index) {
thread_data = -10000.0f;
}
}
int mask_offset = 0;
const int batch_index = blockIdx.y;
if (mask_dimension == 2) {
mask_offset = batch_index * all_sequence_length + threadIdx.x;
} else if (mask_dimension == 3) {
mask_offset = (batch_index * sequence_length + sequence_index) * all_sequence_length + threadIdx.x;
} else if (mask_dimension == 4) {
mask_offset = (batch_index * max_sequence_length + all_sequence_length - sequence_length + sequence_index) * max_sequence_length + threadIdx.x;
}
if (nullptr == key_padding_mask) {
const int& mask = attention_mask[mask_offset];
if (mask == 0)
thread_data += -10000.0f;
} else {
const bool mask = key_padding_mask[mask_offset];
if (mask) {
thread_data = -ROCMRT_INF_F;
}
}
}
if (skip_softmax) {
if (threadIdx.x < all_sequence_length) {
output[index] = T(thread_data);
}
return;
}
const float max = BlockReduce(tmp_storage).Reduce(thread_data, hipcub::Max(), all_sequence_length);
// Store max value
if (threadIdx.x == 0) {
max_block = max;
}
__syncthreads();
float thread_data_exp = threadIdx.x < all_sequence_length ? expf(thread_data - max_block) : 0.0f;
const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, hipcub::Sum(), all_sequence_length);
// Store value of 1.0/sum
if (threadIdx.x == 0) {
sum_reverse_block = (1.f) / sum;
}
__syncthreads();
if (threadIdx.x < all_sequence_length) {
output[index] = T(thread_data_exp * sum_reverse_block);
}
}
template <typename T, unsigned TPB>
__global__ void SoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, const T* add_before_softmax, const T* input, T* output, bool is_unidirectional) {
SoftmaxSmall<T, TPB>(all_sequence_length, sequence_length, all_sequence_length, 0, add_before_softmax, input, output, is_unidirectional);
}
template <typename T, unsigned TPB>
__global__ void SoftmaxKernel(const int all_sequence_length, const int sequence_length, const T* add_before_softmax, const T* input, T* output) {
Softmax<T, TPB>(all_sequence_length, sequence_length, all_sequence_length, 0, add_before_softmax, input, output);
}
template <typename T>
bool ComputeSoftmax(
hipStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads,
const T* add_before_softmax, const T* input, T* output, bool is_unidirectional) {
const dim3 grid(sequence_length * num_heads, batch_size, 1);
if (all_sequence_length <= 32) {
const int blockSize = 32;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional);
} else if (all_sequence_length <= 64) {
const int blockSize = 64;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional);
} else if (all_sequence_length <= 128) {
const int blockSize = 128;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional);
} else if (all_sequence_length <= 256) {
const int blockSize = 256;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional);
} else if (all_sequence_length <= 512) {
const int blockSize = 512;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional);
} else if (all_sequence_length <= 1024) {
const int blockSize = 1024;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional);
} else if (!is_unidirectional) {
const int blockSize = 1024;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernel<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, add_before_softmax, input, output);
} else {
ORT_THROW("Attention ROCM operator does not support total sequence length > 1024.");
}
return HIP_CALL(hipPeekAtLastError());
}
template <typename T, unsigned TPB>
__global__ void MaskedSoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, const int* mask_end, const int* mask_start, const T* add_before_softmax, const T* input, T* output, bool is_unidirectional) {
__shared__ int start_position;
__shared__ int end_position;
if (threadIdx.x == 0) {
const int batch = blockIdx.y;
start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0;
end_position = min(all_sequence_length, mask_end[batch]);
// Attend to no word has same effect as attend to all words. This is added to get parity with CPU result.
if (start_position >= end_position) {
start_position = 0;
end_position = all_sequence_length;
}
}
__syncthreads();
SoftmaxSmall<T, TPB>(all_sequence_length, sequence_length, end_position, start_position, add_before_softmax, input, output, is_unidirectional);
}
template <typename T, unsigned TPB>
__global__ void MaskedSoftmaxKernel(const int all_sequence_length, const int sequence_length, const int* mask_end, const int* mask_start, const T* add_before_softmax, const T* input, T* output) {
__shared__ int start_position;
__shared__ int end_position;
if (threadIdx.x == 0) {
const int batch = blockIdx.y;
start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0;
end_position = min(all_sequence_length, mask_end[batch]);
// Attend to no word has same effect as attend to all words. This is added to get parity with CPU result.
if (start_position >= end_position) {
start_position = 0;
end_position = all_sequence_length;
}
}
__syncthreads();
Softmax<T, TPB>(all_sequence_length, sequence_length, end_position, start_position, add_before_softmax, input, output);
}
template <typename T, unsigned TPB>
__global__ void SoftmaxWithRawMaskSmallKernel(const int all_sequence_length, const int sequence_length, const int* attention_mask, const bool* key_padding_mask, const T* add_before_softmax, const T* input,
T* output, const bool is_unidirectional, const float rsqrt_head_size, const int mask_dimension, const int max_sequence_length,
const bool skip_softmax) {
SoftmaxWithRawMaskSmall<T, TPB>(all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, output, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, skip_softmax);
}
template <typename T>
bool ComputeSoftmaxWithMask1D(hipStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads,
const int* mask_index, const int* mask_start, const T* add_before_softmax, const T* input, T* output, const bool is_unidirectional) {
const dim3 grid(sequence_length * num_heads, batch_size, 1);
if (all_sequence_length <= 32) {
const int blockSize = 32;
hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output, is_unidirectional);
} else if (all_sequence_length <= 64) {
const int blockSize = 64;
hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output, is_unidirectional);
} else if (all_sequence_length <= 128) {
const int blockSize = 128;
hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output, is_unidirectional);
} else if (all_sequence_length <= 256) {
const int blockSize = 256;
hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output, is_unidirectional);
} else if (all_sequence_length <= 512) {
const int blockSize = 512;
hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output, is_unidirectional);
} else if (all_sequence_length <= 1024) {
const int blockSize = 1024;
hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output, is_unidirectional);
} else if (!is_unidirectional) {
const int blockSize = 1024;
hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernel<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output);
} else {
ORT_THROW("Attention ROCM operator does not support total sequence length > 1024.");
}
return HIP_CALL(hipPeekAtLastError());
}
template <typename T>
bool ComputeSoftmaxWithRawMask(hipStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads,
const int* attention_mask, const bool* key_padding_mask, const T* add_before_softmax, const T* input, T* output, const bool is_unidirectional,
const float rsqrt_head_size, const int mask_dimension, const int max_sequence_length, const bool use_persistent_softmax, T* persistent_softmax_workspace) {
const dim3 grid(sequence_length * num_heads, batch_size, 1);
T* out = use_persistent_softmax ? persistent_softmax_workspace : output;
if (all_sequence_length <= 32) {
const int blockSize = 32;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax);
} else if (all_sequence_length <= 64) {
const int blockSize = 64;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax);
} else if (all_sequence_length <= 128) {
const int blockSize = 128;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax);
} else if (all_sequence_length <= 256) {
const int blockSize = 256;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax);
} else if (all_sequence_length <= 512) {
const int blockSize = 512;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax);
} else if (all_sequence_length <= 1024) {
const int blockSize = 1024;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax);
} else {
ORT_THROW("Attention ROCM operator does not support total sequence length > 1024.");
}
if (use_persistent_softmax) {
dispatch_warpwise_softmax_forward<T, T, float, false>(stream, output, persistent_softmax_workspace, all_sequence_length, all_sequence_length, batch_size * num_heads * sequence_length);
}
return HIP_CALL(hipPeekAtLastError());
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -125,8 +125,8 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, Affine)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, Affine)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, Affine)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Attention)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Attention)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ConvTransposeWithDynamicPads)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, Crop)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, Crop)>,

View file

@ -46,10 +46,11 @@ static void RunAttentionTest(
int min_cuda_architecture = use_float16 ? 530 : 0;
bool enable_cuda = HasCudaEnvironment(min_cuda_architecture) && !is_weights_constant && !only_enable_cpu;
bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()) && !is_weights_constant && !only_enable_cpu;
bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()) && !use_float16 && !only_enable_cuda;
int head_size = hidden_size / number_of_heads;
if (enable_cpu || enable_cuda) {
if (enable_cpu || enable_cuda || enable_rocm) {
OpTester tester("Attention", 1, onnxruntime::kMSDomain);
tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(number_of_heads));
tester.AddAttribute<int64_t>("unidirectional", static_cast<int64_t>(is_unidirectional ? 1 : 0));
@ -151,6 +152,12 @@ static void RunAttentionTest(
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (enable_rocm) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultRocmExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (enable_cpu) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());

View file

@ -16,11 +16,8 @@ training_ops_path = 'orttraining/orttraining/training_ops'
contrib_ops_excluded_files = [
'bert/attention.cc',
'bert/attention.h',
'bert/attention_impl.cu',
'bert/attention_impl.h',
'bert/attention_transpose.cu',
'bert/attention_concat.cu',
'bert/attention_softmax.h',
'bert/decoder_attention.h',
'bert/decoder_attention.cc',
'bert/embed_layer_norm.cc',