diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc index 4807c34355..b8c6413c37 100644 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc @@ -1,8 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/rocm/rocm_common.h" #include "contrib_ops/rocm/bert/skip_layer_norm.h" + +#include "core/providers/rocm/rocm_common.h" #include "contrib_ops/rocm/bert/skip_layer_norm_impl.h" namespace onnxruntime { @@ -93,21 +94,19 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const { int sequence_length = static_cast(input_dims[1]); int hidden_size = static_cast(input_dims[2]); int64_t element_count = input_dims[0] * sequence_length * hidden_size; - size_t element_size = sizeof(T); typedef typename ToHipType::MappedType HipT; return LaunchSkipLayerNormKernel( - Stream(), - reinterpret_cast(output->MutableData()), - reinterpret_cast(input->Data()), - reinterpret_cast(skip->Data()), - reinterpret_cast(gamma->Data()), - (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, - (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, - epsilon_, - hidden_size, - static_cast(element_count), - element_size); + Stream(), + reinterpret_cast(output->MutableData()), + reinterpret_cast(input->Data()), + reinterpret_cast(skip->Data()), + reinterpret_cast(gamma->Data()), + (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, + (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, + epsilon_, + hidden_size, + static_cast(element_count)); } } // namespace rocm diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu index 73057aa9a4..120ccd9162 100644 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu @@ -28,93 +28,20 @@ limitations under the License. // Copyright (c) Advanced Micro Devices, Inc. All rights reserved. // Licensed under the MIT License. -#include -#include "contrib_ops/rocm/bert/layer_norm.cuh" #include "contrib_ops/rocm/bert/skip_layer_norm_impl.h" +#include + +#include "contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h" + namespace onnxruntime { namespace contrib { namespace rocm { -template -T maybe2half(float x); - -template <> -float maybe2half(float x) { - return x; -} - -template <> -half maybe2half(float x) { - return __float2half_rn(x); -} - -template -__global__ void SkipLayerNormKernel( - const int ld, const T* input, const T* skip, const T* beta, const T* gamma, const T* bias, - const T epsilon, T* output) { - const T reverse_ld = T(1.f / ld); - const int offset = blockIdx.x * ld; - - KeyValuePairSum pair_sum; - // reduce x and x^2 - hipcub::KeyValuePair thread_data(0, 0); - - for (int i = threadIdx.x; i < ld; i += TPB) { - const int idx = offset + i; - const T val = (bias == nullptr) ? input[idx] + skip[idx] : input[idx] + skip[idx] + bias[i]; - const T rldval = reverse_ld * val; - thread_data = pair_sum(thread_data, hipcub::KeyValuePair(rldval, rldval * val)); - output[idx] = val; - } - - LayerNorm(thread_data, ld, offset, beta, gamma, epsilon, output); -} - -// Vectorized kernel -template -__global__ void SkipLayerNormKernelSmall( - const int ld, const T* input, const T* skip, const T* beta, const T* gamma, - const T* bias, const T epsilon, T* output, bool hasBias) { - const T rld = T(1.f / ld); - const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld - - using VecT = aligned_vector; - T input_v[ILP], skip_v[ILP], bias_v[ILP]; - - hipcub::KeyValuePair thread_data(T(0.f), T(0.f)); - - if (ILP * threadIdx.x < ld) { - VecT* input_val = reinterpret_cast(&input_v); - *input_val = *reinterpret_cast(&input[idx]); - - VecT* skip_val = reinterpret_cast(&skip_v); - *skip_val = *reinterpret_cast(&skip[idx]); - - if (hasBias) { - VecT* bias_val = reinterpret_cast(&bias_v); - *bias_val = *reinterpret_cast(&bias[threadIdx.x * ILP]); - } - - T rldval_sum = T(0.f); - T rldvalsq_sum = T(0.f); - #pragma unroll - for (int i = 0; i < ILP; i++) { - input_v[i] += hasBias ? skip_v[i] + bias_v[i] : skip_v[i]; - const T rldval = rld * input_v[i]; - rldval_sum += rldval; - rldvalsq_sum += rldval * input_v[i]; - } - thread_data = hipcub::KeyValuePair(rldval_sum, rldvalsq_sum); - } - LayerNormSmall(input_v, thread_data, ld, idx, beta, gamma, epsilon, output); -} - template Status LaunchSkipLayerNormKernel( hipStream_t stream, T* output, const T* input, const T* skip, const T* gamma, - const T* beta, const T* bias, float epsilon, const int ld, const int element_count, - size_t element_size) { + const T* beta, const T* bias, float epsilon, const int ld, const int element_count) { // this must be true because n is the total size of the tensor assert(element_count % ld == 0); bool hasBias = (bias == nullptr) ? false : true; @@ -122,55 +49,55 @@ Status LaunchSkipLayerNormKernel( const int grid_size = element_count / ld; if (ld <= 32) { constexpr int block_size = 32; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall), grid_size, block_size, - 0, stream, ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, hasBias); + SkipLayerNormKernelSmall<<>>( + ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, hasBias); } else if (ld <= 64) { constexpr int block_size = 64 / 2; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall), grid_size, block_size, - 0, stream, ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, hasBias); + SkipLayerNormKernelSmall<<>>( + ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, hasBias); } else if (ld <= 128) { constexpr int block_size = 128 / 4; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall), grid_size, block_size, - 0, stream, ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, hasBias); + SkipLayerNormKernelSmall<<>>( + ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, hasBias); } else if (ld <= 384) { constexpr int block_size = 384 / 4; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall), grid_size, block_size, - 0, stream, ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, hasBias); + SkipLayerNormKernelSmall<<>>( + ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, hasBias); } else if (ld <= 768) { constexpr int block_size = 768 / 4; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall), grid_size, block_size, - 0, stream, ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, hasBias); + SkipLayerNormKernelSmall<<>>( + ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, hasBias); } else if (ld <= 1024) { constexpr int block_size = 1024 / 4; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall), grid_size, block_size, - 0, stream, ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, hasBias); + SkipLayerNormKernelSmall<<>>( + ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, hasBias); } else { constexpr int block_size = 256; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernel), grid_size, block_size, - 0, stream, ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output); + SkipLayerNormKernel<<>>( + ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output); } } else { const int grid_size = element_count / ld; if (ld <= 32) { constexpr int block_size = 32; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall), grid_size, block_size, - 0, stream, ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, hasBias); + SkipLayerNormKernelSmall<<>>( + ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, hasBias); } else if (ld <= 64) { constexpr int block_size = 64; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall), grid_size, block_size, - 0, stream, ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, hasBias); + SkipLayerNormKernelSmall<<>>( + ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, hasBias); } else if (ld <= 128) { constexpr int block_size = 128; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall), grid_size, block_size, - 0, stream, ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, hasBias); + SkipLayerNormKernelSmall<<>>( + ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, hasBias); } else if (ld == 384) { constexpr int block_size = 384; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall), grid_size, block_size, - 0, stream, ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, hasBias); + SkipLayerNormKernelSmall<<>>( + ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, hasBias); } else { constexpr int block_size = 256; - hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernel), grid_size, block_size, - 0, stream, ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output); + SkipLayerNormKernel<<>>( + ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output); } } return HIP_CALL(hipPeekAtLastError()); @@ -179,12 +106,12 @@ Status LaunchSkipLayerNormKernel( template Status LaunchSkipLayerNormKernel(hipStream_t stream, float* output, const float* input, const float* skip, const float* gamma, const float* beta, const float* bias, float epsilon, const int ld, - const int element_count, size_t element_size); + const int element_count); template Status LaunchSkipLayerNormKernel(hipStream_t stream, half* output, const half* input, const half* skip, const half* gamma, const half* beta, const half* bias, float epsilon, const int ld, - const int element_count, size_t element_size); + const int element_count); } // namespace rocm } // namespace contrib diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h index d866a0bde6..727116a50f 100644 --- a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h @@ -11,16 +11,15 @@ namespace rocm { template Status LaunchSkipLayerNormKernel( hipStream_t stream, - T* output, // output tensor - const T* input, // input tensor - const T* skip, // skip tensor - const T* gamma, // Layer normalization gamma tensor - const T* beta, // Layer normalization beta tensor - const T* bias, // Layer normalization beta tensor - float epsilon, // Layer normalization epsilon - int hidden_size, // hidden size, it is the leading dimension (ld) - int element_count, // number of elements in input tensor - size_t element_size + T* output, // output tensor + const T* input, // input tensor + const T* skip, // skip tensor + const T* gamma, // Layer normalization gamma tensor + const T* beta, // Layer normalization beta tensor + const T* bias, // Layer normalization beta tensor + float epsilon, // Layer normalization epsilon + int hidden_size, // hidden size, it is the leading dimension (ld) + int element_count // number of elements in input tensor ); } // namespace rocm diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h new file mode 100644 index 0000000000..9af848f36b --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "contrib_ops/rocm/bert/layer_norm.cuh" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +template +T maybe2half(float x); + +template <> +float maybe2half(float x) { + return x; +} + +template <> +half maybe2half(float x) { + return __float2half_rn(x); +} + +template +__global__ void SkipLayerNormKernel( + const int ld, const T* input, const T* skip, const T* beta, const T* gamma, const T* bias, + const T epsilon, T* output) { + const T reverse_ld = T(1.f / ld); + const int offset = blockIdx.x * ld; + + KeyValuePairSum pair_sum; + // reduce x and x^2 + hipcub::KeyValuePair thread_data(0, 0); + + for (int i = threadIdx.x; i < ld; i += TPB) { + const int idx = offset + i; + const T val = (bias == nullptr) ? input[idx] + skip[idx] : input[idx] + skip[idx] + bias[i]; + const T rldval = reverse_ld * val; + thread_data = pair_sum(thread_data, hipcub::KeyValuePair(rldval, rldval * val)); + output[idx] = val; + } + + LayerNorm(thread_data, ld, offset, beta, gamma, epsilon, output); +} + +// Vectorized kernel +template +__global__ void SkipLayerNormKernelSmall( + const int ld, const T* input, const T* skip, const T* beta, const T* gamma, + const T* bias, const T epsilon, T* output, bool hasBias) { + const T rld = T(1.f / ld); + const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld + + using VecT = aligned_vector; + T input_v[ILP], skip_v[ILP], bias_v[ILP]; + + hipcub::KeyValuePair thread_data(T(0.f), T(0.f)); + + if (ILP * threadIdx.x < ld) { + VecT* input_val = reinterpret_cast(&input_v); + *input_val = *reinterpret_cast(&input[idx]); + + VecT* skip_val = reinterpret_cast(&skip_v); + *skip_val = *reinterpret_cast(&skip[idx]); + + if (hasBias) { + VecT* bias_val = reinterpret_cast(&bias_v); + *bias_val = *reinterpret_cast(&bias[threadIdx.x * ILP]); + } + + T rldval_sum = T(0.f); + T rldvalsq_sum = T(0.f); + #pragma unroll + for (int i = 0; i < ILP; i++) { + input_v[i] += hasBias ? skip_v[i] + bias_v[i] : skip_v[i]; + const T rldval = rld * input_v[i]; + rldval_sum += rldval; + rldvalsq_sum += rldval * input_v[i]; + } + thread_data = hipcub::KeyValuePair(rldval_sum, rldvalsq_sum); + } + LayerNormSmall(input_v, thread_data, ld, idx, beta, gamma, epsilon, output); +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_op.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_op.h new file mode 100644 index 0000000000..36b99243e9 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_op.h @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include + +#include "contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h" +#include "contrib_ops/rocm/bert/tunable_op.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +template +struct SkipLayerNormParams : OpParams { + SkipLayerNormParams(hipStream_t stream, T* output, const T* input, + const T* skip, const T* gamma, const T* beta, + const T* bias, float epsilon, const int ld, + const int element_count) + : OpParams(stream), output(output), input(input), skip(skip), gamma(gamma), beta(beta), bias(bias), + epsilon(epsilon), ld(ld), element_count(element_count) {} + + std::string Signature() const override { + std::string sig = std::to_string(ld) + "_" + std::to_string(element_count); + return sig; + } + + T* output; + const T* input; + const T* skip; + const T* gamma; + const T* beta; + const T* bias; + const float epsilon; + const int ld; + const int element_count; +}; + +template +Status SkipLayerNormSmallOp(const SkipLayerNormParams* params) { + TUNABLE_OP_RETURN_UNSUPPOTED_ARGUMENT_IF( + !((params->ld <= 1024 && params->ld % VecSize == 0 && params->ld == ThreadsPerBlock * VecSize))); + SkipLayerNormKernelSmall<<element_count, params->ld)), + dim3(ThreadsPerBlock), + 0, params->stream>>>( + params->ld, params->input, params->skip, + params->beta, params->gamma, params->bias, maybe2half(params->epsilon), params->output, + (params->bias == nullptr) ? false : true); + return HIP_CALL(hipGetLastError()); +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernel_explorer.cc b/onnxruntime/python/tools/kernel_explorer/kernel_explorer.cc index fdb6a2855c..aaa21ea039 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernel_explorer.cc +++ b/onnxruntime/python/tools/kernel_explorer/kernel_explorer.cc @@ -7,6 +7,7 @@ #include "python/tools/kernel_explorer/kernels/vector_add.h" #include "python/tools/kernel_explorer/kernels/fast_gelu.h" #include "python/tools/kernel_explorer/kernels/gemm.h" +#include "python/tools/kernel_explorer/kernels/skip_layer_norm.h" namespace py = pybind11; @@ -19,6 +20,7 @@ PYBIND11_MODULE(_kernel_explorer, m) { InitVectorAdd(m); InitFastGelu(m); InitGemm(m); + InitSkipLayerNorm(m); } } // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm.cc b/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm.cc new file mode 100644 index 0000000000..8fcd92159c --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm.cc @@ -0,0 +1,71 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "python/tools/kernel_explorer/kernels/skip_layer_norm.h" + +#include +#include + +#include "contrib_ops/rocm/bert/skip_layer_norm_op.h" +#include "python/tools/kernel_explorer/device_array.h" +#include "python/tools/kernel_explorer/kernel_explorer_interface.h" + +namespace py = pybind11; + +namespace onnxruntime { + +template +class SkipLayerNormSmall : public IKernelExplorer { + public: + SkipLayerNormSmall(DeviceArray& output, DeviceArray& input, DeviceArray& skip, + DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias, + float epsilon, int hidden_size, int element_count) + : params_(this->Stream(), static_cast(output.ptr()), static_cast(input.ptr()), + static_cast(skip.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), + static_cast(bias.ptr()), epsilon, hidden_size, element_count) {} + + void Run() override { + ORT_THROW_IF_ERROR((contrib::rocm::SkipLayerNormSmallOp(¶ms_))); + } + + bool IsSupported() { + Status status = contrib::rocm::SkipLayerNormSmallOp(¶ms_); + return status.IsOK(); + } + + private: + using ParamsT = contrib::rocm::SkipLayerNormParams; + ParamsT params_{}; +}; + +#define REGISTER_OP(name, type, threads_per_block, vec_size) \ + py::class_>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \ + .def(py::init()) \ + .def("SetRepeats", &name::SetRepeats) \ + .def("Profile", &name::Profile) \ + .def("Run", &name::Run) \ + .def("IsSupported", &name::IsSupported); + +#define REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, threads_per_block) \ + REGISTER_OP(name, type, threads_per_block, 1) \ + REGISTER_OP(name, type, threads_per_block, 2) \ + REGISTER_OP(name, type, threads_per_block, 4) \ + REGISTER_OP(name, type, threads_per_block, 8) \ + REGISTER_OP(name, type, threads_per_block, 16) + +#define REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(name, type) \ + REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 64) \ + REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 128) \ + REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 192) \ + REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 256) \ + REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 320) \ + REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 384) + +void InitSkipLayerNorm(py::module m) { + REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmall, half); + REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmall, float); +} + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm.h b/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm.h new file mode 100644 index 0000000000..4dc72bc3da --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +namespace py = pybind11; + +namespace onnxruntime { + +void InitSkipLayerNorm(py::module m); + +} diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm_test.py new file mode 100644 index 0000000000..5cbadcb73c --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/skip_layer_norm_test.py @@ -0,0 +1,121 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from itertools import product + +import kernel_explorer as ke +import numpy as np +import pytest + + +def get_bert_sizes(): + batch_sizes = [1, 8, 64, 128] + seq_lens = [64, 128, 256, 384, 512] + hidden_sizes = [768, 1024] + return product(batch_sizes, seq_lens, hidden_sizes) + + +def dtype_to_funcs(dtype): + type_map = { + "float16": list(filter(lambda x: "SkipLayerNormSmall_half" in x, dir(ke))), + "float32": list(filter(lambda x: "SkipLayerNormSmall_float" in x, dir(ke))), + } + return type_map[dtype] + + +def skip_layer_norm(input_x, skip, bias, gamma, beta, epsilon): + val = input_x + skip + bias + x_u = np.mean(val, axis=(2,)) + x_s = np.var(val, axis=(2,)) + output = val - x_u[..., None] + output = output / np.sqrt(x_s + epsilon)[..., None] + output = output * gamma + beta + return output + + +def run_skip_layer_norm(batch_size: int, seq_len: int, hidden_size: int, dtype: str, func): + np.random.seed(0) + input_x = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype) + skip = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype) + bias = np.random.rand(hidden_size).astype(dtype) + gamma = np.random.rand(hidden_size).astype(dtype) + beta = np.random.rand((hidden_size)).astype(dtype) + epsilon = 0.0005 + output_y = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype) + + input_d = ke.DeviceArray(input_x) + skip_d = ke.DeviceArray(skip) + bias_d = ke.DeviceArray(bias) + gamma_d = ke.DeviceArray(gamma) + beta_d = ke.DeviceArray(beta) + y_d = ke.DeviceArray(output_y) + my_func = getattr(ke, func) + my_op = my_func( + y_d, input_d, skip_d, gamma_d, beta_d, bias_d, epsilon, hidden_size, batch_size * seq_len * hidden_size + ) + if my_op.IsSupported(): + my_op.Run() + + y_d.UpdateHostNumpyArray() + + y_ref = skip_layer_norm(input_x, skip, bias, gamma, beta, epsilon) + np.testing.assert_almost_equal(y_ref, output_y, decimal=1e-05) + + +dtypes = ["float32", "float16"] + + +@pytest.mark.parametrize("bert_sizes", get_bert_sizes()) +@pytest.mark.parametrize("dtype", dtypes) +def test_skip_layer_norm(bert_sizes, dtype): + for func in dtype_to_funcs(dtype): + print(func) + run_skip_layer_norm(*bert_sizes, dtype, func) + + +def profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func): + np.random.seed(0) + input_x = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype) + skip = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype) + gamma = np.random.rand(hidden_size).astype(dtype) + beta = np.random.rand(hidden_size).astype(dtype) + bias = np.random.rand(hidden_size).astype(dtype) + epsilon = 0.0005 + output_y = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype) + + input_d = ke.DeviceArray(input_x) + skip_d = ke.DeviceArray(skip) + gamma_d = ke.DeviceArray(gamma) + beta_d = ke.DeviceArray(beta) + bias_d = ke.DeviceArray(bias) + y_d = ke.DeviceArray(output_y) + my_func = getattr(ke, func) + my_op = my_func( + y_d, input_d, skip_d, gamma_d, beta_d, bias_d, epsilon, hidden_size, batch_size * seq_len * hidden_size + ) + if my_op.IsSupported(): + duration = my_op.Profile() + print( + dtype, + batch_size, + seq_len, + hidden_size, + my_func, + f"{duration * 1000:.2f} us", + f"{(input_x.size * 3 + bias.size * 3) * input_x.itemsize * 1e3 / duration / 1e9:.2f} GB/s", + ) + + +def profile(): + bert_sizes = get_bert_sizes() + for dtype in dtypes: + for bert_size in bert_sizes: + for func in dtype_to_funcs(dtype): + profile_skip_layer_norm_func(*bert_size, dtype, func) + print() + + +if __name__ == "__main__": + profile()