mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-01 03:45:06 +00:00
Zhanyao/attention (#10545)
* Enable Attention op for ROCM EP. As a note, potential hipify improvements: (1) handle math contants (attention_softmax.h), (2) correctly generate transpose options for the GEMM helpers, consider counterpart/dummy API for CublasMathModeSetter (attention_impl.cu, attention_impl.cu). After these improvements, we don't need to manually keep copies of the above mentioned files any more. * Clean up debugging code.
This commit is contained in:
parent
8d06e5a9df
commit
fd16085cea
6 changed files with 773 additions and 7 deletions
126
onnxruntime/contrib_ops/rocm/bert/attention.cc
Normal file
126
onnxruntime/contrib_ops/rocm/bert/attention.cc
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "contrib_ops/rocm/bert/attention.h"
|
||||
#include "contrib_ops/rocm/bert/attention_impl.h"
|
||||
#include "core/providers/rocm/rocm_common.h"
|
||||
#include "core/providers/rocm/shared_inc/fpgeneric.h"
|
||||
|
||||
using namespace onnxruntime::rocm;
|
||||
using namespace ::onnxruntime::common;
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
#define REGISTER_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
Attention, \
|
||||
kMSDomain, \
|
||||
1, \
|
||||
T, \
|
||||
kRocmExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
Attention<T>);
|
||||
|
||||
REGISTER_KERNEL_TYPED(float)
|
||||
REGISTER_KERNEL_TYPED(MLFloat16)
|
||||
|
||||
template <typename T>
|
||||
Attention<T>::Attention(const OpKernelInfo& info) : RocmKernel(info), AttentionBase(info) {}
|
||||
|
||||
template <typename T>
|
||||
Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
||||
const Tensor* input = context->Input<Tensor>(0);
|
||||
const Tensor* weights = context->Input<Tensor>(1);
|
||||
const Tensor* bias = context->Input<Tensor>(2);
|
||||
const Tensor* mask_index = context->Input<Tensor>(3);
|
||||
const Tensor* past = context->Input<Tensor>(4);
|
||||
const Tensor* extra_add_qk = context->Input<Tensor>(5);
|
||||
|
||||
auto& device_prop = GetDeviceProp();
|
||||
ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), mask_index, past, extra_add_qk, device_prop.maxThreadsPerBlock));
|
||||
|
||||
// input shape (batch_size, sequence_length, input_hidden_size)
|
||||
const auto& shape = input->Shape();
|
||||
int batch_size = static_cast<int>(shape[0]);
|
||||
int sequence_length = static_cast<int>(shape[1]);
|
||||
int input_hidden_size = static_cast<int>(shape[2]);
|
||||
|
||||
// bias shape (3 * hidden_size)
|
||||
const auto& bias_shape = bias->Shape();
|
||||
int hidden_size = static_cast<int>(bias_shape[0]) / 3;
|
||||
|
||||
int head_size = hidden_size / num_heads_;
|
||||
|
||||
TensorShapeVector output_shape(3);
|
||||
output_shape[0] = shape[0];
|
||||
output_shape[1] = shape[1];
|
||||
output_shape[2] = static_cast<int64_t>(hidden_size);
|
||||
Tensor* output = context->Output(0, output_shape);
|
||||
|
||||
int past_sequence_length = 0;
|
||||
Tensor* present = GetPresent(context, past, batch_size, head_size, sequence_length, past_sequence_length);
|
||||
|
||||
rocblas_handle rocblas = RocblasHandle();
|
||||
constexpr size_t element_size = sizeof(T);
|
||||
|
||||
// Use GEMM for fully connection.
|
||||
int m = batch_size * sequence_length;
|
||||
int n = 3 * hidden_size;
|
||||
int k = input_hidden_size;
|
||||
auto gemm_buffer = GetScratchBuffer<T>(batch_size * sequence_length * 3 * hidden_size * element_size);
|
||||
|
||||
typedef typename ToHipType<T>::MappedType HipT;
|
||||
HipT one = ToHipType<T>::FromFloat(1.0f);
|
||||
HipT zero = ToHipType<T>::FromFloat(0.0f);
|
||||
|
||||
// Bias shape is (N), broadcast using B(N, M) = 1 * bias(N, 1) x ones(1, M) + 0 * B.
|
||||
// TODO: use custom kernel of expand to improve the performance.
|
||||
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
|
||||
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, 1, &one,
|
||||
reinterpret_cast<const HipT*>(bias->template Data<T>()), n,
|
||||
GetConstOnes<HipT>(m), 1,
|
||||
&zero, reinterpret_cast<HipT*>(gemm_buffer.get()), n));
|
||||
|
||||
// Gemm, note that ROCM assumes col-major, so result(N, M) = 1 * weights x input + 1 x B.
|
||||
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
|
||||
rocblas, rocblas_operation_none, rocblas_operation_none, n, m, k, &one,
|
||||
reinterpret_cast<const HipT*>(weights->template Data<T>()), n,
|
||||
reinterpret_cast<const HipT*>(input->template Data<T>()), k,
|
||||
&one, reinterpret_cast<HipT*>(gemm_buffer.get()), n));
|
||||
|
||||
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, num_heads_, head_size, sequence_length, past_sequence_length);
|
||||
auto temp_buffer = GetScratchBuffer<void>(workSpaceSize);
|
||||
if (!LaunchAttentionKernel(
|
||||
device_prop,
|
||||
Stream(),
|
||||
reinterpret_cast<const HipT*>(gemm_buffer.get()),
|
||||
nullptr == mask_index ? nullptr : mask_index->template Data<int>(),
|
||||
nullptr == mask_index ? gsl::span<const int64_t>() : mask_index->Shape().GetDims(),
|
||||
output->template MutableData<T>(),
|
||||
batch_size,
|
||||
sequence_length,
|
||||
num_heads_,
|
||||
head_size,
|
||||
temp_buffer.get(),
|
||||
rocblas,
|
||||
element_size,
|
||||
is_unidirectional_,
|
||||
past_sequence_length,
|
||||
nullptr == past ? nullptr : past->template Data<T>(),
|
||||
nullptr == extra_add_qk ? nullptr : extra_add_qk->template Data<T>(),
|
||||
nullptr == present ? nullptr : present->template MutableData<T>())) {
|
||||
// Get last error to reset it to hipSuccess.
|
||||
HIP_CALL(hipGetLastError());
|
||||
return Status(common::ONNXRUNTIME, common::FAIL);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
213
onnxruntime/contrib_ops/rocm/bert/attention_impl.cu
Normal file
213
onnxruntime/contrib_ops/rocm/bert/attention_impl.cu
Normal file
|
|
@ -0,0 +1,213 @@
|
|||
/*
|
||||
The implementation of this file is based on qkvToContext plugin in TensorRT demo:
|
||||
https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/
|
||||
|
||||
Copyright 2019 NVIDIA Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
// Modifications: scaling is moved from masked softmax to the gemm before that.
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <hip/hip_fp16.h>
|
||||
#include "core/providers/rocm/cu_inc/common.cuh"
|
||||
#include "core/providers/rocm/rocm_common.h"
|
||||
#include "core/providers/rocm/shared_inc/fpgeneric.h"
|
||||
#include "contrib_ops/rocm/bert/attention_impl.h"
|
||||
#include "contrib_ops/rocm/bert/attention_softmax.h"
|
||||
#include "contrib_ops/rocm/bert/transformer_common.h"
|
||||
|
||||
using namespace onnxruntime::rocm;
|
||||
using namespace hipcub;
|
||||
|
||||
#define CHECK_ROCM(expr) \
|
||||
if (!HIP_CALL(expr)) { \
|
||||
return false; \
|
||||
}
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
static size_t AlignTo(size_t a, size_t b) {
|
||||
return CeilDiv(a, b) * b;
|
||||
}
|
||||
|
||||
size_t GetAttentionScratchSize(size_t element_size, int batch_size, int num_heads, int sequence_length, int all_sequence_length) {
|
||||
const size_t len = batch_size * num_heads * sequence_length * all_sequence_length;
|
||||
const size_t bytes = len * element_size;
|
||||
|
||||
const size_t alignment = 256;
|
||||
const size_t bytesAligned = AlignTo(bytes, alignment);
|
||||
return bytesAligned;
|
||||
}
|
||||
|
||||
size_t GetAttentionWorkspaceSize(
|
||||
size_t element_size,
|
||||
int batch_size,
|
||||
int num_heads,
|
||||
int head_size,
|
||||
int sequence_length,
|
||||
int past_sequence_length) {
|
||||
size_t qkv_size = 3 * batch_size * sequence_length * num_heads * head_size * element_size;
|
||||
return qkv_size + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, past_sequence_length + sequence_length);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool QkvToContext(
|
||||
const hipDeviceProp_t& prop, rocblas_handle& rocblas, hipStream_t stream,
|
||||
const int batch_size, const int sequence_length, const int num_heads, const int head_size, const size_t element_size,
|
||||
const T* input, T* output, T* workspace,
|
||||
const int* mask_index, gsl::span<const int64_t> mask_index_dims,
|
||||
bool is_unidirectional, int past_sequence_length, const T* past, const T* extra_add_qk, T* present, bool use_persistent_softmax) {
|
||||
const int all_sequence_length = past_sequence_length + sequence_length;
|
||||
const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, all_sequence_length);
|
||||
T* scratch1 = workspace;
|
||||
T* scratch2 = scratch1 + (bytes / element_size);
|
||||
T* scratch3 = scratch2 + (bytes / element_size);
|
||||
|
||||
const int max_threads_per_block = prop.maxThreadsPerBlock;
|
||||
|
||||
// input should be BxSx3xNxH => scratch3: 3xBxNxSxH
|
||||
if (!LaunchTransQkv(stream, 3, sequence_length, batch_size, head_size, num_heads, max_threads_per_block, false, input, scratch3)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// now scratch3 has Q, K, V: each has size BxNxSxH
|
||||
const int batches = batch_size * num_heads;
|
||||
const int size_per_batch = sequence_length * head_size;
|
||||
const int total_size = batches * size_per_batch;
|
||||
|
||||
const T* q = scratch3;
|
||||
const T* k = q + total_size;
|
||||
const T* v = k + total_size;
|
||||
|
||||
rocblas_set_stream(rocblas, stream);
|
||||
|
||||
// Concat past (2xBxNxS'xH) to present (2xBxNxS*xH):
|
||||
// past_k (BxNxS'xH) + k (BxNxSxH) => present_k (BxNxS*xH)
|
||||
// past_v (BxNxS'xH) + v (BxNxSxH) => present_v (BxNxS*xH)
|
||||
const int present_size_per_batch = all_sequence_length * head_size;
|
||||
if (nullptr != present) {
|
||||
if (!LaunchConcatPastToPresent(stream, all_sequence_length, sequence_length, batch_size, head_size, num_heads, max_threads_per_block, past, k, present)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// update pointers to present_k and present_v.
|
||||
k = present;
|
||||
v = present + batches * present_size_per_batch;
|
||||
}
|
||||
|
||||
// Raw attention mask could be 2D (BxS) or 3D (BxSxS*) or 4D(Bx1xMxM), where M is the max sequence length.
|
||||
bool use_raw_attention_mask = (nullptr != mask_index && mask_index_dims.size() >= 2);
|
||||
|
||||
// compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxS*
|
||||
// Q: BxNxSxH, K (present_k): BxNxS*xH, Q*K': BxNxSxS*
|
||||
const float rsqrt_head_size = 1.f / sqrt(static_cast<float>(head_size));
|
||||
const int temp_matrix_size = sequence_length * all_sequence_length;
|
||||
|
||||
typedef typename ToHipType<T>::MappedType HipT;
|
||||
|
||||
//float one = 1.0f;
|
||||
//float zero = 0.f;
|
||||
const HipT one = ToHipType<T>::FromFloat(1.0f);
|
||||
const HipT zero = ToHipType<T>::FromFloat(0.f);
|
||||
|
||||
// For raw attention mask, the scalar if 1/sqrt(H) is moved to softmax computation.
|
||||
//float temp_alpha = use_raw_attention_mask ? one : rsqrt_head_size;
|
||||
const HipT alpha = use_raw_attention_mask ? one : ToHipType<T>::FromFloat(rsqrt_head_size);
|
||||
|
||||
if (!ROCBLAS_CALL(rocblasGemmStridedBatchedHelper(
|
||||
rocblas, rocblas_operation_transpose, rocblas_operation_none, all_sequence_length, sequence_length, head_size, &alpha, k, head_size, present_size_per_batch,
|
||||
q, head_size, size_per_batch, &zero, scratch1, all_sequence_length, temp_matrix_size, batches))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// apply softmax and store result P to scratch2: BxNxSxS*
|
||||
if (use_raw_attention_mask) { // 2d, 3d or 4d attention mask
|
||||
const int mask_dimension = static_cast<int>(mask_index_dims.size());
|
||||
const int64_t max_sequence_length = mask_dimension == 4 ? mask_index_dims.at(3) : 0;
|
||||
|
||||
T* persistent_softmax_workspace = scratch1; // replace Q*K' in place with masked score if persistent softmax is selected.
|
||||
if (!ComputeSoftmaxWithRawMask<T>(stream, all_sequence_length, sequence_length, batch_size, num_heads, mask_index, nullptr, extra_add_qk, scratch1, scratch2,
|
||||
is_unidirectional, rsqrt_head_size, mask_dimension, static_cast<int>(max_sequence_length),
|
||||
use_persistent_softmax, persistent_softmax_workspace)) {
|
||||
return false;
|
||||
}
|
||||
} else if (nullptr != mask_index) { // 1d mask index
|
||||
ORT_ENFORCE(mask_index_dims.size() == 1);
|
||||
// mask_index has 1D shape: either (batch_size) or (2*batch_size). Only the later one has start postions.
|
||||
const int* mask_start = (mask_index_dims.at(0) > batch_size) ? mask_index + batch_size : nullptr;
|
||||
if (!ComputeSoftmaxWithMask1D<T>(stream, all_sequence_length, sequence_length, batch_size, num_heads, mask_index, mask_start, extra_add_qk, scratch1, scratch2, is_unidirectional)) {
|
||||
return false;
|
||||
}
|
||||
} else { // no mask
|
||||
if (!ComputeSoftmax<T>(stream, all_sequence_length, sequence_length, batch_size, num_heads, extra_add_qk, scratch1, scratch2, is_unidirectional)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// compute P*V (as V*P), and store in scratch3: BxNxSxH
|
||||
if (!ROCBLAS_CALL(rocblasGemmStridedBatchedHelper(
|
||||
rocblas, rocblas_operation_none, rocblas_operation_none, head_size, sequence_length, all_sequence_length, &one, v, head_size, present_size_per_batch,
|
||||
scratch2, all_sequence_length, temp_matrix_size, &zero, scratch3, head_size, size_per_batch, batches))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// scratch3 is BxNxSxH, transpose to output BxSxNxH
|
||||
return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads, max_threads_per_block, false, scratch3, output);
|
||||
}
|
||||
|
||||
bool LaunchAttentionKernel(
|
||||
const hipDeviceProp_t& prop,
|
||||
hipStream_t stream,
|
||||
const void* input,
|
||||
const int* mask_index,
|
||||
gsl::span<const int64_t> mask_index_dims,
|
||||
void* output,
|
||||
const int batch_size,
|
||||
const int sequence_length,
|
||||
const int num_heads,
|
||||
const int head_size,
|
||||
void* workspace,
|
||||
rocblas_handle& rocblas,
|
||||
const size_t element_size,
|
||||
bool is_unidirectional,
|
||||
int past_sequence_length,
|
||||
const void* past,
|
||||
const void* extra_add_qk,
|
||||
void* present) {
|
||||
// For testing, environment variable ORT_TRANSFORMER_OPTIONS=1 could enable persistent softmax
|
||||
const TransformerOptions* options = TransformerOptions::GetInstance();
|
||||
bool use_persistent_softmax = options->IsPrecisionMode() && !options->DisablePersistentSoftmax();
|
||||
if (element_size == 2) {
|
||||
return QkvToContext(prop, rocblas, stream,
|
||||
batch_size, sequence_length, num_heads, head_size, element_size,
|
||||
reinterpret_cast<const __half*>(input), reinterpret_cast<__half*>(output), reinterpret_cast<__half*>(workspace),
|
||||
mask_index, mask_index_dims, is_unidirectional,
|
||||
past_sequence_length, reinterpret_cast<const __half*>(past), reinterpret_cast<const __half*>(extra_add_qk),
|
||||
reinterpret_cast<__half*>(present), use_persistent_softmax);
|
||||
} else {
|
||||
return QkvToContext(prop, rocblas, stream,
|
||||
batch_size, sequence_length, num_heads, head_size, element_size,
|
||||
reinterpret_cast<const float*>(input), reinterpret_cast<float*>(output), reinterpret_cast<float*>(workspace),
|
||||
mask_index, mask_index_dims, is_unidirectional,
|
||||
past_sequence_length, reinterpret_cast<const float*>(past), reinterpret_cast<const float*>(extra_add_qk),
|
||||
reinterpret_cast<float*>(present), use_persistent_softmax);
|
||||
}
|
||||
}
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
423
onnxruntime/contrib_ops/rocm/bert/attention_softmax.h
Normal file
423
onnxruntime/contrib_ops/rocm/bert/attention_softmax.h
Normal file
|
|
@ -0,0 +1,423 @@
|
|||
#include "hip/hip_runtime.h"
|
||||
/*
|
||||
The implementation of this file is based on qkvToContext plugin in TensorRT demo:
|
||||
https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/
|
||||
|
||||
Copyright 2019 NVIDIA Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
#include <hipcub/hipcub.hpp>
|
||||
#include <hip/hip_fp16.h>
|
||||
#include "core/providers/rocm/cu_inc/common.cuh"
|
||||
#include "core/providers/rocm/rocm_common.h"
|
||||
#include "core/providers/rocm/math/softmax.h"
|
||||
|
||||
#define ROCMRT_INF_F __int_as_float(0x7f800000)
|
||||
|
||||
using namespace onnxruntime::rocm;
|
||||
using namespace hipcub;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
template <typename T, unsigned TPB>
|
||||
__device__ inline void Softmax(const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
const int valid_end,
|
||||
const int valid_start,
|
||||
const T* add_before_softmax,
|
||||
const T* input,
|
||||
T* output) {
|
||||
using BlockReduce = hipcub::BlockReduce<float, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmp_storage;
|
||||
|
||||
__shared__ float sum_reverse_block;
|
||||
__shared__ float max_block;
|
||||
|
||||
float thread_data_max(-ROCMRT_INF_F);
|
||||
|
||||
// e^x is represented as infinity if x is large enough, like 100.f.
|
||||
// Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough.
|
||||
// a math transform as below is leveraged to get a stable softmax:
|
||||
// e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max))
|
||||
const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length;
|
||||
for (int i = threadIdx.x; i < valid_end; i += TPB) {
|
||||
if (i >= valid_start) {
|
||||
const int index = offset + i;
|
||||
float input_at_idx = add_before_softmax == nullptr ? float(input[index]) : float(input[index] + add_before_softmax[index]);
|
||||
if (thread_data_max < input_at_idx) {
|
||||
thread_data_max = input_at_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, hipcub::Max());
|
||||
|
||||
// Store max value
|
||||
if (threadIdx.x == 0) {
|
||||
max_block = max;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float thread_data_sum(0.f);
|
||||
for (int i = threadIdx.x; i < valid_end; i += TPB) {
|
||||
if (i >= valid_start) {
|
||||
const int index = offset + i;
|
||||
float val = add_before_softmax == nullptr ? input[index] : input[index] + add_before_softmax[index];
|
||||
thread_data_sum += expf(val - max_block);
|
||||
}
|
||||
}
|
||||
|
||||
const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_sum, hipcub::Sum());
|
||||
if (threadIdx.x == 0) {
|
||||
sum_reverse_block = 1.f / sum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int i = threadIdx.x; i < all_sequence_length; i += TPB) {
|
||||
const int index = offset + i;
|
||||
float input_at_idx = add_before_softmax == nullptr ? float(input[index]) : float(input[index] + add_before_softmax[index]);
|
||||
const float val = (i >= valid_start && i < valid_end) ? expf(input_at_idx - max_block) * sum_reverse_block : 0.f;
|
||||
output[index] = T(val);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, unsigned TPB>
|
||||
__device__ inline void SoftmaxSmall(const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
const int valid_end,
|
||||
const int valid_start,
|
||||
const T* add_before_softmax,
|
||||
const T* input,
|
||||
T* output,
|
||||
bool is_unidirectional) {
|
||||
using BlockReduce = hipcub::BlockReduce<float, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmp_storage;
|
||||
|
||||
__shared__ float sum_reverse_block;
|
||||
__shared__ float max_block;
|
||||
|
||||
// Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S;
|
||||
const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length;
|
||||
const int index = offset + threadIdx.x;
|
||||
|
||||
bool is_valid = false; // whether it has attention mask == 1.
|
||||
|
||||
// Update end position for unidirectional.
|
||||
int end = valid_end;
|
||||
if (is_unidirectional) {
|
||||
int end_unid = all_sequence_length - sequence_length + (blockIdx.x % sequence_length) + 1;
|
||||
if (end_unid <= valid_start) {
|
||||
// In this situation, mask of [0, end_unid) and [valid_start, valid_end) has -10000, and [end_unid, valid_start) and [valid_end, all_seq_len) has -20000.
|
||||
// So [0, end_unid) will also have value after softmax.
|
||||
is_valid = threadIdx.x < end_unid;
|
||||
} else {
|
||||
end = min(valid_end, end_unid);
|
||||
}
|
||||
}
|
||||
|
||||
is_valid = is_valid || (threadIdx.x >= valid_start && threadIdx.x < end);
|
||||
|
||||
// e^x is represented as infinity if x is large enough, like 100.f.
|
||||
// Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough.
|
||||
// a math transform as below is leveraged to get a stable softmax:
|
||||
// e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max))
|
||||
float input_data = add_before_softmax == nullptr ? float(input[index]) : float(input[index] + add_before_softmax[index]);
|
||||
float thread_data_max = is_valid ? input_data : float(-ROCMRT_INF_F);
|
||||
const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, hipcub::Max(), end);
|
||||
|
||||
// Store max value
|
||||
if (threadIdx.x == 0) {
|
||||
max_block = max;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float thread_data_exp(0.f);
|
||||
if (is_valid) {
|
||||
thread_data_exp = expf(input_data - max_block);
|
||||
}
|
||||
|
||||
const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, hipcub::Sum(), end);
|
||||
|
||||
// Store value of 1.0/sum.
|
||||
if (threadIdx.x == 0) {
|
||||
sum_reverse_block = (1.f) / sum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// threadIdx.x might be larger than all_sequence_length due to alignment to 32x.
|
||||
if (threadIdx.x < all_sequence_length) {
|
||||
output[index] = T(thread_data_exp * sum_reverse_block);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, unsigned TPB>
|
||||
__device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
const int* attention_mask, // 2D, 3D or 4D attention mask
|
||||
const bool* key_padding_mask,
|
||||
const T* add_before_softmax,
|
||||
const T* input,
|
||||
T* output,
|
||||
const bool is_unidirectional,
|
||||
const float rsqrt_head_size,
|
||||
const int mask_dimension,
|
||||
const int max_sequence_length,
|
||||
const bool skip_softmax) {
|
||||
using BlockReduce = hipcub::BlockReduce<float, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmp_storage;
|
||||
|
||||
__shared__ float sum_reverse_block;
|
||||
__shared__ float max_block;
|
||||
|
||||
// Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S;
|
||||
int index = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length + threadIdx.x;
|
||||
|
||||
float thread_data = -ROCMRT_INF_F;
|
||||
if (threadIdx.x < all_sequence_length) {
|
||||
if (add_before_softmax == nullptr) {
|
||||
thread_data = float(input[index]) * rsqrt_head_size;
|
||||
} else {
|
||||
thread_data = float(input[index] + add_before_softmax[index]) * rsqrt_head_size;
|
||||
}
|
||||
|
||||
const int sequence_index = blockIdx.x % sequence_length;
|
||||
if (is_unidirectional) {
|
||||
int from_index = all_sequence_length - sequence_length + sequence_index; // offset of from token in all sequence length.
|
||||
if (threadIdx.x > from_index) {
|
||||
thread_data = -10000.0f;
|
||||
}
|
||||
}
|
||||
|
||||
int mask_offset = 0;
|
||||
const int batch_index = blockIdx.y;
|
||||
if (mask_dimension == 2) {
|
||||
mask_offset = batch_index * all_sequence_length + threadIdx.x;
|
||||
} else if (mask_dimension == 3) {
|
||||
mask_offset = (batch_index * sequence_length + sequence_index) * all_sequence_length + threadIdx.x;
|
||||
} else if (mask_dimension == 4) {
|
||||
mask_offset = (batch_index * max_sequence_length + all_sequence_length - sequence_length + sequence_index) * max_sequence_length + threadIdx.x;
|
||||
}
|
||||
|
||||
if (nullptr == key_padding_mask) {
|
||||
const int& mask = attention_mask[mask_offset];
|
||||
if (mask == 0)
|
||||
thread_data += -10000.0f;
|
||||
} else {
|
||||
const bool mask = key_padding_mask[mask_offset];
|
||||
if (mask) {
|
||||
thread_data = -ROCMRT_INF_F;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (skip_softmax) {
|
||||
if (threadIdx.x < all_sequence_length) {
|
||||
output[index] = T(thread_data);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const float max = BlockReduce(tmp_storage).Reduce(thread_data, hipcub::Max(), all_sequence_length);
|
||||
|
||||
// Store max value
|
||||
if (threadIdx.x == 0) {
|
||||
max_block = max;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float thread_data_exp = threadIdx.x < all_sequence_length ? expf(thread_data - max_block) : 0.0f;
|
||||
const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, hipcub::Sum(), all_sequence_length);
|
||||
|
||||
// Store value of 1.0/sum
|
||||
if (threadIdx.x == 0) {
|
||||
sum_reverse_block = (1.f) / sum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < all_sequence_length) {
|
||||
output[index] = T(thread_data_exp * sum_reverse_block);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, unsigned TPB>
|
||||
__global__ void SoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, const T* add_before_softmax, const T* input, T* output, bool is_unidirectional) {
|
||||
SoftmaxSmall<T, TPB>(all_sequence_length, sequence_length, all_sequence_length, 0, add_before_softmax, input, output, is_unidirectional);
|
||||
}
|
||||
|
||||
template <typename T, unsigned TPB>
|
||||
__global__ void SoftmaxKernel(const int all_sequence_length, const int sequence_length, const T* add_before_softmax, const T* input, T* output) {
|
||||
Softmax<T, TPB>(all_sequence_length, sequence_length, all_sequence_length, 0, add_before_softmax, input, output);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool ComputeSoftmax(
|
||||
hipStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads,
|
||||
const T* add_before_softmax, const T* input, T* output, bool is_unidirectional) {
|
||||
const dim3 grid(sequence_length * num_heads, batch_size, 1);
|
||||
if (all_sequence_length <= 32) {
|
||||
const int blockSize = 32;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional);
|
||||
} else if (all_sequence_length <= 64) {
|
||||
const int blockSize = 64;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional);
|
||||
} else if (all_sequence_length <= 128) {
|
||||
const int blockSize = 128;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional);
|
||||
} else if (all_sequence_length <= 256) {
|
||||
const int blockSize = 256;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional);
|
||||
} else if (all_sequence_length <= 512) {
|
||||
const int blockSize = 512;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional);
|
||||
} else if (all_sequence_length <= 1024) {
|
||||
const int blockSize = 1024;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernelSmall<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, add_before_softmax, input, output, is_unidirectional);
|
||||
} else if (!is_unidirectional) {
|
||||
const int blockSize = 1024;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxKernel<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, add_before_softmax, input, output);
|
||||
} else {
|
||||
ORT_THROW("Attention ROCM operator does not support total sequence length > 1024.");
|
||||
}
|
||||
|
||||
return HIP_CALL(hipPeekAtLastError());
|
||||
}
|
||||
|
||||
template <typename T, unsigned TPB>
|
||||
__global__ void MaskedSoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, const int* mask_end, const int* mask_start, const T* add_before_softmax, const T* input, T* output, bool is_unidirectional) {
|
||||
__shared__ int start_position;
|
||||
__shared__ int end_position;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
const int batch = blockIdx.y;
|
||||
start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0;
|
||||
end_position = min(all_sequence_length, mask_end[batch]);
|
||||
|
||||
// Attend to no word has same effect as attend to all words. This is added to get parity with CPU result.
|
||||
if (start_position >= end_position) {
|
||||
start_position = 0;
|
||||
end_position = all_sequence_length;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
SoftmaxSmall<T, TPB>(all_sequence_length, sequence_length, end_position, start_position, add_before_softmax, input, output, is_unidirectional);
|
||||
}
|
||||
|
||||
template <typename T, unsigned TPB>
|
||||
__global__ void MaskedSoftmaxKernel(const int all_sequence_length, const int sequence_length, const int* mask_end, const int* mask_start, const T* add_before_softmax, const T* input, T* output) {
|
||||
__shared__ int start_position;
|
||||
__shared__ int end_position;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
const int batch = blockIdx.y;
|
||||
start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0;
|
||||
end_position = min(all_sequence_length, mask_end[batch]);
|
||||
|
||||
// Attend to no word has same effect as attend to all words. This is added to get parity with CPU result.
|
||||
if (start_position >= end_position) {
|
||||
start_position = 0;
|
||||
end_position = all_sequence_length;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
Softmax<T, TPB>(all_sequence_length, sequence_length, end_position, start_position, add_before_softmax, input, output);
|
||||
}
|
||||
|
||||
template <typename T, unsigned TPB>
|
||||
__global__ void SoftmaxWithRawMaskSmallKernel(const int all_sequence_length, const int sequence_length, const int* attention_mask, const bool* key_padding_mask, const T* add_before_softmax, const T* input,
|
||||
T* output, const bool is_unidirectional, const float rsqrt_head_size, const int mask_dimension, const int max_sequence_length,
|
||||
const bool skip_softmax) {
|
||||
SoftmaxWithRawMaskSmall<T, TPB>(all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, output, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, skip_softmax);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool ComputeSoftmaxWithMask1D(hipStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads,
|
||||
const int* mask_index, const int* mask_start, const T* add_before_softmax, const T* input, T* output, const bool is_unidirectional) {
|
||||
const dim3 grid(sequence_length * num_heads, batch_size, 1);
|
||||
|
||||
if (all_sequence_length <= 32) {
|
||||
const int blockSize = 32;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output, is_unidirectional);
|
||||
} else if (all_sequence_length <= 64) {
|
||||
const int blockSize = 64;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output, is_unidirectional);
|
||||
} else if (all_sequence_length <= 128) {
|
||||
const int blockSize = 128;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output, is_unidirectional);
|
||||
} else if (all_sequence_length <= 256) {
|
||||
const int blockSize = 256;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output, is_unidirectional);
|
||||
} else if (all_sequence_length <= 512) {
|
||||
const int blockSize = 512;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output, is_unidirectional);
|
||||
} else if (all_sequence_length <= 1024) {
|
||||
const int blockSize = 1024;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernelSmall<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output, is_unidirectional);
|
||||
} else if (!is_unidirectional) {
|
||||
const int blockSize = 1024;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(MaskedSoftmaxKernel<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, mask_index, mask_start, add_before_softmax, input, output);
|
||||
} else {
|
||||
ORT_THROW("Attention ROCM operator does not support total sequence length > 1024.");
|
||||
}
|
||||
|
||||
return HIP_CALL(hipPeekAtLastError());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool ComputeSoftmaxWithRawMask(hipStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads,
|
||||
const int* attention_mask, const bool* key_padding_mask, const T* add_before_softmax, const T* input, T* output, const bool is_unidirectional,
|
||||
const float rsqrt_head_size, const int mask_dimension, const int max_sequence_length, const bool use_persistent_softmax, T* persistent_softmax_workspace) {
|
||||
const dim3 grid(sequence_length * num_heads, batch_size, 1);
|
||||
|
||||
T* out = use_persistent_softmax ? persistent_softmax_workspace : output;
|
||||
if (all_sequence_length <= 32) {
|
||||
const int blockSize = 32;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax);
|
||||
} else if (all_sequence_length <= 64) {
|
||||
const int blockSize = 64;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax);
|
||||
} else if (all_sequence_length <= 128) {
|
||||
const int blockSize = 128;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax);
|
||||
} else if (all_sequence_length <= 256) {
|
||||
const int blockSize = 256;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax);
|
||||
} else if (all_sequence_length <= 512) {
|
||||
const int blockSize = 512;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax);
|
||||
} else if (all_sequence_length <= 1024) {
|
||||
const int blockSize = 1024;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SoftmaxWithRawMaskSmallKernel<T, blockSize>), grid, blockSize, 0, stream, all_sequence_length, sequence_length, attention_mask, key_padding_mask, add_before_softmax, input, out, is_unidirectional, rsqrt_head_size, mask_dimension, max_sequence_length, use_persistent_softmax);
|
||||
} else {
|
||||
ORT_THROW("Attention ROCM operator does not support total sequence length > 1024.");
|
||||
}
|
||||
|
||||
if (use_persistent_softmax) {
|
||||
dispatch_warpwise_softmax_forward<T, T, float, false>(stream, output, persistent_softmax_workspace, all_sequence_length, all_sequence_length, batch_size * num_heads * sequence_length);
|
||||
}
|
||||
|
||||
return HIP_CALL(hipPeekAtLastError());
|
||||
}
|
||||
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -125,8 +125,8 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
|
|||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, Affine)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, Affine)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, Affine)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Attention)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Attention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Attention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Attention)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ConvTransposeWithDynamicPads)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, Crop)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, Crop)>,
|
||||
|
|
|
|||
|
|
@ -46,10 +46,11 @@ static void RunAttentionTest(
|
|||
|
||||
int min_cuda_architecture = use_float16 ? 530 : 0;
|
||||
bool enable_cuda = HasCudaEnvironment(min_cuda_architecture) && !is_weights_constant && !only_enable_cpu;
|
||||
bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()) && !is_weights_constant && !only_enable_cpu;
|
||||
bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()) && !use_float16 && !only_enable_cuda;
|
||||
|
||||
int head_size = hidden_size / number_of_heads;
|
||||
if (enable_cpu || enable_cuda) {
|
||||
if (enable_cpu || enable_cuda || enable_rocm) {
|
||||
OpTester tester("Attention", 1, onnxruntime::kMSDomain);
|
||||
tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(number_of_heads));
|
||||
tester.AddAttribute<int64_t>("unidirectional", static_cast<int64_t>(is_unidirectional ? 1 : 0));
|
||||
|
|
@ -151,6 +152,12 @@ static void RunAttentionTest(
|
|||
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
}
|
||||
|
||||
if (enable_rocm) {
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.push_back(DefaultRocmExecutionProvider());
|
||||
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
}
|
||||
|
||||
if (enable_cpu) {
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.push_back(DefaultCpuExecutionProvider());
|
||||
|
|
|
|||
|
|
@ -16,11 +16,8 @@ training_ops_path = 'orttraining/orttraining/training_ops'
|
|||
|
||||
contrib_ops_excluded_files = [
|
||||
'bert/attention.cc',
|
||||
'bert/attention.h',
|
||||
'bert/attention_impl.cu',
|
||||
'bert/attention_impl.h',
|
||||
'bert/attention_transpose.cu',
|
||||
'bert/attention_concat.cu',
|
||||
'bert/attention_softmax.h',
|
||||
'bert/decoder_attention.h',
|
||||
'bert/decoder_attention.cc',
|
||||
'bert/embed_layer_norm.cc',
|
||||
|
|
|
|||
Loading…
Reference in a new issue