[ADD] add skip layernorm to kernel explorer for ROCm EP (#12816)

**Description**: Describe your changes.
Related PR: https://github.com/microsoft/onnxruntime/pull/12803
https://github.com/microsoft/onnxruntime/pull/12817
https://github.com/microsoft/onnxruntime/pull/12821

Add skip layernorm to kernel explorer for profiling.

**Motivation and Context**
- Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here.
This commit is contained in:
PeixuanZuo 2022-09-20 17:17:01 +08:00 committed by GitHub
parent ffeba98a9d
commit 189aef2bea
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 407 additions and 127 deletions

View file

@ -1,8 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/rocm_common.h"
#include "contrib_ops/rocm/bert/skip_layer_norm.h"
#include "core/providers/rocm/rocm_common.h"
#include "contrib_ops/rocm/bert/skip_layer_norm_impl.h"
namespace onnxruntime {
@ -93,21 +94,19 @@ Status SkipLayerNorm<T>::ComputeInternal(OpKernelContext* ctx) const {
int sequence_length = static_cast<int>(input_dims[1]);
int hidden_size = static_cast<int>(input_dims[2]);
int64_t element_count = input_dims[0] * sequence_length * hidden_size;
size_t element_size = sizeof(T);
typedef typename ToHipType<T>::MappedType HipT;
return LaunchSkipLayerNormKernel<HipT>(
Stream(),
reinterpret_cast<HipT*>(output->MutableData<T>()),
reinterpret_cast<const HipT*>(input->Data<T>()),
reinterpret_cast<const HipT*>(skip->Data<T>()),
reinterpret_cast<const HipT*>(gamma->Data<T>()),
(beta != nullptr) ? reinterpret_cast<const HipT*>(beta->Data<T>()) : nullptr,
(bias != nullptr) ? reinterpret_cast<const HipT*>(bias->Data<T>()) : nullptr,
epsilon_,
hidden_size,
static_cast<int>(element_count),
element_size);
Stream(),
reinterpret_cast<HipT*>(output->MutableData<T>()),
reinterpret_cast<const HipT*>(input->Data<T>()),
reinterpret_cast<const HipT*>(skip->Data<T>()),
reinterpret_cast<const HipT*>(gamma->Data<T>()),
(beta != nullptr) ? reinterpret_cast<const HipT*>(beta->Data<T>()) : nullptr,
(bias != nullptr) ? reinterpret_cast<const HipT*>(bias->Data<T>()) : nullptr,
epsilon_,
hidden_size,
static_cast<int>(element_count));
}
} // namespace rocm

View file

