mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Add hipified SkipLayerNorm code for ROCmEP (#12107)
* First attempt for half2 vectorized memory access in SkipLayerNorm * Add some functions for debugging * Clean up the code * Clean up the code * Generalize the vectorized kernels with aligned_vector and remove cudaDeviceProp * Add a unit test for a larger input size * Fix some Lint C++ warnings * Use ILP = 4 for the vectorized kernels * Rewrite the vectorized kernel and templatize ComputeSkipLayerNorm * Use conditional operator for input_v * Refactor LaunchSkipLayerNormKernel and replace the original SkipLayerNormKernelSmall with the vectorized kernel * Clean some comments and rename the layernorm function * Use ComputeSkipLayerNorm to replace LaunchSkipLayerNormKernel * Resolve a Lint C++ warning * Fix SkipLayerNormBatch1_Float16_vec output data * Add hipified code of bert SkipLayerNorm for ROCmEP * Resolve some Lint C++ warnings * Resolve some Lint C++ warnings * Resolve some Lint C++ warnings * Resolve Python formatting issue
This commit is contained in:
parent
97b03fedff
commit
dbcf54aa41
6 changed files with 535 additions and 1 deletions
157
onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh
Normal file
157
onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
#include "hip/hip_runtime.h"
|
||||
/*
|
||||
The implementation of this file is based on bert plugins in TensorRT demo:
|
||||
https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/
|
||||
|
||||
Copyright 2019 NVIDIA Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hipblas.h>
|
||||
#include <hipcub/hipcub.hpp>
|
||||
#include "core/providers/rocm/rocm_common.h"
|
||||
#include "core/providers/rocm/cu_inc/common.cuh"
|
||||
#include "core/providers/rocm/shared_inc/rocm_call.h"
|
||||
|
||||
using namespace onnxruntime::rocm;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
template <typename T>
|
||||
__device__ inline T Rsqrt(const T& x);
|
||||
|
||||
template <>
|
||||
__device__ inline float Rsqrt(const float& x) {
|
||||
return rsqrtf(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline half Rsqrt(const half& x) {
|
||||
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
|
||||
return hrsqrt(x);
|
||||
#else
|
||||
return half(rsqrtf(static_cast<float>(x)));
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ inline half2 AddHalf2(const half2 a, const half2 b) {
|
||||
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
|
||||
return __hadd2(a, b);
|
||||
#else
|
||||
return __halves2half2(__hadd(a.x, b.x), __hadd(a.y, b.y));
|
||||
#endif
|
||||
}
|
||||
|
||||
struct KeyValuePairSum {
|
||||
__device__ inline hipcub::KeyValuePair<float, float> operator()(const hipcub::KeyValuePair<float, float>& a,
|
||||
const hipcub::KeyValuePair<float, float>& b) {
|
||||
return hipcub::KeyValuePair<float, float>(a.key + b.key, a.value + b.value);
|
||||
}
|
||||
|
||||
__device__ inline hipcub::KeyValuePair<half, half> operator()(const hipcub::KeyValuePair<half, half>& a,
|
||||
const hipcub::KeyValuePair<half, half>& b) {
|
||||
const half2 a2 = __halves2half2(a.key, a.value);
|
||||
const half2 b2 = __halves2half2(b.key, b.value);
|
||||
const half2 res = AddHalf2(a2, b2);
|
||||
return hipcub::KeyValuePair<half, half>(__low2half(res), __high2half(res));
|
||||
}
|
||||
|
||||
__device__ inline hipcub::KeyValuePair<half2, half2> operator()(const hipcub::KeyValuePair<half2, half2>& a,
|
||||
const hipcub::KeyValuePair<half2, half2>& b) {
|
||||
return hipcub::KeyValuePair<half2, half2>(AddHalf2(a.key, b.key), AddHalf2(a.value, b.value));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int TPB>
|
||||
__device__ inline void LayerNorm(
|
||||
const hipcub::KeyValuePair<T, T>& thread_data, const int ld, const int offset, const T* beta,
|
||||
const T* gamma, const T epsilon, T* output) {
|
||||
// Assuming thread_data is already divided by ld
|
||||
|
||||
using BlockReduce = hipcub::BlockReduce<hipcub::KeyValuePair<T, T>, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
__shared__ T mu; // mean
|
||||
__shared__ T rsigma; // 1 / std.dev.
|
||||
|
||||
KeyValuePairSum pair_sum;
|
||||
const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
mu = sum_kv.key;
|
||||
rsigma = Rsqrt(sum_kv.value - mu * mu + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int i = threadIdx.x; i < ld; i += TPB) {
|
||||
const int idx = offset + i;
|
||||
const T val = output[idx];
|
||||
const T g(gamma[i]);
|
||||
const T b = (nullptr == beta) ? (T)0 : beta[i];
|
||||
output[idx] = g * (val - mu) * rsigma + b;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int TPB, int ILP>
|
||||
__device__ inline void LayerNormSmall(const T* input_v, const hipcub::KeyValuePair<T, T>& thread_data,
|
||||
const int ld, const int idx, const T* beta, const T* gamma,
|
||||
const T epsilon, T* output) {
|
||||
// Assuming thread_data is already divided by ld
|
||||
// Small settings: the block covers the leading dimension TPB >= ld. The input
|
||||
// value is available in a register
|
||||
using VecT = aligned_vector<T, ILP>;
|
||||
using BlockReduce = hipcub::BlockReduce<hipcub::KeyValuePair<T, T>, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
__shared__ T mu; // mean
|
||||
__shared__ T rsigma; // 1 / std.dev.
|
||||
T beta_v[ILP], gamma_v[ILP], output_v[ILP];
|
||||
|
||||
if (beta != nullptr) {
|
||||
VecT* beta_val = reinterpret_cast<VecT*>(&beta_v);
|
||||
*beta_val = *reinterpret_cast<const VecT*>(&beta[threadIdx.x * ILP]);
|
||||
}
|
||||
VecT* gamma_val = reinterpret_cast<VecT*>(&gamma_v);
|
||||
*gamma_val = *reinterpret_cast<const VecT*>(&gamma[threadIdx.x * ILP]);
|
||||
|
||||
VecT* output_val = reinterpret_cast<VecT*>(&output_v);
|
||||
|
||||
KeyValuePairSum pair_sum;
|
||||
const hipcub::KeyValuePair<T, T> sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
mu = sum_kv.key;
|
||||
rsigma = Rsqrt(sum_kv.value - mu * mu + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (ILP * threadIdx.x < ld) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ILP; i++) {
|
||||
output_v[i] = (beta != nullptr) ? gamma_v[i] * (input_v[i] - mu) * rsigma + beta_v[i] :
|
||||
gamma_v[i] * (input_v[i] - mu) * rsigma;
|
||||
}
|
||||
*(reinterpret_cast<VecT*>(&output[idx])) = *reinterpret_cast<VecT*>(&output_v[0]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
||||
122
onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc
Normal file
122
onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/rocm/rocm_common.h"
|
||||
#include "contrib_ops/rocm/bert/skip_layer_norm.h"
|
||||
#include "contrib_ops/rocm/bert/skip_layer_norm_impl.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
#define REGISTER_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
SkipLayerNormalization, \
|
||||
kMSDomain, \
|
||||
1, \
|
||||
T, \
|
||||
kRocmExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
SkipLayerNorm<T>);
|
||||
|
||||
REGISTER_KERNEL_TYPED(float)
|
||||
REGISTER_KERNEL_TYPED(MLFloat16)
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
||||
template <typename T>
|
||||
SkipLayerNorm<T>::SkipLayerNorm(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) {
|
||||
ORT_ENFORCE(op_kernel_info.GetAttr<float>("epsilon", &epsilon_).IsOK());
|
||||
ORT_ENFORCE(epsilon_ >= 0);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status SkipLayerNorm<T>::ComputeInternal(OpKernelContext* ctx) const {
|
||||
const Tensor* input = ctx->Input<Tensor>(0);
|
||||
const Tensor* skip = ctx->Input<Tensor>(1);
|
||||
const Tensor* gamma = ctx->Input<Tensor>(2);
|
||||
const Tensor* beta = ctx->Input<Tensor>(3);
|
||||
const Tensor* bias = ctx->Input<Tensor>(4);
|
||||
|
||||
Tensor* output = ctx->Output(0, input->Shape());
|
||||
|
||||
if (input->Shape() != skip->Shape()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"skip is expected to have same shape as input");
|
||||
}
|
||||
|
||||
if (input->Shape().Size() == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const auto& input_dims = input->Shape().GetDims();
|
||||
if (input_dims.size() != 3) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"input is expected to have 3 dimensions, got ", input_dims.size());
|
||||
}
|
||||
|
||||
const auto& gamma_dims = gamma->Shape().GetDims();
|
||||
if (gamma_dims.size() != 1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"gamma is expected to have 1 dimension, got ", gamma_dims.size());
|
||||
}
|
||||
if (gamma_dims[0] != input_dims[2]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Last dimension of gamma and input does not match");
|
||||
}
|
||||
|
||||
if (nullptr != beta) {
|
||||
const auto& beta_dims = beta->Shape().GetDims();
|
||||
if (beta_dims.size() != 1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"beta is expected to have 1 dimension, got ", beta_dims.size());
|
||||
}
|
||||
if (beta_dims[0] != input_dims[2]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Last dimension of beta and input does not match");
|
||||
}
|
||||
}
|
||||
|
||||
if (nullptr != bias) {
|
||||
const auto& bias_dims = bias->Shape().GetDims();
|
||||
if (bias_dims.size() != 1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"bias is expected to have 1 dimension, got ", bias_dims.size());
|
||||
}
|
||||
if (bias_dims[0] != input_dims[2]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Last dimension of bias and input does not match");
|
||||
}
|
||||
}
|
||||
|
||||
int sequence_length = static_cast<int>(input_dims[1]);
|
||||
int hidden_size = static_cast<int>(input_dims[2]);
|
||||
int64_t element_count = input_dims[0] * sequence_length * hidden_size;
|
||||
size_t element_size = sizeof(T);
|
||||
typedef typename ToHipType<T>::MappedType HipT;
|
||||
|
||||
if (!LaunchSkipLayerNormKernel<HipT>(
|
||||
Stream(),
|
||||
reinterpret_cast<HipT*>(output->template MutableData<T>()),
|
||||
reinterpret_cast<const HipT*>(input->template Data<T>()),
|
||||
reinterpret_cast<const HipT*>(skip->template Data<T>()),
|
||||
reinterpret_cast<const HipT*>(gamma->template Data<T>()),
|
||||
(beta != nullptr) ? reinterpret_cast<const HipT*>(beta->template Data<T>()) : nullptr,
|
||||
(bias != nullptr) ? reinterpret_cast<const HipT*>(bias->template Data<T>()) : nullptr,
|
||||
epsilon_,
|
||||
hidden_size,
|
||||
static_cast<int>(element_count),
|
||||
element_size)) {
|
||||
// Get last error to reset it to hipSuccess.
|
||||
HIP_CALL(hipGetLastError());
|
||||
return Status(common::ONNXRUNTIME, common::FAIL);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
||||
26
onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.h
Normal file
26
onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.h
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "core/common/common.h"
|
||||
#include "core/providers/rocm/rocm_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
using namespace onnxruntime::rocm;
|
||||
|
||||
template <typename T>
|
||||
class SkipLayerNorm final : public RocmKernel {
|
||||
public:
|
||||
SkipLayerNorm(const OpKernelInfo& op_kernel_info);
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
float epsilon_;
|
||||
};
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
197
onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu
Normal file
197
onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu
Normal file
|
|
@ -0,0 +1,197 @@
|
|||
#include "hip/hip_runtime.h"
|
||||
/*
|
||||
The implementation of this file is based on skipLayerNorm plugin in TensorRT demo:
|
||||
https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/
|
||||
|
||||
Copyright 2019 NVIDIA Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// Modifications: Add SkipLayerNormKernelVec to
|
||||
// leverage vectorized load/write.
|
||||
// and templatize ComputeSkipLayerNorm for different
|
||||
// data types.
|
||||
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <hip/hip_fp16.h>
|
||||
#include "contrib_ops/rocm/bert/layer_norm.cuh"
|
||||
#include "contrib_ops/rocm/bert/skip_layer_norm_impl.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
template<typename T>
|
||||
T maybe2half(float x);
|
||||
|
||||
template<>
|
||||
float maybe2half(float x) {
|
||||
return x;
|
||||
}
|
||||
|
||||
template<>
|
||||
half maybe2half(float x) {
|
||||
return __float2half_rn(x);
|
||||
}
|
||||
|
||||
template <typename T, unsigned TPB>
|
||||
__global__ void SkipLayerNormKernel(
|
||||
const int ld, const T* input, const T* skip, const T* beta, const T* gamma, const T* bias,
|
||||
const T epsilon, T* output) {
|
||||
const T reverse_ld = T(1.f / ld);
|
||||
const int offset = blockIdx.x * ld;
|
||||
|
||||
KeyValuePairSum pair_sum;
|
||||
// reduce x and x^2
|
||||
hipcub::KeyValuePair<T, T> thread_data(0, 0);
|
||||
|
||||
for (int i = threadIdx.x; i < ld; i += TPB) {
|
||||
const int idx = offset + i;
|
||||
const T val = (bias == nullptr) ? input[idx] + skip[idx] : input[idx] + skip[idx] + bias[i];
|
||||
const T rldval = reverse_ld * val;
|
||||
thread_data = pair_sum(thread_data, hipcub::KeyValuePair<T, T>(rldval, rldval * val));
|
||||
output[idx] = val;
|
||||
}
|
||||
|
||||
LayerNorm<T, TPB>(thread_data, ld, offset, beta, gamma, epsilon, output);
|
||||
}
|
||||
|
||||
// Vectorized kernel
|
||||
template <typename T, unsigned TPB, int ILP>
|
||||
__global__ void SkipLayerNormKernelSmall(
|
||||
const int ld, const T* input, const T* skip, const T* beta, const T* gamma,
|
||||
const T* bias, const T epsilon, T* output, bool hasBias) {
|
||||
const T rld = T(1.f / ld);
|
||||
const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld
|
||||
|
||||
using VecT = aligned_vector<T, ILP>;
|
||||
__shared__ T mu; // mean
|
||||
__shared__ T rsigma; // 1 / std.dev.
|
||||
|
||||
T input_v[ILP], skip_v[ILP], bias_v[ILP], output_v[ILP];
|
||||
|
||||
VecT* input_val = reinterpret_cast<VecT*>(&input_v);
|
||||
*input_val = *reinterpret_cast<const VecT*>(&input[idx]);
|
||||
|
||||
VecT* skip_val = reinterpret_cast<VecT*>(&skip_v);
|
||||
*skip_val = *reinterpret_cast<const VecT*>(&skip[idx]);
|
||||
|
||||
if (hasBias) {
|
||||
VecT* bias_val = reinterpret_cast<VecT*>(&bias_v);
|
||||
*bias_val = *reinterpret_cast<const VecT*>(&bias[threadIdx.x * ILP]);
|
||||
}
|
||||
|
||||
hipcub::KeyValuePair<T, T> thread_data(T(0.f), T(0.f));
|
||||
|
||||
if (ILP * threadIdx.x < ld) {
|
||||
T rldval_sum = T(0.f);
|
||||
T rldvalsq_sum = T(0.f);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ILP; i++) {
|
||||
input_v[i] += hasBias ? skip_v[i] + bias_v[i]: skip_v[i];
|
||||
const T rldval = rld * input_v[i];
|
||||
rldval_sum += rldval;
|
||||
rldvalsq_sum += rldval * input_v[i];
|
||||
}
|
||||
thread_data = hipcub::KeyValuePair<T, T>(rldval_sum, rldvalsq_sum);
|
||||
}
|
||||
LayerNormSmall<T, TPB, ILP>(input_v, thread_data, ld, idx, beta, gamma, epsilon, output);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool LaunchSkipLayerNormKernel(
|
||||
hipStream_t stream, T* output, const T* input, const T* skip, const T* gamma,
|
||||
const T* beta, const T* bias, float epsilon, const int ld, const int element_count,
|
||||
size_t element_size) {
|
||||
|
||||
// this must be true because n is the total size of the tensor
|
||||
assert(element_count % ld == 0);
|
||||
bool hasBias = (bias == nullptr) ? false : true;
|
||||
if (0 == (ld % 4)) {
|
||||
const int grid_size = element_count / ld;
|
||||
if (ld <= 32) {
|
||||
constexpr int block_size = 32;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 1>), grid_size, block_size,
|
||||
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
|
||||
} else if (ld <= 64) {
|
||||
constexpr int block_size = 64 / 2;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 2>), grid_size, block_size,
|
||||
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
|
||||
} else if (ld <= 128) {
|
||||
constexpr int block_size = 128 / 4;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 4>), grid_size, block_size,
|
||||
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
|
||||
} else if (ld <= 384) {
|
||||
constexpr int block_size = 384 / 4;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 4>), grid_size, block_size,
|
||||
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
|
||||
} else if (ld <= 768) {
|
||||
constexpr int block_size = 768 / 4;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 4>), grid_size, block_size,
|
||||
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
|
||||
} else if (ld <= 1024) {
|
||||
constexpr int block_size = 1024 / 4;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 4>), grid_size, block_size,
|
||||
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
|
||||
} else {
|
||||
constexpr int block_size = 256;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernel<T, block_size>), grid_size, block_size,
|
||||
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output);
|
||||
}
|
||||
} else {
|
||||
const int grid_size = element_count / ld;
|
||||
if (ld <= 32) {
|
||||
constexpr int block_size = 32;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 1>), grid_size, block_size,
|
||||
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
|
||||
} else if (ld <= 64) {
|
||||
constexpr int block_size = 64;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 1>), grid_size, block_size,
|
||||
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
|
||||
} else if (ld <= 128) {
|
||||
constexpr int block_size = 128;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 1>), grid_size, block_size,
|
||||
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
|
||||
} else if (ld == 384) {
|
||||
constexpr int block_size = 384;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 1>), grid_size, block_size,
|
||||
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
|
||||
} else {
|
||||
constexpr int block_size = 256;
|
||||
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernel<T, block_size>), grid_size, block_size,
|
||||
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output);
|
||||
}
|
||||
}
|
||||
return HIP_CALL(hipPeekAtLastError());
|
||||
}
|
||||
|
||||
template bool LaunchSkipLayerNormKernel<float>(hipStream_t stream, float* output, const float* input,
|
||||
const float* skip, const float* gamma, const float* beta,
|
||||
const float* bias, float epsilon, const int ld,
|
||||
const int element_count, size_t element_size);
|
||||
|
||||
template bool LaunchSkipLayerNormKernel<half>(hipStream_t stream, half* output, const half* input,
|
||||
const half* skip, const half* gamma, const half* beta,
|
||||
const half* bias, float epsilon, const int ld,
|
||||
const int element_count, size_t element_size);
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
||||
|
||||
28
onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h
Normal file
28
onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
template <typename T>
|
||||
bool LaunchSkipLayerNormKernel(
|
||||
hipStream_t stream,
|
||||
T* output, // output tensor
|
||||
const T* input, // input tensor
|
||||
const T* skip, // skip tensor
|
||||
const T* gamma, // Layer normalization gamma tensor
|
||||
const T* beta, // Layer normalization beta tensor
|
||||
const T* bias, // Layer normalization beta tensor
|
||||
float epsilon, // Layer normalization epsilon
|
||||
int hidden_size, // hidden size, it is the leading dimension (ld)
|
||||
int element_count, // number of elements in input tensor
|
||||
size_t element_size
|
||||
);
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
||||
|
|
@ -29,7 +29,11 @@ contrib_ops_excluded_files = [
|
|||
"bert/fast_gelu_impl.h",
|
||||
"bert/fast_gelu.cc",
|
||||
"bert/fast_gelu.h",
|
||||
# 'bert/layer_norm.cuh',
|
||||
"bert/layer_norm.cuh",
|
||||
"bert/skip_layer_norm.cc",
|
||||
"bert/skip_layer_norm.h",
|
||||
"bert/skip_layer_norm_impl.cu",
|
||||
"bert/skip_layer_norm_impl.h",
|
||||
"bert/longformer_attention.cc",
|
||||
"bert/longformer_attention.h",
|
||||
"bert/longformer_attention_softmax.cu",
|
||||
|
|
|
|||
Loading…
Reference in a new issue