Add hipified SkipLayerNorm code for ROCmEP (#12107)

* First attempt for half2 vectorized memory access in SkipLayerNorm

* Add some functions for debugging

* Clean up the code

* Clean up the code

* Generalize the vectorized kernels with aligned_vector and remove cudaDeviceProp

* Add a unit test for a larger input size

* Fix some Lint C++ warnings

* Use ILP = 4 for the vectorized kernels

* Rewrite the vectorized kernel and templatize ComputeSkipLayerNorm

* Use conditional operator for input_v

* Refactor LaunchSkipLayerNormKernel and replace the original SkipLayerNormKernelSmall with the vectorized kernel

* Clean some comments and rename the layernorm function

* Use ComputeSkipLayerNorm to replace LaunchSkipLayerNormKernel

* Resolve a Lint C++ warning

* Fix SkipLayerNormBatch1_Float16_vec output data

* Add hipified code of bert SkipLayerNorm for ROCmEP

* Resolve some Lint C++ warnings

* Resolve some Lint C++ warnings

* Resolve some Lint C++ warnings

* Resolve Python formatting issue
This commit is contained in:
Hubert Lu 2022-07-06 22:13:11 -07:00 committed by GitHub
parent 97b03fedff
commit dbcf54aa41
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 535 additions and 1 deletions

View file

@ -0,0 +1,157 @@
#include "hip/hip_runtime.h"
/*
The implementation of this file is based on bert plugins in TensorRT demo:
https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/
Copyright 2019 NVIDIA Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <hip/hip_fp16.h>
#include <hipblas.h>
#include <hipcub/hipcub.hpp>
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/cu_inc/common.cuh"
#include "core/providers/rocm/shared_inc/rocm_call.h"
using namespace onnxruntime::rocm;
namespace onnxruntime {
namespace contrib {
namespace rocm {
template <typename T>
__device__ inline T Rsqrt(const T& x);
template <>
__device__ inline float Rsqrt(const float& x) {
return rsqrtf(x);
}
template <>
__device__ inline half Rsqrt(const half& x) {
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
return hrsqrt(x);
#else
return half(rsqrtf(static_cast<float>(x)));
#endif
}
__device__ inline half2 AddHalf2(const half2 a, const half2 b) {
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
return __hadd2(a, b);
#else
return __halves2half2(__hadd(a.x, b.x), __hadd(a.y, b.y));
#endif
}
struct KeyValuePairSum {
__device__ inline hipcub::KeyValuePair<float, float> operator()(const hipcub::KeyValuePair<float, float>& a,
const hipcub::KeyValuePair<float, float>& b) {
return hipcub::KeyValuePair<float, float>(a.key + b.key, a.value + b.value);
}
__device__ inline hipcub::KeyValuePair<half, half> operator()(const hipcub::KeyValuePair<half, half>& a,
const hipcub::KeyValuePair<half, half>& b) {
const half2 a2 = __halves2half2(a.key, a.value);
const half2 b2 = __halves2half2(b.key, b.value);
const half2 res = AddHalf2(a2, b2);
return hipcub::KeyValuePair<half, half>(__low2half(res), __high2half(res));
}
__device__ inline hipcub::KeyValuePair<half2, half2> operator()(const hipcub::KeyValuePair<half2, half2>& a,
const hipcub::KeyValuePair<half2, half2>& b) {
return hipcub::KeyValuePair<half2, half2>(AddHalf2(a.key, b.key), AddHalf2(a.value, b.value));
}
};
template <typename T, int TPB>
__device__ inline void LayerNorm(
const hipcub::KeyValuePair<T, T>& thread_data, const int ld, const int offset, const T* beta,
const T* gamma, const T epsilon, T* output) {
// Assuming thread_data is already divided by ld
using BlockReduce = hipcub::BlockReduce<hipcub::KeyValuePair<T, T>, TPB>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T mu; // mean
__shared__ T rsigma; // 1 / std.dev.
KeyValuePairSum pair_sum;
const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum);
if (threadIdx.x == 0) {
mu = sum_kv.key;
rsigma = Rsqrt(sum_kv.value - mu * mu + epsilon);
}
__syncthreads();
for (int i = threadIdx.x; i < ld; i += TPB) {
const int idx = offset + i;
const T val = output[idx];
const T g(gamma[i]);
const T b = (nullptr == beta) ? (T)0 : beta[i];
output[idx] = g * (val - mu) * rsigma + b;
}
}
template <typename T, int TPB, int ILP>
__device__ inline void LayerNormSmall(const T* input_v, const hipcub::KeyValuePair<T, T>& thread_data,
const int ld, const int idx, const T* beta, const T* gamma,
const T epsilon, T* output) {
// Assuming thread_data is already divided by ld
// Small settings: the block covers the leading dimension TPB >= ld. The input
// value is available in a register
using VecT = aligned_vector<T, ILP>;
using BlockReduce = hipcub::BlockReduce<hipcub::KeyValuePair<T, T>, TPB>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T mu; // mean
__shared__ T rsigma; // 1 / std.dev.
T beta_v[ILP], gamma_v[ILP], output_v[ILP];
if (beta != nullptr) {
VecT* beta_val = reinterpret_cast<VecT*>(&beta_v);
*beta_val = *reinterpret_cast<const VecT*>(&beta[threadIdx.x * ILP]);
}
VecT* gamma_val = reinterpret_cast<VecT*>(&gamma_v);
*gamma_val = *reinterpret_cast<const VecT*>(&gamma[threadIdx.x * ILP]);
VecT* output_val = reinterpret_cast<VecT*>(&output_v);
KeyValuePairSum pair_sum;
const hipcub::KeyValuePair<T, T> sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum);
if (threadIdx.x == 0) {
mu = sum_kv.key;
rsigma = Rsqrt(sum_kv.value - mu * mu + epsilon);
}
__syncthreads();
if (ILP * threadIdx.x < ld) {
#pragma unroll
for (int i = 0; i < ILP; i++) {
output_v[i] = (beta != nullptr) ? gamma_v[i] * (input_v[i] - mu) * rsigma + beta_v[i] :
gamma_v[i] * (input_v[i] - mu) * rsigma;
}
*(reinterpret_cast<VecT*>(&output[idx])) = *reinterpret_cast<VecT*>(&output_v[0]);
}
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,122 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/rocm/rocm_common.h"
#include "contrib_ops/rocm/bert/skip_layer_norm.h"
#include "contrib_ops/rocm/bert/skip_layer_norm_impl.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
SkipLayerNormalization, \
kMSDomain, \
1, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
SkipLayerNorm<T>);
REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
using namespace ONNX_NAMESPACE;
template <typename T>
SkipLayerNorm<T>::SkipLayerNorm(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) {
ORT_ENFORCE(op_kernel_info.GetAttr<float>("epsilon", &epsilon_).IsOK());
ORT_ENFORCE(epsilon_ >= 0);
}
template <typename T>
Status SkipLayerNorm<T>::ComputeInternal(OpKernelContext* ctx) const {
const Tensor* input = ctx->Input<Tensor>(0);
const Tensor* skip = ctx->Input<Tensor>(1);
const Tensor* gamma = ctx->Input<Tensor>(2);
const Tensor* beta = ctx->Input<Tensor>(3);
const Tensor* bias = ctx->Input<Tensor>(4);
Tensor* output = ctx->Output(0, input->Shape());
if (input->Shape() != skip->Shape()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"skip is expected to have same shape as input");
}
if (input->Shape().Size() == 0) {
return Status::OK();
}
const auto& input_dims = input->Shape().GetDims();
if (input_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"input is expected to have 3 dimensions, got ", input_dims.size());
}
const auto& gamma_dims = gamma->Shape().GetDims();
if (gamma_dims.size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"gamma is expected to have 1 dimension, got ", gamma_dims.size());
}
if (gamma_dims[0] != input_dims[2]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Last dimension of gamma and input does not match");
}
if (nullptr != beta) {
const auto& beta_dims = beta->Shape().GetDims();
if (beta_dims.size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"beta is expected to have 1 dimension, got ", beta_dims.size());
}
if (beta_dims[0] != input_dims[2]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Last dimension of beta and input does not match");
}
}
if (nullptr != bias) {
const auto& bias_dims = bias->Shape().GetDims();
if (bias_dims.size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"bias is expected to have 1 dimension, got ", bias_dims.size());
}
if (bias_dims[0] != input_dims[2]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Last dimension of bias and input does not match");
}
}
int sequence_length = static_cast<int>(input_dims[1]);
int hidden_size = static_cast<int>(input_dims[2]);
int64_t element_count = input_dims[0] * sequence_length * hidden_size;
size_t element_size = sizeof(T);
typedef typename ToHipType<T>::MappedType HipT;
if (!LaunchSkipLayerNormKernel<HipT>(
Stream(),
reinterpret_cast<HipT*>(output->template MutableData<T>()),
reinterpret_cast<const HipT*>(input->template Data<T>()),
reinterpret_cast<const HipT*>(skip->template Data<T>()),
reinterpret_cast<const HipT*>(gamma->template Data<T>()),
(beta != nullptr) ? reinterpret_cast<const HipT*>(beta->template Data<T>()) : nullptr,
(bias != nullptr) ? reinterpret_cast<const HipT*>(bias->template Data<T>()) : nullptr,
epsilon_,
hidden_size,
static_cast<int>(element_count),
element_size)) {
// Get last error to reset it to hipSuccess.
HIP_CALL(hipGetLastError());
return Status(common::ONNXRUNTIME, common::FAIL);
}
return Status::OK();
}
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,26 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/providers/rocm/rocm_kernel.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
using namespace onnxruntime::rocm;
template <typename T>
class SkipLayerNorm final : public RocmKernel {
public:
SkipLayerNorm(const OpKernelInfo& op_kernel_info);
Status ComputeInternal(OpKernelContext* context) const override;
private:
float epsilon_;
};
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,197 @@
#include "hip/hip_runtime.h"
/*
The implementation of this file is based on skipLayerNorm plugin in TensorRT demo:
https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/
Copyright 2019 NVIDIA Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// Modifications: Add SkipLayerNormKernelVec to
// leverage vectorized load/write.
// and templatize ComputeSkipLayerNorm for different
// data types.
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
// Licensed under the MIT License.
#include <hip/hip_fp16.h>
#include "contrib_ops/rocm/bert/layer_norm.cuh"
#include "contrib_ops/rocm/bert/skip_layer_norm_impl.h"
namespace onnxruntime {
namespace contrib {
namespace rocm {
template<typename T>
T maybe2half(float x);
template<>
float maybe2half(float x) {
return x;
}
template<>
half maybe2half(float x) {
return __float2half_rn(x);
}
template <typename T, unsigned TPB>
__global__ void SkipLayerNormKernel(
const int ld, const T* input, const T* skip, const T* beta, const T* gamma, const T* bias,
const T epsilon, T* output) {
const T reverse_ld = T(1.f / ld);
const int offset = blockIdx.x * ld;
KeyValuePairSum pair_sum;
// reduce x and x^2
hipcub::KeyValuePair<T, T> thread_data(0, 0);
for (int i = threadIdx.x; i < ld; i += TPB) {
const int idx = offset + i;
const T val = (bias == nullptr) ? input[idx] + skip[idx] : input[idx] + skip[idx] + bias[i];
const T rldval = reverse_ld * val;
thread_data = pair_sum(thread_data, hipcub::KeyValuePair<T, T>(rldval, rldval * val));
output[idx] = val;
}
LayerNorm<T, TPB>(thread_data, ld, offset, beta, gamma, epsilon, output);
}
// Vectorized kernel
template <typename T, unsigned TPB, int ILP>
__global__ void SkipLayerNormKernelSmall(
const int ld, const T* input, const T* skip, const T* beta, const T* gamma,
const T* bias, const T epsilon, T* output, bool hasBias) {
const T rld = T(1.f / ld);
const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld
using VecT = aligned_vector<T, ILP>;
__shared__ T mu; // mean
__shared__ T rsigma; // 1 / std.dev.
T input_v[ILP], skip_v[ILP], bias_v[ILP], output_v[ILP];
VecT* input_val = reinterpret_cast<VecT*>(&input_v);
*input_val = *reinterpret_cast<const VecT*>(&input[idx]);
VecT* skip_val = reinterpret_cast<VecT*>(&skip_v);
*skip_val = *reinterpret_cast<const VecT*>(&skip[idx]);
if (hasBias) {
VecT* bias_val = reinterpret_cast<VecT*>(&bias_v);
*bias_val = *reinterpret_cast<const VecT*>(&bias[threadIdx.x * ILP]);
}
hipcub::KeyValuePair<T, T> thread_data(T(0.f), T(0.f));
if (ILP * threadIdx.x < ld) {
T rldval_sum = T(0.f);
T rldvalsq_sum = T(0.f);
#pragma unroll
for (int i = 0; i < ILP; i++) {
input_v[i] += hasBias ? skip_v[i] + bias_v[i]: skip_v[i];
const T rldval = rld * input_v[i];
rldval_sum += rldval;
rldvalsq_sum += rldval * input_v[i];
}
thread_data = hipcub::KeyValuePair<T, T>(rldval_sum, rldvalsq_sum);
}
LayerNormSmall<T, TPB, ILP>(input_v, thread_data, ld, idx, beta, gamma, epsilon, output);
}
template <typename T>
bool LaunchSkipLayerNormKernel(
hipStream_t stream, T* output, const T* input, const T* skip, const T* gamma,
const T* beta, const T* bias, float epsilon, const int ld, const int element_count,
size_t element_size) {
// this must be true because n is the total size of the tensor
assert(element_count % ld == 0);
bool hasBias = (bias == nullptr) ? false : true;
if (0 == (ld % 4)) {
const int grid_size = element_count / ld;
if (ld <= 32) {
constexpr int block_size = 32;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 1>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
} else if (ld <= 64) {
constexpr int block_size = 64 / 2;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 2>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
} else if (ld <= 128) {
constexpr int block_size = 128 / 4;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 4>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
} else if (ld <= 384) {
constexpr int block_size = 384 / 4;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 4>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
} else if (ld <= 768) {
constexpr int block_size = 768 / 4;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 4>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
} else if (ld <= 1024) {
constexpr int block_size = 1024 / 4;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 4>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
} else {
constexpr int block_size = 256;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernel<T, block_size>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output);
}
} else {
const int grid_size = element_count / ld;
if (ld <= 32) {
constexpr int block_size = 32;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 1>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
} else if (ld <= 64) {
constexpr int block_size = 64;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 1>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
} else if (ld <= 128) {
constexpr int block_size = 128;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 1>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
} else if (ld == 384) {
constexpr int block_size = 384;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 1>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
} else {
constexpr int block_size = 256;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernel<T, block_size>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output);
}
}
return HIP_CALL(hipPeekAtLastError());
}
template bool LaunchSkipLayerNormKernel<float>(hipStream_t stream, float* output, const float* input,
const float* skip, const float* gamma, const float* beta,
const float* bias, float epsilon, const int ld,
const int element_count, size_t element_size);
template bool LaunchSkipLayerNormKernel<half>(hipStream_t stream, half* output, const half* input,
const half* skip, const half* gamma, const half* beta,
const half* bias, float epsilon, const int ld,
const int element_count, size_t element_size);
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,28 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
namespace onnxruntime {
namespace contrib {
namespace rocm {
template <typename T>
bool LaunchSkipLayerNormKernel(
hipStream_t stream,
T* output, // output tensor
const T* input, // input tensor
const T* skip, // skip tensor
const T* gamma, // Layer normalization gamma tensor
const T* beta, // Layer normalization beta tensor
const T* bias, // Layer normalization beta tensor
float epsilon, // Layer normalization epsilon
int hidden_size, // hidden size, it is the leading dimension (ld)
int element_count, // number of elements in input tensor
size_t element_size
);
} // namespace rocm
} // namespace contrib
} // namespace onnxruntime

View file

@ -29,7 +29,11 @@ contrib_ops_excluded_files = [
"bert/fast_gelu_impl.h",
"bert/fast_gelu.cc",
"bert/fast_gelu.h",
# 'bert/layer_norm.cuh',
"bert/layer_norm.cuh",
"bert/skip_layer_norm.cc",
"bert/skip_layer_norm.h",
"bert/skip_layer_norm_impl.cu",
"bert/skip_layer_norm_impl.h",
"bert/longformer_attention.cc",
"bert/longformer_attention.h",
"bert/longformer_attention_softmax.cu",