[ROCm] Support for gpt2-based model inferencing (#14675)

When inferencing real gpt2-based model, found some gaps between CUDA and
ROCm codebase.

The fixes include:

1. minimum code change to fix tensor shape on Attention Op
2. Support optional output tensor with SkipLayerNorm
3. fix a build error found on MI200

---------

Co-authored-by: Ubuntu <ettao@ettao-amd-dev1.zvflicr54joexhdgnhvmxrxygg.phxx.internal.cloudapp.net>
This commit is contained in:
ytaous 2023-02-15 00:16:00 -08:00 committed by GitHub
parent a216c9a3fa
commit d49cea05fa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 175 additions and 95 deletions

View file

@ -15,6 +15,10 @@ namespace onnxruntime {
namespace contrib {
namespace rocm {
constexpr int kPastSequenceLengthInputIndex = 6;
constexpr int kPastInputIndex = 4;
constexpr int kPresentOutputIndex = 1;
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
Attention, \
@ -22,8 +26,10 @@ namespace rocm {
1, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
(*KernelDefBuilder::Create()) \
.MayInplace(kPastInputIndex, kPresentOutputIndex) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.InputMemoryType(OrtMemTypeCPUInput, kPastSequenceLengthInputIndex), \
Attention<T>);
REGISTER_KERNEL_TYPED(float)
@ -40,47 +46,41 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* mask_index = context->Input<Tensor>(3);
const Tensor* past = context->Input<Tensor>(4);
const Tensor* relative_position_bias = context->Input<Tensor>(5);
const Tensor* past_seq_len = context->Input<Tensor>(kPastSequenceLengthInputIndex);
auto& device_prop = GetDeviceProp();
AttentionParameters parameters;
ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(),
weights->Shape(),
bias->Shape(),
mask_index,
past,
relative_position_bias,
nullptr,
device_prop.maxThreadsPerBlock));
// input shape (batch_size, sequence_length, input_hidden_size)
const auto& shape = input->Shape();
int batch_size = static_cast<int>(shape[0]);
int sequence_length = static_cast<int>(shape[1]);
int input_hidden_size = static_cast<int>(shape[2]);
// Note: Scenario where q_hidden_size == k_hidden_size != v_hidden_size is not supported in ROCM EP
// bias shape (3 * hidden_size)
const auto& bias_shape = bias->Shape();
int hidden_size = static_cast<int>(bias_shape[0]) / 3;
int head_size = hidden_size / num_heads_;
&parameters,
device_prop.maxThreadsPerBlock,
past_seq_len));
ORT_ENFORCE(parameters.sequence_length == parameters.kv_sequence_length); // self attention
TensorShapeVector output_shape(3);
output_shape[0] = shape[0];
output_shape[1] = shape[1];
output_shape[2] = static_cast<int64_t>(hidden_size);
output_shape[0] = static_cast<int64_t>(parameters.batch_size);
output_shape[1] = static_cast<int64_t>(parameters.sequence_length);
output_shape[2] = static_cast<int64_t>(parameters.v_hidden_size);
Tensor* output = context->Output(0, output_shape);
int past_sequence_length = 0;
Tensor* present = GetPresent(context, past, batch_size, head_size, sequence_length, past_sequence_length);
std::vector<int64_t> present_dims{
2, parameters.batch_size, parameters.num_heads,
parameters.past_present_share_buffer ? parameters.max_sequence_length : parameters.total_sequence_length,
parameters.head_size};
TensorShape present_shape(present_dims);
Tensor* present = context->Output(kPresentOutputIndex, present_shape);
rocblas_handle rocblas = GetRocblasHandle(context);
constexpr size_t element_size = sizeof(T);
// Use GEMM for fully connection.
int m = batch_size * sequence_length;
int n = 3 * hidden_size;
int k = input_hidden_size;
auto gemm_buffer = GetScratchBuffer<T>(batch_size * sequence_length * 3 * hidden_size * element_size, context->GetComputeStream());
int m = parameters.batch_size * parameters.sequence_length;
int n = (parameters.hidden_size + parameters.hidden_size + parameters.v_hidden_size);
int k = parameters.input_hidden_size;
auto gemm_buffer = GetScratchBuffer<T>(static_cast<size_t>(m) * n, context->GetComputeStream());
typedef typename ToHipType<T>::MappedType HipT;
namespace blas = rocm::tunable::blas;
@ -108,8 +108,12 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
/*beta=*/1.0f,
reinterpret_cast<HipT*>(gemm_buffer.get()), n));
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, num_heads_, head_size,
sequence_length, past_sequence_length);
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size,
parameters.batch_size,
parameters.num_heads,
parameters.head_size,
parameters.sequence_length,
parameters.past_sequence_length);
auto work_space = GetScratchBuffer<void>(workSpaceSize, context->GetComputeStream());
return LaunchAttentionKernel(
@ -118,16 +122,16 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
Stream(context),
rocblas,
element_size,
batch_size,
sequence_length,
num_heads_,
head_size,
past_sequence_length,
is_unidirectional_,
parameters.batch_size,
parameters.sequence_length,
parameters.num_heads,
parameters.head_size,
parameters.past_sequence_length,
parameters.is_unidirectional,
reinterpret_cast<const void*>(gemm_buffer.get()),
nullptr == mask_index ? nullptr : mask_index->Data<int>(),
nullptr == mask_index ? gsl::span<const int64_t>() : mask_index->Shape().GetDims(),
mask_filter_value_,
parameters.mask_filter_value,
nullptr == past ? nullptr : past->Data<T>(),
nullptr == relative_position_bias ? nullptr : relative_position_bias->Data<T>(),
work_space.get(),

View file

@ -43,6 +43,10 @@ Status SkipLayerNorm<T>::ComputeInternal(OpKernelContext* ctx) const {
Tensor* output = ctx->Output(0, input->Shape());
// For inferencing, we support one more optional output which is the sum
// of the input and skip tensors
Tensor* skip_input_bias_add_output = ctx->Output(3, input->Shape());
if (input->Shape() != skip->Shape()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"skip is expected to have same shape as input");
@ -101,6 +105,7 @@ Status SkipLayerNorm<T>::ComputeInternal(OpKernelContext* ctx) const {
GetTuningContext(),
Stream(ctx),
reinterpret_cast<HipT*>(output->MutableData<T>()),
skip_input_bias_add_output != nullptr ? reinterpret_cast<HipT*>(skip_input_bias_add_output->MutableData<T>()) : nullptr,
reinterpret_cast<const HipT*>(input->Data<T>()),
reinterpret_cast<const HipT*>(skip->Data<T>()),
reinterpret_cast<const HipT*>(gamma->Data<T>()),

View file

@ -41,12 +41,12 @@ namespace rocm {
template <typename T>
Status LaunchSkipLayerNormKernel(
RocmTuningContext* tuning_ctx, hipStream_t stream, T* output, const T* input, const T* skip, const T* gamma,
const T* beta, const T* bias, float epsilon, int ld, int element_count) {
RocmTuningContext* tuning_ctx, hipStream_t stream, T* output, T* skip_input_bias_add_output, const T* input,
const T* skip, const T* gamma, const T* beta, const T* bias, float epsilon, int ld, int element_count) {
// this must be true because element_count is the total size of the tensor
assert(element_count % ld == 0);
SkipLayerNormParams<T> params(tuning_ctx, stream, output, input, skip, gamma, beta, bias, epsilon, ld, element_count);
SkipLayerNormParams<T> params(tuning_ctx, stream, output, skip_input_bias_add_output, input, skip, gamma, beta, bias, epsilon, ld, element_count);
if (tuning_ctx->IsTunableOpEnabled()) {
static SkipLayerNormTunableOp<T> op;
@ -57,13 +57,13 @@ Status LaunchSkipLayerNormKernel(
}
template Status LaunchSkipLayerNormKernel<float>(
RocmTuningContext* tuning_ctx, hipStream_t stream, float* output, const float* input,
RocmTuningContext* tuning_ctx, hipStream_t stream, float* output, float* skip_input_bias_add_output, const float* input,
const float* skip, const float* gamma, const float* beta,
const float* bias, float epsilon, int ld,
int element_count);
template Status LaunchSkipLayerNormKernel<half>(
RocmTuningContext* tuning_ctx, hipStream_t stream, half* output, const half* input,
RocmTuningContext* tuning_ctx, hipStream_t stream, half* output, half* skip_input_bias_add_output, const half* input,
const half* skip, const half* gamma, const half* beta,
const half* bias, float epsilon, int ld,
int element_count);

View file

@ -15,6 +15,7 @@ Status LaunchSkipLayerNormKernel(
RocmTuningContext* tuning,
hipStream_t stream,
T* output, // output tensor
T* skip_input_bias_add_output, // optional output tensor
const T* input, // input tensor
const T* skip, // skip tensor
const T* gamma, // Layer normalization gamma tensor

View file

@ -26,7 +26,7 @@ half maybe2half(float x) {
template <typename T, unsigned TPB>
__global__ void SkipLayerNormKernel(
const int ld, const T* input, const T* skip, const T* beta, const T* gamma, const T* bias,
const T epsilon, T* output) {
const T epsilon, T* output, T* skip_input_bias_add_output) {
const T reverse_ld = T(1.f / ld);
const int offset = blockIdx.x * ld;
@ -39,6 +39,11 @@ __global__ void SkipLayerNormKernel(
const T val = (bias == nullptr) ? input[idx] + skip[idx] : input[idx] + skip[idx] + bias[i];
const T rldval = reverse_ld * val;
thread_data = pair_sum(thread_data, hipcub::KeyValuePair<T, T>(rldval, rldval * val));
if (skip_input_bias_add_output != nullptr) {
skip_input_bias_add_output[idx] = val;
}
output[idx] = val;
}
@ -49,7 +54,8 @@ __global__ void SkipLayerNormKernel(
template <typename T, unsigned TPB, int ILP>
__global__ void SkipLayerNormKernelVec(
const int ld, const T* input, const T* skip, const T* beta, const T* gamma,
const T* bias, const T epsilon, T* output, bool hasBias) {
const T* bias, const T epsilon, T* output, T* skip_input_bias_add_output,
bool hasBias, bool hasSkipInputBiasAdditionOutput) {
const T reverse_ld = T(1.f / ld);
const int offset = blockIdx.x * ld;
@ -58,7 +64,7 @@ __global__ void SkipLayerNormKernelVec(
hipcub::KeyValuePair<T, T> thread_data(0, 0);
using VecT = aligned_vector<T, ILP>;
T input_v[ILP], skip_v[ILP], bias_v[ILP];
T input_v[ILP], skip_v[ILP], bias_v[ILP], skip_input_bias_add_output_v[ILP];;
if (threadIdx.x * ILP < ld) {
VecT* input_val = reinterpret_cast<VecT*>(&input_v);
VecT* skip_val = reinterpret_cast<VecT*>(&skip_v);
@ -76,9 +82,19 @@ __global__ void SkipLayerNormKernelVec(
#pragma unroll
for (int k = 0; k < ILP; k++) {
input_v[k] += hasBias ? skip_v[k] + bias_v[k] : skip_v[k];
if (hasSkipInputBiasAdditionOutput) {
skip_input_bias_add_output_v[i] = input_v[i];
}
const T rldval = reverse_ld * input_v[k];
thread_data = pair_sum(thread_data, hipcub::KeyValuePair<T, T>(rldval, rldval * input_v[k]));
}
if (hasSkipInputBiasAdditionOutput) {
*(reinterpret_cast<VecT*>(&skip_input_bias_add_output[idx])) = *reinterpret_cast<VecT*>(&skip_input_bias_add_output_v);
}
*(reinterpret_cast<VecT*>(&output[idx])) = *reinterpret_cast<VecT*>(&input_v[0]);
}
}
@ -90,12 +106,13 @@ __global__ void SkipLayerNormKernelVec(
template <typename T, unsigned TPB, int ILP>
__global__ void SkipLayerNormKernelSmall(
const int ld, const T* input, const T* skip, const T* beta, const T* gamma,
const T* bias, const T epsilon, T* output, bool hasBias) {
const T* bias, const T epsilon, T* output, T* skip_input_bias_add_output,
bool hasBias, bool hasSkipInputBiasAdditionOutput) {
const T rld = T(1.f / ld);
const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld
using VecT = aligned_vector<T, ILP>;
T input_v[ILP], skip_v[ILP], bias_v[ILP];
T input_v[ILP], skip_v[ILP], bias_v[ILP], skip_input_bias_add_output_v[ILP];
hipcub::KeyValuePair<T, T> thread_data(T(0.f), T(0.f));
@ -116,10 +133,20 @@ __global__ void SkipLayerNormKernelSmall(
#pragma unroll
for (int i = 0; i < ILP; i++) {
input_v[i] += hasBias ? skip_v[i] + bias_v[i] : skip_v[i];
if (hasSkipInputBiasAdditionOutput) {
skip_input_bias_add_output_v[i] = input_v[i];
}
const T rldval = rld * input_v[i];
rldval_sum += rldval;
rldvalsq_sum += rldval * input_v[i];
}
if (hasSkipInputBiasAdditionOutput) {
*(reinterpret_cast<VecT*>(&skip_input_bias_add_output[idx])) = *reinterpret_cast<VecT*>(&skip_input_bias_add_output_v);
}
thread_data = hipcub::KeyValuePair<T, T>(rldval_sum, rldvalsq_sum);
}
LayerNormSmall<T, TPB, ILP>(input_v, thread_data, ld, idx, beta, gamma, epsilon, output);

View file

@ -20,11 +20,11 @@ namespace rocm {
template <typename T>
struct SkipLayerNormParams : OpParams {
SkipLayerNormParams(RocmTuningContext* tuning_ctx, hipStream_t stream, T* output, const T* input,
SkipLayerNormParams(RocmTuningContext* tuning_ctx, hipStream_t stream, T* output, T* skip_input_bias_add_output, const T* input,
const T* skip, const T* gamma, const T* beta,
const T* bias, float epsilon, int ld, int element_count)
: OpParams(tuning_ctx, stream), output(output), input(input), skip(skip), gamma(gamma), beta(beta), bias(bias),
epsilon(epsilon), ld(ld), element_count(element_count) {}
: OpParams(tuning_ctx, stream), output(output), skip_input_bias_add_output(skip_input_bias_add_output), input(input), skip(skip),
gamma(gamma), beta(beta), bias(bias), epsilon(epsilon), ld(ld), element_count(element_count) {}
std::string Signature() const override {
std::string sig = std::to_string(ld) + "_" + std::to_string(element_count);
@ -32,6 +32,7 @@ struct SkipLayerNormParams : OpParams {
}
T* output;
T* skip_input_bias_add_output;
const T* input;
const T* skip;
const T* gamma;
@ -51,8 +52,8 @@ Status SkipLayerNormSmallOp(const SkipLayerNormParams<T>* params) {
dim3(ThreadsPerBlock),
0, params->stream>>>(
params->ld, params->input, params->skip,
params->beta, params->gamma, params->bias, maybe2half<T>(params->epsilon), params->output,
(params->bias == nullptr) ? false : true);
params->beta, params->gamma, params->bias, maybe2half<T>(params->epsilon), params->output, params->skip_input_bias_add_output,
(params->bias == nullptr) ? false : true, (params->skip_input_bias_add_output == nullptr) ? false : true);
return HIP_CALL(hipGetLastError());
}
@ -66,51 +67,52 @@ Status SkipLayerNormRegularOp(const SkipLayerNormParams<T>* params) {
dim3(ThreadsPerBlock),
0, params->stream>>>(
params->ld, params->input, params->skip,
params->beta, params->gamma, params->bias, maybe2half<T>(params->epsilon), params->output,
(params->bias == nullptr) ? false : true);
params->beta, params->gamma, params->bias, maybe2half<T>(params->epsilon), params->output, params->skip_input_bias_add_output,
(params->bias == nullptr) ? false : true, (params->skip_input_bias_add_output == nullptr) ? false : true);
return HIP_CALL(hipGetLastError());
}
template <typename T>
Status SkipLayerNormStaticSelection(const SkipLayerNormParams<T>* params) {
bool hasBias = (params->bias == nullptr) ? false : true;
bool hasSkipInputBiasAdditionOutput = (params->skip_input_bias_add_output == nullptr) ? false : true;
if (0 == (params->ld % 4)) {
const int grid_size = params->element_count / params->ld;
if (params->ld <= 32) {
constexpr int block_size = 32;
SkipLayerNormKernelSmall<T, block_size, 1><<<grid_size, block_size, 0, params->stream>>>(
params->ld, params->input, params->skip, params->beta, params->gamma, params->bias,
maybe2half<T>(params->epsilon), params->output, hasBias);
maybe2half<T>(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput);
} else if (params->ld <= 64) {
constexpr int block_size = 64 / 2;
SkipLayerNormKernelSmall<T, block_size, 2><<<grid_size, block_size, 0, params->stream>>>(
params->ld, params->input, params->skip, params->beta, params->gamma, params->bias,
maybe2half<T>(params->epsilon), params->output, hasBias);
maybe2half<T>(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput);
} else if (params->ld <= 128) {
constexpr int block_size = 128 / 4;
SkipLayerNormKernelSmall<T, block_size, 4><<<grid_size, block_size, 0, params->stream>>>(
params->ld, params->input, params->skip, params->beta, params->gamma, params->bias,
maybe2half<T>(params->epsilon), params->output, hasBias);
maybe2half<T>(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput);
} else if (params->ld <= 384) {
constexpr int block_size = 384 / 4;
SkipLayerNormKernelSmall<T, block_size, 4><<<grid_size, block_size, 0, params->stream>>>(
params->ld, params->input, params->skip, params->beta, params->gamma, params->bias,
maybe2half<T>(params->epsilon), params->output, hasBias);
maybe2half<T>(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput);
} else if (params->ld <= 768) {
constexpr int block_size = 768 / 4;
SkipLayerNormKernelSmall<T, block_size, 4><<<grid_size, block_size, 0, params->stream>>>(
params->ld, params->input, params->skip, params->beta, params->gamma, params->bias,
maybe2half<T>(params->epsilon), params->output, hasBias);
maybe2half<T>(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput);
} else if (params->ld <= 1024) {
constexpr int block_size = 1024 / 4;
SkipLayerNormKernelSmall<T, block_size, 4><<<grid_size, block_size, 0, params->stream>>>(
params->ld, params->input, params->skip, params->beta, params->gamma, params->bias,
maybe2half<T>(params->epsilon), params->output, hasBias);
maybe2half<T>(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput);
} else {
constexpr int block_size = 256;
SkipLayerNormKernel<T, block_size><<<grid_size, block_size, 0, params->stream>>>(
params->ld, params->input, params->skip, params->beta, params->gamma, params->bias,
maybe2half<T>(params->epsilon), params->output);
maybe2half<T>(params->epsilon), params->output, params->skip_input_bias_add_output);
}
} else {
const int grid_size = params->element_count / params->ld;
@ -118,27 +120,27 @@ Status SkipLayerNormStaticSelection(const SkipLayerNormParams<T>* params) {
constexpr int block_size = 32;
SkipLayerNormKernelSmall<T, block_size, 1><<<grid_size, block_size, 0, params->stream>>>(
params->ld, params->input, params->skip, params->beta, params->gamma, params->bias,
maybe2half<T>(params->epsilon), params->output, hasBias);
maybe2half<T>(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput);
} else if (params->ld <= 64) {
constexpr int block_size = 64;
SkipLayerNormKernelSmall<T, block_size, 1><<<grid_size, block_size, 0, params->stream>>>(
params->ld, params->input, params->skip, params->beta, params->gamma, params->bias,
maybe2half<T>(params->epsilon), params->output, hasBias);
maybe2half<T>(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput);
} else if (params->ld <= 128) {
constexpr int block_size = 128;
SkipLayerNormKernelSmall<T, block_size, 1><<<grid_size, block_size, 0, params->stream>>>(
params->ld, params->input, params->skip, params->beta, params->gamma, params->bias,
maybe2half<T>(params->epsilon), params->output, hasBias);
maybe2half<T>(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput);
} else if (params->ld == 384) {
constexpr int block_size = 384;
SkipLayerNormKernelSmall<T, block_size, 1><<<grid_size, block_size, 0, params->stream>>>(
params->ld, params->input, params->skip, params->beta, params->gamma, params->bias,
maybe2half<T>(params->epsilon), params->output, hasBias);
maybe2half<T>(params->epsilon), params->output, params->skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput);
} else {
constexpr int block_size = 256;
SkipLayerNormKernel<T, block_size><<<grid_size, block_size, 0, params->stream>>>(
params->ld, params->input, params->skip, params->beta, params->gamma, params->bias,
maybe2half<T>(params->epsilon), params->output);
maybe2half<T>(params->epsilon), params->output, params->skip_input_bias_add_output);
}
}
return HIP_CALL(hipPeekAtLastError());

View file

@ -17,12 +17,12 @@ namespace onnxruntime {
template <typename T, int ThreadsPerBlock, int VecSize>
class SkipLayerNormSmall : public IKernelExplorer {
public:
SkipLayerNormSmall(DeviceArray& output, DeviceArray& input, DeviceArray& skip,
SkipLayerNormSmall(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input, DeviceArray& skip,
DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias,
float epsilon, int hidden_size, int element_count)
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(input.ptr()),
static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()), static_cast<T*>(beta.ptr()),
static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {}
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(skip_input_bias_add_output.ptr()),
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()),
static_cast<T*>(beta.ptr()), static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {}
void Run() override {
ORT_THROW_IF_ERROR((contrib::rocm::SkipLayerNormSmallOp<T, ThreadsPerBlock, VecSize>(&params_)));
@ -41,12 +41,12 @@ class SkipLayerNormSmall : public IKernelExplorer {
template <typename T, int ThreadsPerBlock, int VecSize>
class SkipLayerNormRegular : public IKernelExplorer {
public:
SkipLayerNormRegular(DeviceArray& output, DeviceArray& input, DeviceArray& skip,
SkipLayerNormRegular(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input, DeviceArray& skip,
DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias,
float epsilon, int hidden_size, int element_count)
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(input.ptr()),
static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()), static_cast<T*>(beta.ptr()),
static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {}
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(skip_input_bias_add_output.ptr()),
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()),
static_cast<T*>(beta.ptr()), static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {}
void Run() override {
ORT_THROW_IF_ERROR((contrib::rocm::SkipLayerNormRegularOp<T, ThreadsPerBlock, VecSize>(&params_)));
@ -65,12 +65,12 @@ class SkipLayerNormRegular : public IKernelExplorer {
template <typename T>
class SkipLayerNormStaticSelection : public IKernelExplorer {
public:
SkipLayerNormStaticSelection(DeviceArray& output, DeviceArray& input, DeviceArray& skip,
DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias,
SkipLayerNormStaticSelection(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input,
DeviceArray& skip, DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias,
float epsilon, int hidden_size, int element_count)
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(input.ptr()),
static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()), static_cast<T*>(beta.ptr()),
static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {}
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(skip_input_bias_add_output.ptr()),
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()),
static_cast<T*>(beta.ptr()), static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {}
void Run() override {
ORT_THROW_IF_ERROR((contrib::rocm::SkipLayerNormStaticSelection<T>(&params_)));
@ -89,12 +89,12 @@ class SkipLayerNormStaticSelection : public IKernelExplorer {
template <typename T>
class SkipLayerNormTunable : public IKernelExplorer {
public:
SkipLayerNormTunable(DeviceArray& output, DeviceArray& input, DeviceArray& skip,
SkipLayerNormTunable(DeviceArray& output, DeviceArray& skip_input_bias_add_output, DeviceArray& input, DeviceArray& skip,
DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias,
float epsilon, int hidden_size, int element_count)
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(input.ptr()),
static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()), static_cast<T*>(beta.ptr()),
static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {
: params_(TuningContext(), Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(skip_input_bias_add_output.ptr()),
static_cast<T*>(input.ptr()), static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()),
static_cast<T*>(beta.ptr()), static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {
params_.TuningContext()->EnableTunableOp();
}
@ -113,14 +113,14 @@ class SkipLayerNormTunable : public IKernelExplorer {
contrib::rocm::SkipLayerNormTunableOp<T> op_{};
};
#define REGISTER_OP(name, type, threads_per_block, vec_size) \
py::class_<name<type, threads_per_block, vec_size>>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, \
DeviceArray&, DeviceArray&, \
float, int, int>()) \
.def("SetRepeats", &name<type, threads_per_block, vec_size>::SetRepeats) \
.def("Profile", &name<type, threads_per_block, vec_size>::Profile) \
.def("Run", &name<type, threads_per_block, vec_size>::Run) \
#define REGISTER_OP(name, type, threads_per_block, vec_size) \
py::class_<name<type, threads_per_block, vec_size>>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, \
DeviceArray&, DeviceArray&, DeviceArray&, \
float, int, int>()) \
.def("SetRepeats", &name<type, threads_per_block, vec_size>::SetRepeats) \
.def("Profile", &name<type, threads_per_block, vec_size>::Profile) \
.def("Run", &name<type, threads_per_block, vec_size>::Run) \
.def("IsSupported", &name<type, threads_per_block, vec_size>::IsSupported);
#define REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, threads_per_block) \
@ -141,7 +141,7 @@ class SkipLayerNormTunable : public IKernelExplorer {
#define REGISTER_OP_TYPED(name, type) \
py::class_<name<type>>(m, #name "_" #type) \
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, \
DeviceArray&, DeviceArray&, \
DeviceArray&, DeviceArray&, DeviceArray&, \
float, int, int>()) \
.def("SetRepeats", &name<type>::SetRepeats) \
.def("Profile", &name<type>::Profile) \

View file

@ -56,6 +56,9 @@ def run_skip_layer_norm(batch_size: int, seq_len: int, hidden_size: int, dtype:
# Becuase of rocm FMAs calculation issue with float16, epsilon should be larger when hidden_size is small
epsilon = 0.05 if hidden_size < 8 else 0.0005
output_y = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
# output_optional = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
# enforce nullptr in the backend
output_optional = np.empty((0), dtype=dtype)
input_d = ke.DeviceArray(input_x)
skip_d = ke.DeviceArray(skip)
@ -63,8 +66,24 @@ def run_skip_layer_norm(batch_size: int, seq_len: int, hidden_size: int, dtype:
gamma_d = ke.DeviceArray(gamma)
beta_d = ke.DeviceArray(beta)
y_d = ke.DeviceArray(output_y)
optional_d = ke.DeviceArray(output_optional)
f = getattr(ke, func)
my_op = f(y_d, input_d, skip_d, gamma_d, beta_d, bias_d, epsilon, hidden_size, batch_size * seq_len * hidden_size)
# output_optional is newly added optional output tensor for SkipLayerNorm
# Right now we are not testing the tensor like we do for output_y
# the test for output_optional could be considered in future as needed
my_op = f(
y_d,
optional_d,
input_d,
skip_d,
gamma_d,
beta_d,
bias_d,
epsilon,
hidden_size,
batch_size * seq_len * hidden_size,
)
if my_op.IsSupported():
my_op.Run()
@ -106,6 +125,9 @@ def profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func):
bias = np.random.rand(hidden_size).astype(dtype)
epsilon = 0.0005
output_y = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
# output_optional = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
# enforce nullptr in the backend - optional when profiling
output_optional = np.empty((0), dtype=dtype)
input_d = ke.DeviceArray(input_x)
skip_d = ke.DeviceArray(skip)
@ -113,8 +135,24 @@ def profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func):
beta_d = ke.DeviceArray(beta)
bias_d = ke.DeviceArray(bias)
y_d = ke.DeviceArray(output_y)
optional_d = ke.DeviceArray(output_optional)
f = getattr(ke, func)
my_op = f(y_d, input_d, skip_d, gamma_d, beta_d, bias_d, epsilon, hidden_size, batch_size * seq_len * hidden_size)
# output_optional is newly added optional output tensor for SkipLayerNorm
# Right now we are not testing the tensor like we do for output_y
# the profile for output_optional could be considered in future as needed
my_op = f(
y_d,
optional_d,
input_d,
skip_d,
gamma_d,
beta_d,
bias_d,
epsilon,
hidden_size,
batch_size * seq_len * hidden_size,
)
duration_ms = -1
if my_op.IsSupported():

View file

@ -436,8 +436,8 @@ TEST(SkipLayerNormTest, SkipLayerNormBatch2_Bias) {
hidden_size);
}
// Don't enable this test for DML/ROCM builds as these EP doesn't produce the new optional output yet
#if !defined(USE_ROCM) && !defined(USE_DML)
// Don't enable this test for DML builds as these EP doesn't produce the new optional output yet
#if !defined(USE_DML)
TEST(SkipLayerNormTest, SkipLayerNormBatch2_Bias_ProducingOptionalOutput) {
int batch_size = 2;
int sequence_length = 2;
@ -494,6 +494,8 @@ TEST(SkipLayerNormTest, SkipLayerNormBatch2_Bias_ProducingOptionalOutput) {
hidden_size);
}
// SkipSimplifiedLayerNorm has not been enabled for ROCm yet
#if !defined(USE_ROCM)
TEST(SkipLayerNormTest, SkipSimplifiedLayerNormBatch1_Float16) {
int batch_size = 1;
int sequence_length = 2;
@ -530,6 +532,7 @@ TEST(SkipLayerNormTest, SkipSimplifiedLayerNormBatch1_Float16) {
true);
}
#endif
#endif
} // namespace test
} // namespace onnxruntime