mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-28 22:56:32 +00:00
Support SkipLayerNorm for ROCm EP (#7210)
Co-authored-by: Weixing Zhang <wezhan@microsoft.com>
This commit is contained in:
parent
a3f17c8b0d
commit
2d352056cf
6 changed files with 146 additions and 12 deletions
|
|
@ -5,8 +5,8 @@
|
|||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "onnx/defs/tensor_proto_util.h"
|
||||
#include "skip_layer_norm.h"
|
||||
#include "skip_layer_norm_impl.h"
|
||||
#include "contrib_ops/cuda/bert/skip_layer_norm.h"
|
||||
#include "contrib_ops/cuda/bert/skip_layer_norm_impl.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
|
|||
|
|
@ -20,8 +20,8 @@ limitations under the License.
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "layer_norm.cuh"
|
||||
#include "skip_layer_norm_impl.h"
|
||||
#include "contrib_ops/cuda/bert/layer_norm.cuh"
|
||||
#include "contrib_ops/cuda/bert/skip_layer_norm_impl.h"
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
|
|||
131
onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh
Normal file
131
onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
/*
|
||||
The implementation of this file is based on bert plugins in TensorRT demo:
|
||||
https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/
|
||||
|
||||
Copyright 2019 NVIDIA Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/rocm/rocm_common.h"
|
||||
#include "core/providers/rocm/cu_inc/common.cuh"
|
||||
#include "core/providers/rocm/shared_inc/rocm_call.h"
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hipblas.h>
|
||||
#include <hipcub/hipcub.hpp>
|
||||
|
||||
using namespace onnxruntime::rocm;
|
||||
using namespace hipcub;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
|
||||
template <typename T>
|
||||
__device__ inline T Rsqrt(const T& x);
|
||||
|
||||
template <>
|
||||
__device__ inline float Rsqrt(const float& x) {
|
||||
return rsqrtf(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline half Rsqrt(const half& x) {
|
||||
return hrsqrt(x);
|
||||
}
|
||||
|
||||
__device__ inline half2 AddHalf2(const half2 a, const half2 b) {
|
||||
return __hadd2(a, b);
|
||||
}
|
||||
|
||||
struct KeyValuePairSum {
|
||||
__device__ inline hipcub::KeyValuePair<float, float> operator()(const hipcub::KeyValuePair<float, float>& a, const hipcub::KeyValuePair<float, float>& b) {
|
||||
return hipcub::KeyValuePair<float, float>(a.key + b.key, a.value + b.value);
|
||||
}
|
||||
|
||||
__device__ inline hipcub::KeyValuePair<half, half> operator()(const hipcub::KeyValuePair<half, half>& a, const hipcub::KeyValuePair<half, half>& b) {
|
||||
const half2 a2 = __halves2half2(a.key, a.value);
|
||||
const half2 b2 = __halves2half2(b.key, b.value);
|
||||
const half2 res = AddHalf2(a2, b2);
|
||||
return hipcub::KeyValuePair<half, half>(__low2half(res), __high2half(res));
|
||||
}
|
||||
|
||||
__device__ inline hipcub::KeyValuePair<half2, half2> operator()(const hipcub::KeyValuePair<half2, half2>& a, const hipcub::KeyValuePair<half2, half2>& b) {
|
||||
return hipcub::KeyValuePair<half2, half2>(AddHalf2(a.key, b.key), AddHalf2(a.value, b.value));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int TPB>
|
||||
__device__ inline void LayerNorm(
|
||||
const hipcub::KeyValuePair<T, T>& thread_data, const int ld, const int offset, const T* beta,
|
||||
const T* gamma, const T epsilon, T* output) {
|
||||
// Assuming thread_data is already divided by ld
|
||||
|
||||
using BlockReduce = hipcub::BlockReduce<hipcub::KeyValuePair<T, T>, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
__shared__ T mu; // mean
|
||||
__shared__ T rsigma; // 1 / std.dev.
|
||||
|
||||
KeyValuePairSum pair_sum;
|
||||
const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
mu = sum_kv.key;
|
||||
rsigma = Rsqrt(sum_kv.value - mu * mu + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int i = threadIdx.x; i < ld; i += TPB) {
|
||||
const int idx = offset + i;
|
||||
const T val = output[idx];
|
||||
const T g(gamma[i]);
|
||||
const T b = (nullptr == beta) ? (T)0 : beta[i];
|
||||
output[idx] = g * (val - mu) * rsigma + b;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int TPB>
|
||||
__device__ inline void LayerNormSmall(const T val, const hipcub::KeyValuePair<T, T>& thread_data, const int ld, const int idx,
|
||||
const T* beta, const T* gamma, const T epsilon, T* output) {
|
||||
// Assuming thread_data is already divided by ld
|
||||
// Small settings: the block covers the leading dimension TPB >= ld. The input
|
||||
// value is available in a register
|
||||
|
||||
using BlockReduce = hipcub::BlockReduce<hipcub::KeyValuePair<T, T>, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
__shared__ T mu; // mean
|
||||
__shared__ T rsigma; // 1 / std.dev.
|
||||
|
||||
KeyValuePairSum pair_sum;
|
||||
const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
mu = sum_kv.key;
|
||||
rsigma = Rsqrt(sum_kv.value - mu * mu + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < ld) {
|
||||
const T g(gamma[threadIdx.x]);
|
||||
const T b = (nullptr == beta) ? (T)0 : beta[threadIdx.x];
|
||||
output[idx] = g * (val - mu) * rsigma + b;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -128,8 +128,8 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
|
|||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ScaledTanh)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ScaledTanh)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ThresholdedRelu)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ThresholdedRelu)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ThresholdedRelu)>,
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ static void RunTest(
|
|||
std::vector<int64_t> bias_dims = gamma_dims;
|
||||
std::vector<int64_t> output_dims = input_dims;
|
||||
|
||||
auto rocm_ep = DefaultRocmExecutionProvider();
|
||||
if (!use_float16) {
|
||||
OpTester test("SkipLayerNormalization", 1, onnxruntime::kMSDomain);
|
||||
test.AddInput<float>("input", input_dims, input_data);
|
||||
|
|
@ -53,7 +54,8 @@ static void RunTest(
|
|||
|
||||
test.AddOutput<float>("output", output_dims, output_data);
|
||||
test.Run();
|
||||
} else if (HasCudaEnvironment(530 /*min_cuda_architecture*/)) {
|
||||
} else if (HasCudaEnvironment(530 /*min_cuda_architecture*/) ||
|
||||
rocm_ep != nullptr) {
|
||||
OpTester test("SkipLayerNormalization", 1, onnxruntime::kMSDomain);
|
||||
test.AddInput<MLFloat16>("input", input_dims, ToFloat16(input_data));
|
||||
test.AddInput<MLFloat16>("skip", skip_dims, ToFloat16(skip_data));
|
||||
|
|
@ -71,7 +73,12 @@ static void RunTest(
|
|||
test.AddOutput<MLFloat16>("output", output_dims, ToFloat16(output_data));
|
||||
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
if (rocm_ep != nullptr) {
|
||||
execution_providers.push_back(DefaultRocmExecutionProvider());
|
||||
} else {
|
||||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
}
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -32,10 +32,6 @@ contrib_ops_excluded_files = [
|
|||
'bert/longformer_attention_impl.h',
|
||||
'bert/longformer_global_impl.cu',
|
||||
'bert/longformer_global_impl.h',
|
||||
'bert/skip_layer_norm.cc',
|
||||
'bert/skip_layer_norm.h',
|
||||
'bert/skip_layer_norm_impl.cu',
|
||||
'bert/skip_layer_norm_impl.h',
|
||||
'math/bias_softmax.cc',
|
||||
'math/bias_softmax.h',
|
||||
'math/bias_softmax_impl.cu',
|
||||
|
|
|
|||
Loading…
Reference in a new issue