@ -28,93 +28,20 @@ limitations under the License.
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
// Licensed under the MIT License.
#include <hip/hip_fp16.h>
#include "contrib_ops/rocm/bert/layer_norm.cuh"
#include "contrib_ops/rocm/bert/skip_layer_norm_impl.h"
#include <hip/hip_fp16.h>
#include "contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
template <typename T>
T maybe2half(float x);
template <>
float maybe2half(float x) {
return x;
}
template <>
half maybe2half(float x) {
return __float2half_rn(x);
}
template <typename T, unsigned TPB>
__global__ void SkipLayerNormKernel(
const int ld, const T* input, const T* skip, const T* beta, const T* gamma, const T* bias,
const T epsilon, T* output) {
const T reverse_ld = T(1.f / ld);
const int offset = blockIdx.x * ld;
KeyValuePairSum pair_sum;
// reduce x and x^2
hipcub::KeyValuePair<T, T> thread_data(0, 0);
for (int i = threadIdx.x; i < ld; i += TPB) {
const int idx = offset + i;
const T val = (bias == nullptr) ? input[idx] + skip[idx] : input[idx] + skip[idx] + bias[i];
const T rldval = reverse_ld * val;
thread_data = pair_sum(thread_data, hipcub::KeyValuePair<T, T>(rldval, rldval * val));
output[idx] = val;
}
LayerNorm<T, TPB>(thread_data, ld, offset, beta, gamma, epsilon, output);
}
// Vectorized kernel
template <typename T, unsigned TPB, int ILP>
__global__ void SkipLayerNormKernelSmall(
const int ld, const T* input, const T* skip, const T* beta, const T* gamma,
const T* bias, const T epsilon, T* output, bool hasBias) {
const T rld = T(1.f / ld);
const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld
using VecT = aligned_vector<T, ILP>;
T input_v[ILP], skip_v[ILP], bias_v[ILP];
hipcub::KeyValuePair<T, T> thread_data(T(0.f), T(0.f));
if (ILP * threadIdx.x < ld) {
VecT* input_val = reinterpret_cast<VecT*>(&input_v);
*input_val = *reinterpret_cast<const VecT*>(&input[idx]);
VecT* skip_val = reinterpret_cast<VecT*>(&skip_v);
*skip_val = *reinterpret_cast<const VecT*>(&skip[idx]);
if (hasBias) {
VecT* bias_val = reinterpret_cast<VecT*>(&bias_v);
*bias_val = *reinterpret_cast<const VecT*>(&bias[threadIdx.x * ILP]);
}
T rldval_sum = T(0.f);
T rldvalsq_sum = T(0.f);
#pragma unroll
for (int i = 0; i < ILP; i++) {
input_v[i] += hasBias ? skip_v[i] + bias_v[i] : skip_v[i];
const T rldval = rld * input_v[i];
rldval_sum += rldval;
rldvalsq_sum += rldval * input_v[i];
}
thread_data = hipcub::KeyValuePair<T, T>(rldval_sum, rldvalsq_sum);
}
LayerNormSmall<T, TPB, ILP>(input_v, thread_data, ld, idx, beta, gamma, epsilon, output);
}
template <typename T>
Status LaunchSkipLayerNormKernel(
hipStream_t stream, T* output, const T* input, const T* skip, const T* gamma,
const T* beta, const T* bias, float epsilon, const int ld, const int element_count,
size_t element_size) {
const T* beta, const T* bias, float epsilon, const int ld, const int element_count) {
// this must be true because n is the total size of the tensor
assert(element_count % ld == 0);
bool hasBias = (bias == nullptr) ? false : true;
@ -122,55 +49,55 @@ Status LaunchSkipLayerNormKernel(
const int grid_size = element_count / ld;
if (ld <= 32) {
constexpr int block_size = 32;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 1>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
SkipLayerNormKernelSmall<T, block_size, 1><<<grid_size, block_size, 0, stream>>>(
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
} else if (ld <= 64) {
constexpr int block_size = 64 / 2;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 2>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
SkipLayerNormKernelSmall<T, block_size, 2><<<grid_size, block_size, 0, stream>>>(
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
} else if (ld <= 128) {
constexpr int block_size = 128 / 4;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 4>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
SkipLayerNormKernelSmall<T, block_size, 4><<<grid_size, block_size, 0, stream>>>(
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
} else if (ld <= 384) {
constexpr int block_size = 384 / 4;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 4>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
SkipLayerNormKernelSmall<T, block_size, 4><<<grid_size, block_size, 0, stream>>>(
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
} else if (ld <= 768) {
constexpr int block_size = 768 / 4;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 4>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
SkipLayerNormKernelSmall<T, block_size, 4><<<grid_size, block_size, 0, stream>>>(
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
} else if (ld <= 1024) {
constexpr int block_size = 1024 / 4;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 4>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
SkipLayerNormKernelSmall<T, block_size, 4><<<grid_size, block_size, 0, stream>>>(
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
} else {
constexpr int block_size = 256;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernel<T, block_size>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output);
SkipLayerNormKernel<T, block_size><<<grid_size, block_size, 0, stream>>>(
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output);
}
} else {
const int grid_size = element_count / ld;
if (ld <= 32) {
constexpr int block_size = 32;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 1>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
SkipLayerNormKernelSmall<T, block_size, 1><<<grid_size, block_size, 0, stream>>>(
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
} else if (ld <= 64) {
constexpr int block_size = 64;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 1>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
SkipLayerNormKernelSmall<T, block_size, 1><<<grid_size, block_size, 0, stream>>>(
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
} else if (ld <= 128) {
constexpr int block_size = 128;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 1>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
SkipLayerNormKernelSmall<T, block_size, 1><<<grid_size, block_size, 0, stream>>>(
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
} else if (ld == 384) {
constexpr int block_size = 384;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 1>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
SkipLayerNormKernelSmall<T, block_size, 1><<<grid_size, block_size, 0, stream>>>(
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
} else {
constexpr int block_size = 256;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernel<T, block_size>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output);
SkipLayerNormKernel<T, block_size><<<grid_size, block_size, 0, stream>>>(
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output);
}
}
return HIP_CALL(hipPeekAtLastError());
@ -179,12 +106,12 @@ Status LaunchSkipLayerNormKernel(
template Status LaunchSkipLayerNormKernel<float>(hipStream_t stream, float* output, const float* input,
const float* skip, const float* gamma, const float* beta,
const float* bias, float epsilon, const int ld,
const int element_count, size_t element_size);
const int element_count);
template Status LaunchSkipLayerNormKernel<half>(hipStream_t stream, half* output, const half* input,
const half* skip, const half* gamma, const half* beta,
const half* bias, float epsilon, const int ld,
const int element_count, size_t element_size);
const int element_count);
} // namespace rocm
} // namespace contrib

View file

@ -11,16 +11,15 @@ namespace rocm {
template <typename T>
Status LaunchSkipLayerNormKernel(
hipStream_t stream,
T* output, // output tensor
const T* input, // input tensor
const T* skip, // skip tensor
const T* gamma, // Layer normalization gamma tensor
const T* beta, // Layer normalization beta tensor
const T* bias, // Layer normalization beta tensor
float epsilon, // Layer normalization epsilon
int hidden_size, // hidden size, it is the leading dimension (ld)
int element_count, // number of elements in input tensor
size_t element_size
T* output, // output tensor
const T* input, // input tensor
const T* skip, // skip tensor
const T* gamma, // Layer normalization gamma tensor
const T* beta, // Layer normalization beta tensor
const T* bias, // Layer normalization beta tensor
float epsilon, // Layer normalization epsilon
int hidden_size, // hidden size, it is the leading dimension (ld)
int element_count // number of elements in input tensor
);
} // namespace rocm

View file

@ -0,0 +1,89 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <hip/hip_fp16.h>
#include "contrib_ops/rocm/bert/layer_norm.cuh"
namespace onnxruntime {
namespace contrib {
namespace rocm {
template <typename T>
T maybe2half(float x);
template <>
float maybe2half(float x) {
return x;
}
template <>
half maybe2half(float x) {
return __float2half_rn(x);
}
template <typename T, unsigned TPB>
__global__ void SkipLayerNormKernel(
const int ld, const T* input, const T* skip, const T* beta, const T* gamma, const T* bias,
const T epsilon, T* output) {
const T reverse_ld = T(1.f / ld);
const int offset = blockIdx.x * ld;
KeyValuePairSum pair_sum;
// reduce x and x^2
hipcub::KeyValuePair<T, T> thread_data(0, 0);
for (int i = threadIdx.x; i < ld; i += TPB) {
const int idx = offset + i;
const T val = (bias == nullptr) ? input[idx] + skip[idx] : input[idx] + skip[idx] + bias[i];
const T rldval = reverse_ld * val;
thread_data = pair_sum(thread_data, hipcub::KeyValuePair<T, T>(rldval, rldval * val));
output[idx] = val;
}
LayerNorm<T, TPB>(thread_data, ld, offset, beta, gamma, epsilon, output);
}
// Vectorized kernel
template <typename T, unsigned TPB, int ILP>
__global__ void SkipLayerNormKernelSmall(
const int ld, const T* input, const T* skip, const T* beta, const T* gamma,
const T* bias, const T epsilon, T* output, bool hasBias) {
const T rld = T(1.f / ld);
const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld
using VecT = aligned_vector<T, ILP>;
T input_v[ILP], skip_v[ILP], bias_v[ILP];
hipcub::KeyValuePair<T, T> thread_data(T(0.f), T(0.f));
if (ILP * threadIdx.x < ld) {
VecT* input_val = reinterpret_cast<VecT*>(&input_v);
*input_val = *reinterpret_cast<const VecT*>(&input[idx]);
VecT* skip_val = reinterpret_cast<VecT*>(&skip_v);
*skip_val = *reinterpret_cast<const VecT*>(&skip[idx]);
if (hasBias) {
VecT* bias_val = reinterpret_cast<VecT*>(&bias_v);
*bias_val = *reinterpret_cast<const VecT*>(&bias[threadIdx.x * ILP]);
}
T rldval_sum = T(0.f);
T rldvalsq_sum = T(0.f);
#pragma unroll
for (int i = 0; i < ILP; i++) {
input_v[i] += hasBias ? skip_v[i] + bias_v[i] : skip_v[i];
const T rldval = rld * input_v[i];
rldval_sum += rldval;
rldvalsq_sum += rldval * input_v[i];
}
thread_data = hipcub::KeyValuePair<T, T>(rldval_sum, rldvalsq_sum);
}
LayerNormSmall<T, TPB, ILP>(input_v, thread_data, ld, idx, beta, gamma, epsilon, output);
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,58 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <hip/hip_runtime.h>
#include <memory>
#include <string>
#include <vector>
#include "contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h"
#include "contrib_ops/rocm/bert/tunable_op.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
template <typename T>
struct SkipLayerNormParams : OpParams {
SkipLayerNormParams(hipStream_t stream, T* output, const T* input,
const T* skip, const T* gamma, const T* beta,
const T* bias, float epsilon, const int ld,
const int element_count)
: OpParams(stream), output(output), input(input), skip(skip), gamma(gamma), beta(beta), bias(bias),
epsilon(epsilon), ld(ld), element_count(element_count) {}
std::string Signature() const override {
std::string sig = std::to_string(ld) + "_" + std::to_string(element_count);
return sig;
}
T* output;
const T* input;
const T* skip;
const T* gamma;
const T* beta;
const T* bias;
const float epsilon;
const int ld;
const int element_count;
};
template <typename T, int ThreadsPerBlock, int VecSize>
Status SkipLayerNormSmallOp(const SkipLayerNormParams<T>* params) {
TUNABLE_OP_RETURN_UNSUPPOTED_ARGUMENT_IF(
!((params->ld <= 1024 && params->ld % VecSize == 0 && params->ld == ThreadsPerBlock * VecSize)));
SkipLayerNormKernelSmall<T, ThreadsPerBlock, VecSize><<<dim3(CeilingDivision(params->element_count, params->ld)),
dim3(ThreadsPerBlock),
0, params->stream>>>(
params->ld, params->input, params->skip,
params->beta, params->gamma, params->bias, maybe2half<T>(params->epsilon), params->output,
(params->bias == nullptr) ? false : true);
return HIP_CALL(hipGetLastError());
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -7,6 +7,7 @@
#include "python/tools/kernel_explorer/kernels/vector_add.h"
#include "python/tools/kernel_explorer/kernels/fast_gelu.h"
#include "python/tools/kernel_explorer/kernels/gemm.h"
#include "python/tools/kernel_explorer/kernels/skip_layer_norm.h"
namespace py = pybind11;
@ -19,6 +20,7 @@ PYBIND11_MODULE(_kernel_explorer, m) {
InitVectorAdd(m);
InitFastGelu(m);
InitGemm(m);
InitSkipLayerNorm(m);
}
} // namespace onnxruntime

View file

@ -0,0 +1,71 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "python/tools/kernel_explorer/kernels/skip_layer_norm.h"
#include <hip/hip_fp16.h>
#include <pybind11/pybind11.h>
#include "contrib_ops/rocm/bert/skip_layer_norm_op.h"
#include "python/tools/kernel_explorer/device_array.h"
#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
namespace py = pybind11;
namespace onnxruntime {
template <typename T, int ThreadsPerBlock, int VecSize>
class SkipLayerNormSmall : public IKernelExplorer {
public:
SkipLayerNormSmall(DeviceArray& output, DeviceArray& input, DeviceArray& skip,
DeviceArray& gamma, DeviceArray& beta, DeviceArray& bias,
float epsilon, int hidden_size, int element_count)
: params_(this->Stream(), static_cast<T*>(output.ptr()), static_cast<T*>(input.ptr()),
static_cast<T*>(skip.ptr()), static_cast<T*>(gamma.ptr()), static_cast<T*>(beta.ptr()),
static_cast<T*>(bias.ptr()), epsilon, hidden_size, element_count) {}
void Run() override {
ORT_THROW_IF_ERROR((contrib::rocm::SkipLayerNormSmallOp<T, ThreadsPerBlock, VecSize>(&params_)));
}
bool IsSupported() {
Status status = contrib::rocm::SkipLayerNormSmallOp<T, ThreadsPerBlock, VecSize>(&params_);
return status.IsOK();
}
private:
using ParamsT = contrib::rocm::SkipLayerNormParams<T>;
ParamsT params_{};
};
#define REGISTER_OP(name, type, threads_per_block, vec_size) \
py::class_<name<type, threads_per_block, vec_size>>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, \
DeviceArray&, DeviceArray&, \
float, int, int>()) \
.def("SetRepeats", &name<type, threads_per_block, vec_size>::SetRepeats) \
.def("Profile", &name<type, threads_per_block, vec_size>::Profile) \
.def("Run", &name<type, threads_per_block, vec_size>::Run) \
.def("IsSupported", &name<type, threads_per_block, vec_size>::IsSupported);
#define REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, threads_per_block) \
REGISTER_OP(name, type, threads_per_block, 1) \
REGISTER_OP(name, type, threads_per_block, 2) \
REGISTER_OP(name, type, threads_per_block, 4) \
REGISTER_OP(name, type, threads_per_block, 8) \
REGISTER_OP(name, type, threads_per_block, 16)
#define REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(name, type) \
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 64) \
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 128) \
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 192) \
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 256) \
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 320) \
REGISTER_OP_FOR_ALL_VEC_SIZE(name, type, 384)
void InitSkipLayerNorm(py::module m) {
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmall, half);
REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(SkipLayerNormSmall, float);
}
} // namespace onnxruntime

View file

@ -0,0 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
namespace onnxruntime {
void InitSkipLayerNorm(py::module m);
}

View file

@ -0,0 +1,121 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from itertools import product
import kernel_explorer as ke
import numpy as np
import pytest
def get_bert_sizes():
batch_sizes = [1, 8, 64, 128]
seq_lens = [64, 128, 256, 384, 512]
hidden_sizes = [768, 1024]
return product(batch_sizes, seq_lens, hidden_sizes)
def dtype_to_funcs(dtype):
type_map = {
"float16": list(filter(lambda x: "SkipLayerNormSmall_half" in x, dir(ke))),
"float32": list(filter(lambda x: "SkipLayerNormSmall_float" in x, dir(ke))),
}
return type_map[dtype]
def skip_layer_norm(input_x, skip, bias, gamma, beta, epsilon):
val = input_x + skip + bias
x_u = np.mean(val, axis=(2,))
x_s = np.var(val, axis=(2,))
output = val - x_u[..., None]
output = output / np.sqrt(x_s + epsilon)[..., None]
output = output * gamma + beta
return output
def run_skip_layer_norm(batch_size: int, seq_len: int, hidden_size: int, dtype: str, func):
np.random.seed(0)
input_x = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
skip = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
bias = np.random.rand(hidden_size).astype(dtype)
gamma = np.random.rand(hidden_size).astype(dtype)
beta = np.random.rand((hidden_size)).astype(dtype)
epsilon = 0.0005
output_y = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
input_d = ke.DeviceArray(input_x)
skip_d = ke.DeviceArray(skip)
bias_d = ke.DeviceArray(bias)
gamma_d = ke.DeviceArray(gamma)
beta_d = ke.DeviceArray(beta)
y_d = ke.DeviceArray(output_y)
my_func = getattr(ke, func)
my_op = my_func(
y_d, input_d, skip_d, gamma_d, beta_d, bias_d, epsilon, hidden_size, batch_size * seq_len * hidden_size
)
if my_op.IsSupported():
my_op.Run()
y_d.UpdateHostNumpyArray()
y_ref = skip_layer_norm(input_x, skip, bias, gamma, beta, epsilon)
np.testing.assert_almost_equal(y_ref, output_y, decimal=1e-05)
dtypes = ["float32", "float16"]
@pytest.mark.parametrize("bert_sizes", get_bert_sizes())
@pytest.mark.parametrize("dtype", dtypes)
def test_skip_layer_norm(bert_sizes, dtype):
for func in dtype_to_funcs(dtype):
print(func)
run_skip_layer_norm(*bert_sizes, dtype, func)
def profile_skip_layer_norm_func(batch_size, seq_len, hidden_size, dtype, func):
np.random.seed(0)
input_x = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
skip = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
gamma = np.random.rand(hidden_size).astype(dtype)
beta = np.random.rand(hidden_size).astype(dtype)
bias = np.random.rand(hidden_size).astype(dtype)
epsilon = 0.0005
output_y = np.random.rand(batch_size, seq_len, hidden_size).astype(dtype)
input_d = ke.DeviceArray(input_x)
skip_d = ke.DeviceArray(skip)
gamma_d = ke.DeviceArray(gamma)
beta_d = ke.DeviceArray(beta)
bias_d = ke.DeviceArray(bias)
y_d = ke.DeviceArray(output_y)
my_func = getattr(ke, func)
my_op = my_func(
y_d, input_d, skip_d, gamma_d, beta_d, bias_d, epsilon, hidden_size, batch_size * seq_len * hidden_size
)
if my_op.IsSupported():
duration = my_op.Profile()
print(
dtype,
batch_size,
seq_len,
hidden_size,
my_func,
f"{duration * 1000:.2f} us",
f"{(input_x.size * 3 + bias.size * 3) * input_x.itemsize * 1e3 / duration / 1e9:.2f} GB/s",
)
def profile():
bert_sizes = get_bert_sizes()
for dtype in dtypes:
for bert_size in bert_sizes:
for func in dtype_to_funcs(dtype):
profile_skip_layer_norm_func(*bert_size, dtype, func)
print()
if __name__ == "__main__":
profile()