mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
[ROCm] Support for gpt2-based model inferencing (#14675)
When inferencing real gpt2-based model, found some gaps between CUDA and ROCm codebase. The fixes include: 1. minimum code change to fix tensor shape on Attention Op 2. Support optional output tensor with SkipLayerNorm 3. fix a build error found on MI200 --------- Co-authored-by: Ubuntu <ettao@ettao-amd-dev1.zvflicr54joexhdgnhvmxrxygg.phxx.internal.cloudapp.net>
This commit is contained in:
parent
a216c9a3fa
commit
d49cea05fa
9 changed files with 175 additions and 95 deletions
|
|
@ -15,6 +15,10 @@ namespace onnxruntime {
|
|||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
constexpr int kPastSequenceLengthInputIndex = 6;
|
||||
constexpr int kPastInputIndex = 4;
|
||||
constexpr int kPresentOutputIndex = 1;
|
||||
|
||||
#define REGISTER_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
Attention, \
|
||||
|
|
@ -22,8 +26,10 @@ namespace rocm {
|
|||
1, \
|
||||
T, \
|
||||
kRocmExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.MayInplace(kPastInputIndex, kPresentOutputIndex) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
|
||||
.InputMemoryType(OrtMemTypeCPUInput, kPastSequenceLengthInputIndex), \
|
||||
Attention<T>);
|
||||
|
||||
REGISTER_KERNEL_TYPED(float)
|
||||
|
|
@ -40,47 +46,41 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
const Tensor* mask_index = context->Input<Tensor>(3);
|
||||
const Tensor* past = context->Input<Tensor>(4);
|
||||
const Tensor* relative_position_bias = context->Input<Tensor>(5);
|
||||
const Tensor* past_seq_len = context->Input<Tensor>(kPastSequenceLengthInputIndex);
|
||||
|
||||
auto& device_prop = GetDeviceProp();
|
||||
AttentionParameters parameters;
|
||||
ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(),
|
||||
weights->Shape(),
|
||||
bias->Shape(),
|
||||
mask_index,
|
||||
past,
|
||||
relative_position_bias,
|
||||
nullptr,
|
||||
device_prop.maxThreadsPerBlock));
|
||||
|
||||
// input shape (batch_size, sequence_length, input_hidden_size)
|
||||
const auto& shape = input->Shape();
|
||||
int batch_size = static_cast<int>(shape[0]);
|
||||
int sequence_length = static_cast<int>(shape[1]);
|
||||
int input_hidden_size = static_cast<int>(shape[2]);
|
||||
|
||||
// Note: Scenario where q_hidden_size == k_hidden_size != v_hidden_size is not supported in ROCM EP
|
||||
// bias shape (3 * hidden_size)
|
||||
const auto& bias_shape = bias->Shape();
|
||||
int hidden_size = static_cast<int>(bias_shape[0]) / 3;
|
||||
|
||||
int head_size = hidden_size / num_heads_;
|
||||
¶meters,
|
||||
device_prop.maxThreadsPerBlock,
|
||||
past_seq_len));
|
||||
ORT_ENFORCE(parameters.sequence_length == parameters.kv_sequence_length); // self attention
|
||||
|
||||
TensorShapeVector output_shape(3);
|
||||
output_shape[0] = shape[0];
|
||||
output_shape[1] = shape[1];
|
||||
output_shape[2] = static_cast<int64_t>(hidden_size);
|
||||
output_shape[0] = static_cast<int64_t>(parameters.batch_size);
|
||||
output_shape[1] = static_cast<int64_t>(parameters.sequence_length);
|
||||
output_shape[2] = static_cast<int64_t>(parameters.v_hidden_size);
|
||||
Tensor* output = context->Output(0, output_shape);
|
||||
|
||||
int past_sequence_length = 0;
|
||||
Tensor* present = GetPresent(context, past, batch_size, head_size, sequence_length, past_sequence_length);
|
||||
std::vector<int64_t> present_dims{
|
||||
2, parameters.batch_size, parameters.num_heads,
|
||||
parameters.past_present_share_buffer ? parameters.max_sequence_length : parameters.total_sequence_length,
|
||||
parameters.head_size};
|
||||
TensorShape present_shape(present_dims);
|
||||
Tensor* present = context->Output(kPresentOutputIndex, present_shape);
|
||||
|
||||
rocblas_handle rocblas = GetRocblasHandle(context);
|
||||
constexpr size_t element_size = sizeof(T);
|
||||
|
||||
// Use GEMM for fully connection.
|
||||
int m = batch_size * sequence_length;
|
||||
int n = 3 * hidden_size;
|
||||
int k = input_hidden_size;
|
||||
auto gemm_buffer = GetScratchBuffer<T>(batch_size * sequence_length * 3 * hidden_size * element_size, context->GetComputeStream());
|
||||
int m = parameters.batch_size * parameters.sequence_length;
|
||||
int n = (parameters.hidden_size + parameters.hidden_size + parameters.v_hidden_size);
|
||||
int k = parameters.input_hidden_size;
|
||||
auto gemm_buffer = GetScratchBuffer<T>(static_cast<size_t>(m) * n, context->GetComputeStream());
|
||||
|
||||
typedef typename ToHipType<T>::MappedType HipT;
|
||||
namespace blas = rocm::tunable::blas;
|
||||
|
|
@ -108,8 +108,12 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
/*beta=*/1.0f,
|
||||
reinterpret_cast<HipT*>(gemm_buffer.get()), n));
|
||||
|
||||
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, num_heads_, head_size,
|
||||
sequence_length, past_sequence_length);
|
||||
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size,
|
||||
parameters.batch_size,
|
||||
parameters.num_heads,
|
||||
parameters.head_size,
|
||||
parameters.sequence_length,
|
||||
parameters.past_sequence_length);
|
||||
|
||||
auto work_space = GetScratchBuffer<void>(workSpaceSize, context->GetComputeStream());
|
||||
return LaunchAttentionKernel(
|
||||
|
|
@ -118,16 +122,16 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
Stream(context),
|
||||
rocblas,
|
||||
element_size,
|
||||
batch_size,
|
||||
sequence_length,
|
||||
num_heads_,
|
||||
head_size,
|
||||
past_sequence_length,
|
||||
is_unidirectional_,
|
||||
parameters.batch_size,
|
||||
parameters.sequence_length,
|
||||
parameters.num_heads,
|
||||
parameters.head_size,
|
||||
parameters.past_sequence_length,
|
||||
parameters.is_unidirectional,
|
||||
reinterpret_cast<const void*>(gemm_buffer.get()),
|
||||
nullptr == mask_index ? nullptr : mask_index->Data<int>(),
|
||||
nullptr == mask_index ? gsl::span<const int64_t>() : mask_index->Shape().GetDims(),
|
||||
mask_filter_value_,
|
||||
parameters.mask_filter_value,
|
||||
nullptr == past ? nullptr : past->Data<T>(),
|
||||
nullptr == relative_position_bias ? nullptr : relative_position_bias->Data<T>(),
|
||||
work_space.get(),
|
||||
|
|
|
|||
|
|
@ -43,6 +43,10 @@ Status SkipLayerNorm<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
|
||||
Tensor* output = ctx->Output(0, input->Shape());
|
||||
|
||||
// For inferencing, we support one more optional output which is the sum
|
||||
// of the input and skip tensors
|
||||
Tensor* skip_input_bias_add_output = ctx->Output(3, input->Shape());
|
||||
|
||||
if (input->Shape() != skip->Shape()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"skip is expected to have same shape as input");
|
||||
|
|
@ -101,6 +105,7 @@ Status SkipLayerNorm<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
GetTuningContext(),
|
||||
Stream(ctx),
|
||||
reinterpret_cast<HipT*>(output->MutableData<T>()),
|
||||
skip_input_bias_add_output != nullptr ? reinterpret_cast<HipT*>(skip_input_bias_add_output->MutableData<T>()) : nullptr,
|
||||
reinterpret_cast<const HipT*>(input->Data<T>()),
|
||||
reinterpret_cast<const HipT*>(skip->Data<T>()),
|
||||
reinterpret_cast<const HipT*>(gamma->Data<T>()),
|
||||
|
|
|
|||
|
|
@ -41,12 +41,12 @@ namespace rocm {
|
|||
|
||||
template <typename T>
|
||||
Status LaunchSkipLayerNormKernel(
|
||||
RocmTuningContext* tuning_ctx, hipStream_t stream, T* output, const T* input, const T* skip, const T* gamma,
|
||||
const T* beta, const T* bias, float epsilon, int ld, int element_count) {
|
||||
RocmTuningContext* tuning_ctx, hipStream_t stream, T* output, T* skip_input_bias_add_output, const T* input,
|
||||
const T* skip, const T* gamma, const T* beta, const T* bias, float epsilon, int ld, int element_count) {
|
||||
// this must be true because element_count is the total size of the tensor
|
||||
assert(element_count % ld == 0);
|
||||
|
||||
SkipLayerNormParams<T> params(tuning_ctx, stream, output, input, skip, gamma, beta, bias, epsilon, ld, element_count);
|
||||
SkipLayerNormParams<T> params(tuning_ctx, stream, output, skip_input_bias_add_output, input, skip, gamma, beta, bias, epsilon, ld, element_count);
|
||||
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
static SkipLayerNormTunableOp<T> op;
|
||||
|
|
@ -57,13 +57,13 @@ Status LaunchSkipLayerNormKernel(
|
|||
}
|
||||
|
||||
template Status LaunchSkipLayerNormKernel<float>(
|
||||
RocmTuningContext* tuning_ctx, hipStream_t stream, float* output, const float* input,
|
||||
RocmTuningContext* tuning_ctx, hipStream_t stream, float* output, float* skip_input_bias_add_output, const float* input,
|
||||
const float* skip, const float* gamma, const float* beta,
|
||||
const float* bias, float epsilon, int ld,
|
||||
int element_count);
|
||||
|
||||
template Status LaunchSkipLayerNormKernel<half>(
|
||||
RocmTuningContext* tuning_ctx, hipStream_t stream, half* output, const half* input,
|
||||
RocmTuningContext* tuning_ctx, hipStream_t stream, half* output, half* skip_input_bias_add_output, const half* input,
|
||||
const half* skip, const half* gamma, const half* beta,
|
||||
const half* bias, float epsilon, int ld,
|
||||
int element_count);
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ Status LaunchSkipLayerNormKernel(
|
|||
RocmTuningContext* tuning,
|
||||
hipStream_t stream,
|
||||
T* output, // output tensor
|
||||
T* skip_input_bias_add_output, // optional output tensor
|
||||
const T* input, // input tensor
|
||||
const T* skip, // skip tensor
|
||||
const T* gamma, // Layer normalization gamma tensor
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ half maybe2half(float x) {
|
|||
template <typename T, unsigned TPB>
|
||||
__global__ void SkipLayerNormKernel(
|
||||
const int ld, const T* input, const T* skip, const T* beta, const T* gamma, const T* bias,
|
||||
const T epsilon, T* output) {
|
||||
const T epsilon, T* output, T* skip_input_bias_add_output) {
|
||||
const T reverse_ld = T(1.f / ld);
|
||||
const int offset = blockIdx.x * ld;
|
||||
|
||||
|
|
@ -39,6 +39,11 @@ __global__ void SkipLayerNormKernel(
|
|||
const T val = (bias == nullptr) ? input[idx] + skip[idx] : input[idx] + skip[idx] + bias[i];
|
||||
const T rldval = reverse_ld * val;
|
||||
thread_data = pair_sum(thread_data, hipcub::KeyValuePair<T, T>(rldval, rldval * val));
|
||||
|
||||
if (skip_input_bias_add_output != nullptr) {
|
||||
skip_input_bias_add_output[idx] = val;
|
||||
}
|
||||
|
||||
output[idx] = val;
|
||||
}
|
||||
|
||||
|
|
@ -49,7 +54,8 @@ __global__ void SkipLayerNormKernel(
|
|||
template <typename T, unsigned TPB, int ILP>
|
||||
__global__ void SkipLayerNormKernelVec(
|
||||
const int ld, const T* input, const T* skip, const T* beta, const T* gamma,
|
||||
const T* bias, const T epsilon, T* output, bool hasBias) {
|
||||
const T* bias, const T epsilon, T* output, T* skip_input_bias_add_output,
|
||||
bool hasBias, bool hasSkipInputBiasAdditionOutput) {
|
||||
const T reverse_ld = T(1.f / ld);
|
||||
const int offset = blockIdx.x * ld;
|
||||
|
||||
|
|
@ -58,7 +64,7 @@ __global__ void SkipLayerNormKernelVec(
|
|||
hipcub::KeyValuePair<T, T> thread_data(0, 0);
|
||||
|
||||
using VecT = aligned_vector<T, ILP>;
|
||||
T input_v[ILP], skip_v[ILP], bias_v[ILP];
|
||||
T input_v[ILP], skip_v[ILP], bias_v[ILP], skip_input_bias_add_output_v[ILP];;
|
||||
if (threadIdx.x * ILP < ld) {
|
||||
VecT* input_val = reinterpret_cast<VecT*>(&input_v);
|
||||
VecT* skip_val = reinterpret_cast<VecT*>(&skip_v);
|
||||
|
|
@ -76,9 +82,19 @@ __global__ void SkipLayerNormKernelVec(
|
|||
#pragma unroll
|
||||
for (int k = 0; k < ILP; k++) {
|
||||
input_v[k] += hasBias ? skip_v[k] + bias_v[k] : skip_v[k];
|
||||
|
||||
if (hasSkipInputBiasAdditionOutput) {
|
||||
skip_input_bias_add_output_v[i] = input_v[i];
|
||||
}
|
||||
|
||||
const T rldval = reverse_ld * input_v[k];
|
||||
thread_data = pair_sum(thread_data, hipcub::KeyValuePair<T, T>(rldval, rldval * input_v[k]));
|
||||
}
|
||||
|
||||
if (hasSkipInputBiasAdditionOutput) {
|
||||
*(reinterpret_cast<VecT*>(&skip_input_bias_add_output[idx])) = *reinterpret_cast<VecT*>(&skip_input_bias_add_output_v);
|
||||
}
|
||||
|
||||
*(reinterpret_cast<VecT*>(&output[idx])) = *reinterpret_cast<VecT*>(&input_v[0]);
|
||||
}
|
||||
}
|
||||
|
|
@ -90,12 +106,13 @@ __global__ void SkipLayerNormKernelVec(
|
|||
template <typename T, unsigned TPB, int ILP>
|
||||
__global__ void SkipLayerNormKernelSmall(
|
||||
const int ld, const T* input, const T* skip, const T* beta, const T* gamma,
|
||||
const T* bias, const T epsilon, T* output, bool hasBias) {
|
||||
const T* bias, const T epsilon, T* output, T* skip_input_bias_add_output,
|
||||
bool hasBias, bool hasSkipInputBiasAdditionOutput) {
|
||||
const T rld = T(1.f / ld);
|
||||
const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld
|
||||
|
||||
using VecT = aligned_vector<T, ILP>;
|
||||
T input_v[ILP], skip_v[ILP], bias_v[ILP];
|
||||
T input_v[ILP], skip_v[ILP], bias_v[ILP], skip_input_bias_add_output_v[ILP];
|
||||
|
||||
hipcub::KeyValuePair<T, T> thread_data(T(0.f), T(0.f));
|
||||
|
||||
|
|
@ -116,10 +133,20 @@ __global__ void SkipLayerNormKernelSmall(
|
|||
#pragma unroll
|
||||
for (int i = 0; i < ILP; i++) {
|
||||
input_v[i] += hasBias ? skip_v[i] + bias_v[i] : skip_v[i];
|
||||
|
||||
if (hasSkipInputBiasAdditionOutput) {
|
||||
skip_input_bias_add_output_v[i] = input_v[i];
|
||||
}
|
||||
|
||||
const T rldval = rld * input_v[i];
|
||||
rldval_sum += rldval;
|
||||
rldvalsq_sum += rldval * input_v[i];
|
||||
}
|
||||
|
||||
if (hasSkipInputBiasAdditionOutput) {
|
||||
*(reinterpret_cast<VecT*>(&skip_input_bias_add_output[idx])) = *reinterpret_cast<VecT*>(&skip_input_bias_add_output_v);
|
||||
}
|
||||
|
||||
thread_data = hipcub::KeyValuePair<T, T>(rldval_sum, rldvalsq_sum);
|
||||
}
|
||||
LayerNormSmall<T, TPB, ILP>(input_v, thread_data, ld, idx, beta, gamma, epsilon, output);
|
||||
|
|
|
|||
|
|
@ -20,11 +20,11 @@ namespace rocm {
|
|||
|
||||
template <typename T>
|
||||
struct SkipLayerNormParams : OpParams {
|
||||
SkipLayerNormParams(RocmTuningContext* tuning_ctx, hipStream_t stream, T* output, const T* input,
|
||||
SkipLayerNormParams(RocmTuningContext* tuning_ctx, hipStream_t stream, T* output, T* skip_input_bias_add_output, const T* input,
|
||||
const T* skip, const T* gamma, const T* beta,
|
||||
const T* bias, float epsilon, int ld, int element_count)
|
||||
: OpParams(tuning_ctx, stream), output(output), input(input), skip(skip), gamma(gamma), beta(beta), bias(bias),
|
||||
epsilon(epsilon), ld(ld), element_count(element_count) {}
|
||||
: OpParams(tuning_ctx, stream), output(output), skip_input_bias_add_output(skip_input_bias_add_output), input(input), skip(skip),
|
||||
gamma(gamma), beta(beta), bias(bias), epsilon(epsilon), ld(ld), element_count(element_count) {}
|
||||
|
||||
std::string Signature() const override {
|
||||
std::string sig = std::to_string(ld) + "_" + std::to_string(element_count);
|
||||
|
|
@ -32,6 +32,7 @@ struct SkipLayerNormParams : OpParams {
|
|||
}
|
||||
|
||||
T* output;
|
||||
T* skip_input_bias_add_output;
|
||||
const T* input;
|
||||
const T* skip;
|
||||
const T* gamma;
|
||||
|
|
@ -51,8 +52,8 @@ Status SkipLayerNormSmallOp(const SkipLayerNormParams<T>* params) {
|
|||
dim3(ThreadsPerBlock),
|
||||
0, params->stream>>>(
|
||||
params->ld, params->input, params->skip,
|
||||
params->beta, params->gamma, params->bias, maybe2half<T>(params->epsilon), params->output,
|
||||
(params->bias == nullptr) ? false : true);
|
||||
params->beta, params->gamma, params->bias, maybe2half<T>(params->epsilon), params->output, params->skip_input_bias_add_output,
|
||||
(params->bias == nullptr) ? false : true, (params->skip_input_bias_add_output == nullptr) ? false : true);
|
||||
return HIP_CALL(hipGetLastError());
|
||||
}
|
||||
|
||||
|
|
@ -66,51 +67,52 @@ Status SkipLayerNormRegularOp(const SkipLayerNormParams<T>* params) {
|
|||
dim3(ThreadsPerBlock),
|
||||
0, params->stream>>>(
|
||||
params->ld, params->input, params->skip,
|
||||
params->beta, params->gamma, params->bias, maybe2half<T>(params->epsilon), params->output,
|
||||
(params->bias == nullptr) ? false : true);
|
||||
params->beta, params->gamma, params->bias, maybe2half<T>(params->epsilon), params->output, params->skip_input_bias_add_output,
|
||||
(params->bias == nullptr) ? false : true, (params->skip_input_bias_add_output == nullptr) ? false : true);
|
||||
return HIP_CALL(hipGetLastError());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status SkipLayerNormStaticSelection(const SkipLayerNormParams<T>* params) {
|
||||
bool hasBias = (params->bias == nullptr) ? false : true;
|
||||
bool hasSkipInputBiasAdditionOutput = (params->skip_input_bias_add_output == nullptr) ? false : true;
|
||||
if (0 == (params->ld % 4)) {
|
||||
const int grid_size = params->element_count / params->ld;
|
||||
if (params->ld <= 32) {
|
||||
constexpr int block_size = 32;
|
||||
SkipLayerNormKernelSmall<T, block_size, 1><<<grid_size, block_size, 0, params->stream>>>(
|
||||
params->ld, params->input, params->skip, params->beta, params->gamma, params->bias,
|
||||
maybe2half<T>(params->epsilon), params->output, hasBias);
|
||||
maybe2half<T>(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput);
|
||||
} else if (params->ld <= 64) {
|
||||
constexpr int block_size = 64 / 2;
|
||||
SkipLayerNormKernelSmall<T, block_size, 2><<<grid_size, block_size, 0, params->stream>>>(
|
||||
params->ld, params->input, params->skip, params->beta, params->gamma, params->bias,
|
||||
maybe2half<T>(params->epsilon), params->output, hasBias);
|
||||
maybe2half<T>(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput);
|
||||
} else if (params->ld <= 128) {
|
||||
constexpr int block_size = 128 / 4;
|
||||
SkipLayerNormKernelSmall<T, block_size, 4><<<grid_size, block_size, 0, params->stream>>>(
|
||||
params->ld, params->input, params->skip, params->beta, params->gamma, params->bias,
|
||||
maybe2half<T>(params->epsilon), params->output, hasBias);
|
||||
maybe2half<T>(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput);
|
||||
} else if (params->ld <= 384) {
|
||||
constexpr int block_size = 384 / 4;
|
||||
SkipLayerNormKernelSmall<T, block_size, 4><<<grid_size, block_size, 0, params->stream>>>(
|
||||
params->ld, params->input, params->skip, params->beta, params->gamma, params->bias,
|
||||
maybe2half<T>(params->epsilon), params->output, hasBias);
|
||||
maybe2half<T>(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput);
|
||||
} else if (params->ld <= 768) {
|
||||
constexpr int block_size = 768 / 4;
|
||||
SkipLayerNormKernelSmall<T, block_size, 4><<<grid_size, block_size, 0, params->stream>>>(
|
||||
params->ld, params->input, params->skip, params->beta, params->gamma, params->bias,
|
||||
maybe2half<T>(params->epsilon), params->output, hasBias);
|
||||
maybe2half<T>(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput);
|
||||
} else if (params->ld <= 1024) {
|
||||
constexpr int block_size = 1024 / 4;
|
||||
SkipLayerNormKernelSmall<T, block_size, 4><<<grid_size, block_size, 0, params->stream>>>(
|
||||
params->ld, params->input, params->skip, params->beta, params->gamma, params->bias,
|
||||
maybe2half<T>(params->epsilon), params->output, hasBias);
|
||||
maybe2half<T>(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput);
|
||||
} else {
|
||||
constexpr int block_size = 256;
|
||||
SkipLayerNormKernel<T, block_size><<<grid_size, block_size, 0, params->stream>>>(
|
||||
params->ld, params->input, params->skip, params->beta, params->gamma, params->bias,
|
||||
maybe2half<T>(params->epsilon), params->output);
|
||||
maybe2half<T>(params->epsilon), params->output, params->skip_input_bias_add_output);
|
||||
}
|
||||
} else {
|
||||
const int grid_size = params->element_count / params->ld;
|
||||
|
|
@ -118,27 +120,27 @@ Status SkipLayerNormStaticSelection(const SkipLayerNormParams<T>* params) {
|
|||
constexpr int block_size = 32;
|
||||
SkipLayerNormKernelSmall<T, block_size, 1><<<grid_size, block_size, 0, params->stream>>>(
|
||||
params->ld, params->input, params->skip, params->beta, params->gamma, params->bias,
|
||||
maybe2half<T>(params->epsilon), params->output, hasBias);
|
||||
maybe2half<T>(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput);
|
||||
} else if (params->ld <= 64) {
|
||||
constexpr int block_size = 64;
|
||||
SkipLayerNormKernelSmall<T, block_size, 1><<<grid_size, block_size, 0, params->stream>>>(
|
||||
params->ld, params->input, params->skip, params->beta, params->gamma, params->bias,
|
||||
maybe2half<T>(params->epsilon), params->output, hasBias);
|
||||
maybe2half<T>(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput);
|
||||
} else if (params->ld <= 128) {
|
||||
constexpr int block_size = 128;
|
||||
SkipLayerNormKernelSmall<T, block_size, 1><<<grid_size, block_size, 0, params->stream>>>(
|
||||
params->ld, params->input, params->skip, params->beta, params->gamma, params->bias,
|
||||
maybe2half<T>(params->epsilon), params->output, hasBias);
|
||||
maybe2half<T>(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput);
|
||||
} else if (params->ld == 384) {
|
||||
constexpr int block_size = 384;
|
||||
SkipLayerNormKernelSmall<T, block_size, 1><<<grid_size, block_size, 0, params->stream>>>(
|
||||
params->ld, params->input, params->skip, params->beta, params->gamma, params->bias,
|
||||
maybe2half<T>(params->epsilon), params->output, hasBias);
|
||||
maybe2half<T>(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput);
|
||||
} else {
|
||||
constexpr int block_size = 256;
|
||||
SkipLayerNormKernel<T, block_size><<<grid_size, block_size, 0, params->stream>>>(
|
||||
params->ld, params->input, params->skip, params->beta, params->gamma, params->bias,
|
||||
maybe2half<T>(params->epsilon), params->output);
|
||||
maybe2half<T>(params->epsilon), params->output, params->skip_input_bias_add_output);
|
||||
}
|
||||
}
|
||||
return HIP_CALL(hipPeekAtLastError());
|
||||
|
|
|
|||
|
|
@ -17,12 +17,12 @@ namespace onnxruntime {
|
|||
template <typename T, int ThreadsPerBlock, int VecSize>
|
||||
class SkipLayerNormSmall : public IKernelExplorer {
|
||||
public:
|
||||
SkipLayerNormSmall(DeviceArray& output, DeviceArray& input, DeviceArray& skip,
|
||||
SkipLayerNormSmall(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input, DeviceArray& skip,
|
||||
DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias,
|
||||
float epsilon, int hidden_size, int element_count)
|
||||
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(input.ptr()),
|
||||
static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()), static_cast<T*>(beta.ptr()),
|
||||
static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {}
|
||||
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(skip_input_bias_add_output.ptr()),
|
||||
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()),
|
||||
static_cast<T*>(beta.ptr()), static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {}
|
||||
|
||||
void Run() override {
|
||||
ORT_THROW_IF_ERROR((contrib::rocm::SkipLayerNormSmallOp<T, ThreadsPerBlock, VecSize>(¶ms_)));
|
||||
|
|
@ -41,12 +41,12 @@ class SkipLayerNormSmall : public IKernelExplorer {
|
|||
template <typename T, int ThreadsPerBlock, int VecSize>
|
||||
class SkipLayerNormRegular : public IKernelExplorer {
|
||||
public:
|
||||
SkipLayerNormRegular(DeviceArray& output, DeviceArray& input, DeviceArray& skip,
|
||||
SkipLayerNormRegular(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input, DeviceArray& skip,
|
||||
DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias,
|
||||
float epsilon, int hidden_size, int element_count)
|
||||
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(input.ptr()),
|
||||
static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()), static_cast<T*>(beta.ptr()),
|
||||
static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {}
|
||||
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(skip_input_bias_add_output.ptr()),
|
||||
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()),
|
||||
static_cast<T*>(beta.ptr()), static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {}
|
||||
|
||||
void Run() override {
|
||||
ORT_THROW_IF_ERROR((contrib::rocm::SkipLayerNormRegularOp<T, ThreadsPerBlock, VecSize>(¶ms_)));
|
||||
|
|
@ -65,12 +65,12 @@ class SkipLayerNormRegular : public IKernelExplorer {
|
|||
template <typename T>
|
||||
class SkipLayerNormStaticSelection : public IKernelExplorer {
|
||||
public:
|
||||
SkipLayerNormStaticSelection(DeviceArray& output, DeviceArray& input, DeviceArray& skip,
|
||||
DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias,
|
||||
SkipLayerNormStaticSelection(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input,
|
||||
DeviceArray& skip, DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias,
|
||||
float epsilon, int hidden_size, int element_count)
|
||||
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(input.ptr()),
|
||||
static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()), static_cast<T*>(beta.ptr()),
|
||||
static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {}
|
||||
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(skip_input_bias_add_output.ptr()),
|
||||
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()),
|
||||
static_cast<T*>(beta.ptr()), static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {}
|
||||
|
||||
void Run() override {
|
||||
ORT_THROW_IF_ERROR((contrib::rocm::SkipLayerNormStaticSelection<T>(¶ms_)));
|
||||
|
|
@ -89,12 +89,12 @@ class SkipLayerNormStaticSelection : public IKernelExplorer {
|
|||
template <typename T>
|
||||
class SkipLayerNormTunable : public IKernelExplorer {
|
||||
public:
|
||||
SkipLayerNormTunable(DeviceArray& output, DeviceArray& input, DeviceArray& skip,
|
||||
SkipLayerNormTunable(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input, DeviceArray& skip,
|
||||
DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias,
|
||||
float epsilon, int hidden_size, int element_count)
|
||||
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(input.ptr()),
|
||||
static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()), static_cast<T*>(beta.ptr()),
|
||||
static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {
|
||||
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(skip_input_bias_add_output.ptr()),
|
||||
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()),
|
||||
static_cast<T*>(beta.ptr()), static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {
|
||||
|
||||
params_.TuningContext()->EnableTunableOp();
|
||||
}
|
||||
|
|
@ -113,14 +113,14 @@ class SkipLayerNormTunable : public IKernelExplorer {
|
|||
contrib::rocm::SkipLayerNormTunableOp<T> op_{};
|
||||
};
|
||||
|
||||
#define REGISTER_OP(name, type, threads_per_block, vec_size) \
|
||||
py::class_<name<type, threads_per_block, vec_size>>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \
|
||||
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, \
|
||||
DeviceArray&, DeviceArray&, \
|
||||
float, int, int>()) \
|
||||
.def("SetRepeats", &name<type, threads_per_block, vec_size>::SetRepeats) \
|
||||
.def("Profile", &name<type, threads_per_block, vec_size>::Profile) \
|
||||
.def("Run", &name<type, threads_per_block, vec_size>::Run) \
|
||||
#define REGISTER_OP(name, type, threads_per_block, vec_size) \
|
||||
py::class_<name<type, threads_per_block, vec_size>>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \
|
||||
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, \
|
||||
DeviceArray&, DeviceArray&, DeviceArray&, \
|
||||
float, int, int>()) \
|
||||
.def("SetRepeats", &name<type, threads_per_block, vec_size>::SetRepeats) \
|
||||
.def("Profile", &name<type, threads_per_block, vec_size>::Profile) \
|
||||
.def("Run", &name<type, threads_per_block, vec_size>::Run) \
|
||||
.def("IsSupported", &name<type, threads_per_block, vec_size>::IsSupported);
|
||||
|
||||
#define REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, threads_per_block) \
|
||||
|
|
@ -141,7 +141,7 @@ class SkipLayerNormTunable : public IKernelExplorer {
|
|||
#define REGISTER_OP_TYPED(name, type) \
|
||||
py::class_<name<type>>(m, #name "_" #type) \
|
||||
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, \
|
||||
DeviceArray&, DeviceArray&, \
|
||||
DeviceArray&, DeviceArray&, DeviceArray&, \
|
||||
float, int, int>()) \
|
||||
.def("SetRepeats", &name<type>::SetRepeats) \
|
||||
.def("Profile", &name<type>::Profile) \
|
||||
|
|
|
|||
|
|
@ -56,6 +56,9 @@ def run_skip_layer_norm(batch_size: int, seq_len: int, hidden_size: int, dtype:
|
|||
# Becuase of rocm FMAs calculation issue with float16, epsilon should be larger when hidden_size is small
|
||||
epsilon = 0.05 if hidden_size < 8 else 0.0005
|
||||
output_y = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
|
||||
# output_optional = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
|
||||
# enforce nullptr in the backend
|
||||
output_optional = np.empty((0), dtype=dtype)
|
||||
|
||||
input_d = ke.DeviceArray(input_x)
|
||||
skip_d = ke.DeviceArray(skip)
|
||||
|
|
@ -63,8 +66,24 @@ def run_skip_layer_norm(batch_size: int, seq_len: int, hidden_size: int, dtype:
|
|||
gamma_d = ke.DeviceArray(gamma)
|
||||
beta_d = ke.DeviceArray(beta)
|
||||
y_d = ke.DeviceArray(output_y)
|
||||
optional_d = ke.DeviceArray(output_optional)
|
||||
f = getattr(ke, func)
|
||||
my_op = f(y_d, input_d, skip_d, gamma_d, beta_d, bias_d, epsilon, hidden_size, batch_size * seq_len * hidden_size)
|
||||
|
||||
# output_optional is newly added optional output tensor for SkipLayerNorm
|
||||
# Right now we are not testing the tensor like we do for output_y
|
||||
# the test for output_optional could be considered in future as needed
|
||||
my_op = f(
|
||||
y_d,
|
||||
optional_d,
|
||||
input_d,
|
||||
skip_d,
|
||||
gamma_d,
|
||||
beta_d,
|
||||
bias_d,
|
||||
epsilon,
|
||||
hidden_size,
|
||||
batch_size * seq_len * hidden_size,
|
||||
)
|
||||
if my_op.IsSupported():
|
||||
my_op.Run()
|
||||
|
||||
|
|
@ -106,6 +125,9 @@ def profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func):
|
|||
bias = np.random.rand(hidden_size).astype(dtype)
|
||||
epsilon = 0.0005
|
||||
output_y = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
|
||||
# output_optional = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
|
||||
# enforce nullptr in the backend - optional when profiling
|
||||
output_optional = np.empty((0), dtype=dtype)
|
||||
|
||||
input_d = ke.DeviceArray(input_x)
|
||||
skip_d = ke.DeviceArray(skip)
|
||||
|
|
@ -113,8 +135,24 @@ def profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func):
|
|||
beta_d = ke.DeviceArray(beta)
|
||||
bias_d = ke.DeviceArray(bias)
|
||||
y_d = ke.DeviceArray(output_y)
|
||||
optional_d = ke.DeviceArray(output_optional)
|
||||
f = getattr(ke, func)
|
||||
my_op = f(y_d, input_d, skip_d, gamma_d, beta_d, bias_d, epsilon, hidden_size, batch_size * seq_len * hidden_size)
|
||||
|
||||
# output_optional is newly added optional output tensor for SkipLayerNorm
|
||||
# Right now we are not testing the tensor like we do for output_y
|
||||
# the profile for output_optional could be considered in future as needed
|
||||
my_op = f(
|
||||
y_d,
|
||||
optional_d,
|
||||
input_d,
|
||||
skip_d,
|
||||
gamma_d,
|
||||
beta_d,
|
||||
bias_d,
|
||||
epsilon,
|
||||
hidden_size,
|
||||
batch_size * seq_len * hidden_size,
|
||||
)
|
||||
|
||||
duration_ms = -1
|
||||
if my_op.IsSupported():
|
||||
|
|
|
|||
|
|
@ -436,8 +436,8 @@ TEST(SkipLayerNormTest, SkipLayerNormBatch2_Bias) {
|
|||
hidden_size);
|
||||
}
|
||||
|
||||
// Don't enable this test for DML/ROCM builds as these EP doesn't produce the new optional output yet
|
||||
#if !defined(USE_ROCM) && !defined(USE_DML)
|
||||
// Don't enable this test for DML builds as these EP doesn't produce the new optional output yet
|
||||
#if !defined(USE_DML)
|
||||
TEST(SkipLayerNormTest, SkipLayerNormBatch2_Bias_ProducingOptionalOutput) {
|
||||
int batch_size = 2;
|
||||
int sequence_length = 2;
|
||||
|
|
@ -494,6 +494,8 @@ TEST(SkipLayerNormTest, SkipLayerNormBatch2_Bias_ProducingOptionalOutput) {
|
|||
hidden_size);
|
||||
}
|
||||
|
||||
// SkipSimplifiedLayerNorm has not been enabled for ROCm yet
|
||||
#if !defined(USE_ROCM)
|
||||
TEST(SkipLayerNormTest, SkipSimplifiedLayerNormBatch1_Float16) {
|
||||
int batch_size = 1;
|
||||
int sequence_length = 2;
|
||||
|
|
@ -530,6 +532,7 @@ TEST(SkipLayerNormTest, SkipSimplifiedLayerNormBatch1_Float16) {
|
|||
true);
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue