mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-16 01:33:39 +00:00
[ADD] add skip layernorm to kernel explorer for ROCm EP (#12816)
**Description**: Describe your changes. Related PR: https://github.com/microsoft/onnxruntime/pull/12803 https://github.com/microsoft/onnxruntime/pull/12817 https://github.com/microsoft/onnxruntime/pull/12821 Add skip layernorm to kernel explorer for profiling. **Motivation and Context** - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here.
This commit is contained in:
parent
ffeba98a9d
commit
189aef2bea
9 changed files with 407 additions and 127 deletions
|
|
@ -1,8 +1,9 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/rocm/rocm_common.h"
|
||||
#include "contrib_ops/rocm/bert/skip_layer_norm.h"
|
||||
|
||||
#include "core/providers/rocm/rocm_common.h"
|
||||
#include "contrib_ops/rocm/bert/skip_layer_norm_impl.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -93,21 +94,19 @@ Status SkipLayerNorm<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
int sequence_length = static_cast<int>(input_dims[1]);
|
||||
int hidden_size = static_cast<int>(input_dims[2]);
|
||||
int64_t element_count = input_dims[0] * sequence_length * hidden_size;
|
||||
size_t element_size = sizeof(T);
|
||||
typedef typename ToHipType<T>::MappedType HipT;
|
||||
|
||||
return LaunchSkipLayerNormKernel<HipT>(
|
||||
Stream(),
|
||||
reinterpret_cast<HipT*>(output->MutableData<T>()),
|
||||
reinterpret_cast<const HipT*>(input->Data<T>()),
|
||||
reinterpret_cast<const HipT*>(skip->Data<T>()),
|
||||
reinterpret_cast<const HipT*>(gamma->Data<T>()),
|
||||
(beta != nullptr) ? reinterpret_cast<const HipT*>(beta->Data<T>()) : nullptr,
|
||||
(bias != nullptr) ? reinterpret_cast<const HipT*>(bias->Data<T>()) : nullptr,
|
||||
epsilon_,
|
||||
hidden_size,
|
||||
static_cast<int>(element_count),
|
||||
element_size);
|
||||
Stream(),
|
||||
reinterpret_cast<HipT*>(output->MutableData<T>()),
|
||||
reinterpret_cast<const HipT*>(input->Data<T>()),
|
||||
reinterpret_cast<const HipT*>(skip->Data<T>()),
|
||||
reinterpret_cast<const HipT*>(gamma->Data<T>()),
|
||||
(beta != nullptr) ? reinterpret_cast<const HipT*>(beta->Data<T>()) : nullptr,
|
||||
(bias != nullptr) ? reinterpret_cast<const HipT*>(bias->Data<T>()) : nullptr,
|
||||
epsilon_,
|
||||
hidden_size,
|
||||
static_cast<int>(element_count));
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
|
|
|
|||
|
|
@ -28,93 +28,20 @@ limitations under the License.
|
|||
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <hip/hip_fp16.h>
|
||||
#include "contrib_ops/rocm/bert/layer_norm.cuh"
|
||||
#include "contrib_ops/rocm/bert/skip_layer_norm_impl.h"
|
||||
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
#include "contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
template <typename T>
|
||||
T maybe2half(float x);
|
||||
|
||||
template <>
|
||||
float maybe2half(float x) {
|
||||
return x;
|
||||
}
|
||||
|
||||
template <>
|
||||
half maybe2half(float x) {
|
||||
return __float2half_rn(x);
|
||||
}
|
||||
|
||||
template <typename T, unsigned TPB>
|
||||
__global__ void SkipLayerNormKernel(
|
||||
const int ld, const T* input, const T* skip, const T* beta, const T* gamma, const T* bias,
|
||||
const T epsilon, T* output) {
|
||||
const T reverse_ld = T(1.f / ld);
|
||||
const int offset = blockIdx.x * ld;
|
||||
|
||||
KeyValuePairSum pair_sum;
|
||||
// reduce x and x^2
|
||||
hipcub::KeyValuePair<T, T> thread_data(0, 0);
|
||||
|
||||
for (int i = threadIdx.x; i < ld; i += TPB) {
|
||||
const int idx = offset + i;
|
||||
const T val = (bias == nullptr) ? input[idx] + skip[idx] : input[idx] + skip[idx] + bias[i];
|
||||
const T rldval = reverse_ld * val;
|
||||
thread_data = pair_sum(thread_data, hipcub::KeyValuePair<T, T>(rldval, rldval * val));
|
||||
output[idx] = val;
|
||||
}
|
||||
|
||||
LayerNorm<T, TPB>(thread_data, ld, offset, beta, gamma, epsilon, output);
|
||||
}
|
||||
|
||||
// Vectorized kernel
|
||||
template <typename T, unsigned TPB, int ILP>
|
||||
__global__ void SkipLayerNormKernelSmall(
|
||||
const int ld, const T* input, const T* skip, const T* beta, const T* gamma,
|
||||
const T* bias, const T epsilon, T* output, bool hasBias) {
|
||||
const T rld = T(1.f / ld);
|
||||
const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld
|
||||
|
||||
using VecT = aligned_vector<T, ILP>;
|
||||
T input_v[ILP], skip_v[ILP], bias_v[ILP];
|
||||
|
||||
hipcub::KeyValuePair<T, T> thread_data(T(0.f), T(0.f));
|
||||
|
||||
if (ILP * threadIdx.x < ld) {
|
||||
VecT* input_val = reinterpret_cast<VecT*>(&input_v);
|
||||
*input_val = *reinterpret_cast<const VecT*>(&input[idx]);
|
||||
|
||||
VecT* skip_val = reinterpret_cast<VecT*>(&skip_v);
|
||||
*skip_val = *reinterpret_cast<const VecT*>(&skip[idx]);
|
||||
|
||||
if (hasBias) {
|
||||
VecT* bias_val = reinterpret_cast<VecT*>(&bias_v);
|
||||
*bias_val = *reinterpret_cast<const VecT*>(&bias[threadIdx.x * ILP]);
|
||||
}
|
||||
|
||||
T rldval_sum = T(0.f);
|
||||
T rldvalsq_sum = T(0.f);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ILP; i++) {
|
||||
input_v[i] += hasBias ? skip_v[i] + bias_v[i] : skip_v[i];
|
||||
const T rldval = rld * input_v[i];
|
||||
rldval_sum += rldval;
|
||||
rldvalsq_sum += rldval * input_v[i];
|
||||
}
|
||||
thread_data = hipcub::KeyValuePair<T, T>(rldval_sum, rldvalsq_sum);
|
||||
}
|
||||
LayerNormSmall<T, TPB, ILP>(input_v, thread_data, ld, idx, beta, gamma, epsilon, output);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status LaunchSkipLayerNormKernel(
|
||||
hipStream_t stream, T* output, const T* input, const T* skip, const T* gamma,
|
||||
const T* beta, const T* bias, float epsilon, const int ld, const int element_count,
|
||||
size_t element_size) {
|
||||
const T* beta, const T* bias, float epsilon, const int ld, const int element_count) {
|
||||
// this must be true because n is the total size of the tensor
|
||||
assert(element_count % ld == 0);
|
||||
bool hasBias = (bias == nullptr) ? false : true;
|
||||
|
|
@ -122,55 +49,55 @@ Status LaunchSkipLayerNormKernel(
|
|||
const int grid_size = element_count / ld;
|
||||
if (ld <= 32) {
|
||||
constexpr int block_size = 32;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 1>), grid_size, block_size,
|
||||
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
|
||||
SkipLayerNormKernelSmall<T, block_size, 1><<<grid_size, block_size, 0, stream>>>(
|
||||
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
|
||||
} else if (ld <= 64) {
|
||||
constexpr int block_size = 64 / 2;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 2>), grid_size, block_size,
|
||||
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
|
||||
SkipLayerNormKernelSmall<T, block_size, 2><<<grid_size, block_size, 0, stream>>>(
|
||||
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
|
||||
} else if (ld <= 128) {
|
||||
constexpr int block_size = 128 / 4;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 4>), grid_size, block_size,
|
||||
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
|
||||
SkipLayerNormKernelSmall<T, block_size, 4><<<grid_size, block_size, 0, stream>>>(
|
||||
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
|
||||
} else if (ld <= 384) {
|
||||
constexpr int block_size = 384 / 4;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 4>), grid_size, block_size,
|
||||
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
|
||||
SkipLayerNormKernelSmall<T, block_size, 4><<<grid_size, block_size, 0, stream>>>(
|
||||
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
|
||||
} else if (ld <= 768) {
|
||||
constexpr int block_size = 768 / 4;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 4>), grid_size, block_size,
|
||||
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
|
||||
SkipLayerNormKernelSmall<T, block_size, 4><<<grid_size, block_size, 0, stream>>>(
|
||||
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
|
||||
} else if (ld <= 1024) {
|
||||
constexpr int block_size = 1024 / 4;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 4>), grid_size, block_size,
|
||||
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
|
||||
SkipLayerNormKernelSmall<T, block_size, 4><<<grid_size, block_size, 0, stream>>>(
|
||||
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
|
||||
} else {
|
||||
constexpr int block_size = 256;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernel<T, block_size>), grid_size, block_size,
|
||||
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output);
|
||||
SkipLayerNormKernel<T, block_size><<<grid_size, block_size, 0, stream>>>(
|
||||
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output);
|
||||
}
|
||||
} else {
|
||||
const int grid_size = element_count / ld;
|
||||
if (ld <= 32) {
|
||||
constexpr int block_size = 32;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 1>), grid_size, block_size,
|
||||
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
|
||||
SkipLayerNormKernelSmall<T, block_size, 1><<<grid_size, block_size, 0, stream>>>(
|
||||
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
|
||||
} else if (ld <= 64) {
|
||||
constexpr int block_size = 64;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 1>), grid_size, block_size,
|
||||
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
|
||||
SkipLayerNormKernelSmall<T, block_size, 1><<<grid_size, block_size, 0, stream>>>(
|
||||
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
|
||||
} else if (ld <= 128) {
|
||||
constexpr int block_size = 128;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 1>), grid_size, block_size,
|
||||
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
|
||||
SkipLayerNormKernelSmall<T, block_size, 1><<<grid_size, block_size, 0, stream>>>(
|
||||
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
|
||||
} else if (ld == 384) {
|
||||
constexpr int block_size = 384;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 1>), grid_size, block_size,
|
||||
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
|
||||
SkipLayerNormKernelSmall<T, block_size, 1><<<grid_size, block_size, 0, stream>>>(
|
||||
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
|
||||
} else {
|
||||
constexpr int block_size = 256;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernel<T, block_size>), grid_size, block_size,
|
||||
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output);
|
||||
SkipLayerNormKernel<T, block_size><<<grid_size, block_size, 0, stream>>>(
|
||||
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output);
|
||||
}
|
||||
}
|
||||
return HIP_CALL(hipPeekAtLastError());
|
||||
|
|
@ -179,12 +106,12 @@ Status LaunchSkipLayerNormKernel(
|
|||
template Status LaunchSkipLayerNormKernel<float>(hipStream_t stream, float* output, const float* input,
|
||||
const float* skip, const float* gamma, const float* beta,
|
||||
const float* bias, float epsilon, const int ld,
|
||||
const int element_count, size_t element_size);
|
||||
const int element_count);
|
||||
|
||||
template Status LaunchSkipLayerNormKernel<half>(hipStream_t stream, half* output, const half* input,
|
||||
const half* skip, const half* gamma, const half* beta,
|
||||
const half* bias, float epsilon, const int ld,
|
||||
const int element_count, size_t element_size);
|
||||
const int element_count);
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
|
|
|
|||
|
|
@ -11,16 +11,15 @@ namespace rocm {
|
|||
template <typename T>
|
||||
Status LaunchSkipLayerNormKernel(
|
||||
hipStream_t stream,
|
||||
T* output, // output tensor
|
||||
const T* input, // input tensor
|
||||
const T* skip, // skip tensor
|
||||
const T* gamma, // Layer normalization gamma tensor
|
||||
const T* beta, // Layer normalization beta tensor
|
||||
const T* bias, // Layer normalization beta tensor
|
||||
float epsilon, // Layer normalization epsilon
|
||||
int hidden_size, // hidden size, it is the leading dimension (ld)
|
||||
int element_count, // number of elements in input tensor
|
||||
size_t element_size
|
||||
T* output, // output tensor
|
||||
const T* input, // input tensor
|
||||
const T* skip, // skip tensor
|
||||
const T* gamma, // Layer normalization gamma tensor
|
||||
const T* beta, // Layer normalization beta tensor
|
||||
const T* bias, // Layer normalization beta tensor
|
||||
float epsilon, // Layer normalization epsilon
|
||||
int hidden_size, // hidden size, it is the leading dimension (ld)
|
||||
int element_count // number of elements in input tensor
|
||||
);
|
||||
|
||||
} // namespace rocm
|
||||
|
|
|
|||
|
|
@ -0,0 +1,89 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_fp16.h>
|
||||
#include "contrib_ops/rocm/bert/layer_norm.cuh"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
template <typename T>
|
||||
T maybe2half(float x);
|
||||
|
||||
template <>
|
||||
float maybe2half(float x) {
|
||||
return x;
|
||||
}
|
||||
|
||||
template <>
|
||||
half maybe2half(float x) {
|
||||
return __float2half_rn(x);
|
||||
}
|
||||
|
||||
template <typename T, unsigned TPB>
|
||||
__global__ void SkipLayerNormKernel(
|
||||
const int ld, const T* input, const T* skip, const T* beta, const T* gamma, const T* bias,
|
||||
const T epsilon, T* output) {
|
||||
const T reverse_ld = T(1.f / ld);
|
||||
const int offset = blockIdx.x * ld;
|
||||
|
||||
KeyValuePairSum pair_sum;
|
||||
// reduce x and x^2
|
||||
hipcub::KeyValuePair<T, T> thread_data(0, 0);
|
||||
|
||||
for (int i = threadIdx.x; i < ld; i += TPB) {
|
||||
const int idx = offset + i;
|
||||
const T val = (bias == nullptr) ? input[idx] + skip[idx] : input[idx] + skip[idx] + bias[i];
|
||||
const T rldval = reverse_ld * val;
|
||||
thread_data = pair_sum(thread_data, hipcub::KeyValuePair<T, T>(rldval, rldval * val));
|
||||
output[idx] = val;
|
||||
}
|
||||
|
||||
LayerNorm<T, TPB>(thread_data, ld, offset, beta, gamma, epsilon, output);
|
||||
}
|
||||
|
||||
// Vectorized kernel
|
||||
template <typename T, unsigned TPB, int ILP>
|
||||
__global__ void SkipLayerNormKernelSmall(
|
||||
const int ld, const T* input, const T* skip, const T* beta, const T* gamma,
|
||||
const T* bias, const T epsilon, T* output, bool hasBias) {
|
||||
const T rld = T(1.f / ld);
|
||||
const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld
|
||||
|
||||
using VecT = aligned_vector<T, ILP>;
|
||||
T input_v[ILP], skip_v[ILP], bias_v[ILP];
|
||||
|
||||
hipcub::KeyValuePair<T, T> thread_data(T(0.f), T(0.f));
|
||||
|
||||
if (ILP * threadIdx.x < ld) {
|
||||
VecT* input_val = reinterpret_cast<VecT*>(&input_v);
|
||||
*input_val = *reinterpret_cast<const VecT*>(&input[idx]);
|
||||
|
||||
VecT* skip_val = reinterpret_cast<VecT*>(&skip_v);
|
||||
*skip_val = *reinterpret_cast<const VecT*>(&skip[idx]);
|
||||
|
||||
if (hasBias) {
|
||||
VecT* bias_val = reinterpret_cast<VecT*>(&bias_v);
|
||||
*bias_val = *reinterpret_cast<const VecT*>(&bias[threadIdx.x * ILP]);
|
||||
}
|
||||
|
||||
T rldval_sum = T(0.f);
|
||||
T rldvalsq_sum = T(0.f);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ILP; i++) {
|
||||
input_v[i] += hasBias ? skip_v[i] + bias_v[i] : skip_v[i];
|
||||
const T rldval = rld * input_v[i];
|
||||
rldval_sum += rldval;
|
||||
rldvalsq_sum += rldval * input_v[i];
|
||||
}
|
||||
thread_data = hipcub::KeyValuePair<T, T>(rldval_sum, rldvalsq_sum);
|
||||
}
|
||||
LayerNormSmall<T, TPB, ILP>(input_v, thread_data, ld, idx, beta, gamma, epsilon, output);
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
58
onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_op.h
Normal file
58
onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_op.h
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h"
|
||||
#include "contrib_ops/rocm/bert/tunable_op.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
template <typename T>
|
||||
struct SkipLayerNormParams : OpParams {
|
||||
SkipLayerNormParams(hipStream_t stream, T* output, const T* input,
|
||||
const T* skip, const T* gamma, const T* beta,
|
||||
const T* bias, float epsilon, const int ld,
|
||||
const int element_count)
|
||||
: OpParams(stream), output(output), input(input), skip(skip), gamma(gamma), beta(beta), bias(bias),
|
||||
epsilon(epsilon), ld(ld), element_count(element_count) {}
|
||||
|
||||
std::string Signature() const override {
|
||||
std::string sig = std::to_string(ld) + "_" + std::to_string(element_count);
|
||||
return sig;
|
||||
}
|
||||
|
||||
T* output;
|
||||
const T* input;
|
||||
const T* skip;
|
||||
const T* gamma;
|
||||
const T* beta;
|
||||
const T* bias;
|
||||
const float epsilon;
|
||||
const int ld;
|
||||
const int element_count;
|
||||
};
|
||||
|
||||
template <typename T, int ThreadsPerBlock, int VecSize>
|
||||
Status SkipLayerNormSmallOp(const SkipLayerNormParams<T>* params) {
|
||||
TUNABLE_OP_RETURN_UNSUPPOTED_ARGUMENT_IF(
|
||||
!((params->ld <= 1024 && params->ld % VecSize == 0 && params->ld == ThreadsPerBlock * VecSize)));
|
||||
SkipLayerNormKernelSmall<T, ThreadsPerBlock, VecSize><<<dim3(CeilingDivision(params->element_count, params->ld)),
|
||||
dim3(ThreadsPerBlock),
|
||||
0, params->stream>>>(
|
||||
params->ld, params->input, params->skip,
|
||||
params->beta, params->gamma, params->bias, maybe2half<T>(params->epsilon), params->output,
|
||||
(params->bias == nullptr) ? false : true);
|
||||
return HIP_CALL(hipGetLastError());
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -7,6 +7,7 @@
|
|||
#include "python/tools/kernel_explorer/kernels/vector_add.h"
|
||||
#include "python/tools/kernel_explorer/kernels/fast_gelu.h"
|
||||
#include "python/tools/kernel_explorer/kernels/gemm.h"
|
||||
#include "python/tools/kernel_explorer/kernels/skip_layer_norm.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
|
|
@ -19,6 +20,7 @@ PYBIND11_MODULE(_kernel_explorer, m) {
|
|||
InitVectorAdd(m);
|
||||
InitFastGelu(m);
|
||||
InitGemm(m);
|
||||
InitSkipLayerNorm(m);
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -0,0 +1,71 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "python/tools/kernel_explorer/kernels/skip_layer_norm.h"
|
||||
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include "contrib_ops/rocm/bert/skip_layer_norm_op.h"
|
||||
#include "python/tools/kernel_explorer/device_array.h"
|
||||
#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
template <typename T, int ThreadsPerBlock, int VecSize>
|
||||
class SkipLayerNormSmall : public IKernelExplorer {
|
||||
public:
|
||||
SkipLayerNormSmall(DeviceArray& output, DeviceArray& input, DeviceArray& skip,
|
||||
DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias,
|
||||
float epsilon, int hidden_size, int element_count)
|
||||
: params_(this->Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(input.ptr()),
|
||||
static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()), static_cast<T*>(beta.ptr()),
|
||||
static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {}
|
||||
|
||||
void Run() override {
|
||||
ORT_THROW_IF_ERROR((contrib::rocm::SkipLayerNormSmallOp<T, ThreadsPerBlock, VecSize>(¶ms_)));
|
||||
}
|
||||
|
||||
bool IsSupported() {
|
||||
Status status = contrib::rocm::SkipLayerNormSmallOp<T, ThreadsPerBlock, VecSize>(¶ms_);
|
||||
return status.IsOK();
|
||||
}
|
||||
|
||||
private:
|
||||
using ParamsT = contrib::rocm::SkipLayerNormParams<T>;
|
||||
ParamsT params_{};
|
||||
};
|
||||
|
||||
#define REGISTER_OP(name, type, threads_per_block, vec_size) \
|
||||
py::class_<name<type, threads_per_block, vec_size>>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \
|
||||
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, \
|
||||
DeviceArray&, DeviceArray&, \
|
||||
float, int, int>()) \
|
||||
.def("SetRepeats", &name<type, threads_per_block, vec_size>::SetRepeats) \
|
||||
.def("Profile", &name<type, threads_per_block, vec_size>::Profile) \
|
||||
.def("Run", &name<type, threads_per_block, vec_size>::Run) \
|
||||
.def("IsSupported", &name<type, threads_per_block, vec_size>::IsSupported);
|
||||
|
||||
#define REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, threads_per_block) \
|
||||
REGISTER_OP(name, type, threads_per_block, 1) \
|
||||
REGISTER_OP(name, type, threads_per_block, 2) \
|
||||
REGISTER_OP(name, type, threads_per_block, 4) \
|
||||
REGISTER_OP(name, type, threads_per_block, 8) \
|
||||
REGISTER_OP(name, type, threads_per_block, 16)
|
||||
|
||||
#define REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(name, type) \
|
||||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 64) \
|
||||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 128) \
|
||||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 192) \
|
||||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 256) \
|
||||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 320) \
|
||||
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 384)
|
||||
|
||||
void InitSkipLayerNorm(py::module m) {
|
||||
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmall, half);
|
||||
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmall, float);
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
void InitSkipLayerNorm(py::module m);
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,121 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
from itertools import product
|
||||
|
||||
import kernel_explorer as ke
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
def get_bert_sizes():
|
||||
batch_sizes = [1, 8, 64, 128]
|
||||
seq_lens = [64, 128, 256, 384, 512]
|
||||
hidden_sizes = [768, 1024]
|
||||
return product(batch_sizes, seq_lens, hidden_sizes)
|
||||
|
||||
|
||||
def dtype_to_funcs(dtype):
|
||||
type_map = {
|
||||
"float16": list(filter(lambda x: "SkipLayerNormSmall_half" in x, dir(ke))),
|
||||
"float32": list(filter(lambda x: "SkipLayerNormSmall_float" in x, dir(ke))),
|
||||
}
|
||||
return type_map[dtype]
|
||||
|
||||
|
||||
def skip_layer_norm(input_x, skip, bias, gamma, beta, epsilon):
|
||||
val = input_x + skip + bias
|
||||
x_u = np.mean(val, axis=(2,))
|
||||
x_s = np.var(val, axis=(2,))
|
||||
output = val - x_u[..., None]
|
||||
output = output / np.sqrt(x_s + epsilon)[..., None]
|
||||
output = output * gamma + beta
|
||||
return output
|
||||
|
||||
|
||||
def run_skip_layer_norm(batch_size: int, seq_len: int, hidden_size: int, dtype: str, func):
|
||||
np.random.seed(0)
|
||||
input_x = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
|
||||
skip = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
|
||||
bias = np.random.rand(hidden_size).astype(dtype)
|
||||
gamma = np.random.rand(hidden_size).astype(dtype)
|
||||
beta = np.random.rand((hidden_size)).astype(dtype)
|
||||
epsilon = 0.0005
|
||||
output_y = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
|
||||
|
||||
input_d = ke.DeviceArray(input_x)
|
||||
skip_d = ke.DeviceArray(skip)
|
||||
bias_d = ke.DeviceArray(bias)
|
||||
gamma_d = ke.DeviceArray(gamma)
|
||||
beta_d = ke.DeviceArray(beta)
|
||||
y_d = ke.DeviceArray(output_y)
|
||||
my_func = getattr(ke, func)
|
||||
my_op = my_func(
|
||||
y_d, input_d, skip_d, gamma_d, beta_d, bias_d, epsilon, hidden_size, batch_size * seq_len * hidden_size
|
||||
)
|
||||
if my_op.IsSupported():
|
||||
my_op.Run()
|
||||
|
||||
y_d.UpdateHostNumpyArray()
|
||||
|
||||
y_ref = skip_layer_norm(input_x, skip, bias, gamma, beta, epsilon)
|
||||
np.testing.assert_almost_equal(y_ref, output_y, decimal=1e-05)
|
||||
|
||||
|
||||
dtypes = ["float32", "float16"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("bert_sizes", get_bert_sizes())
|
||||
@pytest.mark.parametrize("dtype", dtypes)
|
||||
def test_skip_layer_norm(bert_sizes, dtype):
|
||||
for func in dtype_to_funcs(dtype):
|
||||
print(func)
|
||||
run_skip_layer_norm(*bert_sizes, dtype, func)
|
||||
|
||||
|
||||
def profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func):
|
||||
np.random.seed(0)
|
||||
input_x = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
|
||||
skip = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
|
||||
gamma = np.random.rand(hidden_size).astype(dtype)
|
||||
beta = np.random.rand(hidden_size).astype(dtype)
|
||||
bias = np.random.rand(hidden_size).astype(dtype)
|
||||
epsilon = 0.0005
|
||||
output_y = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
|
||||
|
||||
input_d = ke.DeviceArray(input_x)
|
||||
skip_d = ke.DeviceArray(skip)
|
||||
gamma_d = ke.DeviceArray(gamma)
|
||||
beta_d = ke.DeviceArray(beta)
|
||||
bias_d = ke.DeviceArray(bias)
|
||||
y_d = ke.DeviceArray(output_y)
|
||||
my_func = getattr(ke, func)
|
||||
my_op = my_func(
|
||||
y_d, input_d, skip_d, gamma_d, beta_d, bias_d, epsilon, hidden_size, batch_size * seq_len * hidden_size
|
||||
)
|
||||
if my_op.IsSupported():
|
||||
duration = my_op.Profile()
|
||||
print(
|
||||
dtype,
|
||||
batch_size,
|
||||
seq_len,
|
||||
hidden_size,
|
||||
my_func,
|
||||
f"{duration * 1000:.2f} us",
|
||||
f"{(input_x.size * 3 + bias.size * 3) * input_x.itemsize * 1e3 / duration / 1e9:.2f} GB/s",
|
||||
)
|
||||
|
||||
|
||||
def profile():
|
||||
bert_sizes = get_bert_sizes()
|
||||
for dtype in dtypes:
|
||||
for bert_size in bert_sizes:
|
||||
for func in dtype_to_funcs(dtype):
|
||||
profile_skip_layer_norm_func(*bert_size, dtype, func)
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
profile()
|
||||
Loading…
Reference in a new issue