From 2d352056cf050785f672d32ee3df86e1228fb040 Mon Sep 17 00:00:00 2001 From: Weixing Zhang Date: Fri, 2 Apr 2021 09:03:30 -0700 Subject: [PATCH] Support SkipLayerNorm for ROCm EP (#7210) Co-authored-by: Weixing Zhang --- .../contrib_ops/cuda/bert/skip_layer_norm.cc | 4 +- .../cuda/bert/skip_layer_norm_impl.cu | 4 +- .../contrib_ops/rocm/bert/layer_norm.cuh | 131 ++++++++++++++++++ .../contrib_ops/rocm/rocm_contrib_kernels.cc | 4 +- .../test/contrib_ops/skiplayernorm_op_test.cc | 11 +- tools/ci_build/amd_hipify.py | 4 - 6 files changed, 146 insertions(+), 12 deletions(-) create mode 100644 onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc index 2af85ca89f..da2e23b103 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc @@ -5,8 +5,8 @@ #include "core/providers/cuda/cuda_common.h" #include "core/framework/tensorprotoutils.h" #include "onnx/defs/tensor_proto_util.h" -#include "skip_layer_norm.h" -#include "skip_layer_norm_impl.h" +#include "contrib_ops/cuda/bert/skip_layer_norm.h" +#include "contrib_ops/cuda/bert/skip_layer_norm_impl.h" namespace onnxruntime { namespace contrib { diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu index a7b6aabe52..4501186946 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu @@ -20,8 +20,8 @@ limitations under the License. // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "layer_norm.cuh" -#include "skip_layer_norm_impl.h" +#include "contrib_ops/cuda/bert/layer_norm.cuh" +#include "contrib_ops/cuda/bert/skip_layer_norm_impl.h" #include namespace onnxruntime { diff --git a/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh b/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh new file mode 100644 index 0000000000..fa02a37aa6 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh @@ -0,0 +1,131 @@ +/* + The implementation of this file is based on bert plugins in TensorRT demo: + https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ + +Copyright 2019 NVIDIA Corporation + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include "core/providers/rocm/rocm_common.h" +#include "core/providers/rocm/cu_inc/common.cuh" +#include "core/providers/rocm/shared_inc/rocm_call.h" +#include +#include +#include + +using namespace onnxruntime::rocm; +using namespace hipcub; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +template +__device__ inline T Rsqrt(const T& x); + +template <> +__device__ inline float Rsqrt(const float& x) { + return rsqrtf(x); +} + +template <> +__device__ inline half Rsqrt(const half& x) { + return hrsqrt(x); +} + +__device__ inline half2 AddHalf2(const half2 a, const half2 b) { + return __hadd2(a, b); +} + +struct KeyValuePairSum { + __device__ inline hipcub::KeyValuePair operator()(const hipcub::KeyValuePair& a, const hipcub::KeyValuePair& b) { + return hipcub::KeyValuePair(a.key + b.key, a.value + b.value); + } + + __device__ inline hipcub::KeyValuePair operator()(const hipcub::KeyValuePair& a, const hipcub::KeyValuePair& b) { + const half2 a2 = __halves2half2(a.key, a.value); + const half2 b2 = __halves2half2(b.key, b.value); + const half2 res = AddHalf2(a2, b2); + return hipcub::KeyValuePair(__low2half(res), __high2half(res)); + } + + __device__ inline hipcub::KeyValuePair operator()(const hipcub::KeyValuePair& a, const hipcub::KeyValuePair& b) { + return hipcub::KeyValuePair(AddHalf2(a.key, b.key), AddHalf2(a.value, b.value)); + } +}; + +template +__device__ inline void LayerNorm( + const hipcub::KeyValuePair& thread_data, const int ld, const int offset, const T* beta, + const T* gamma, const T epsilon, T* output) { + // Assuming thread_data is already divided by ld + + using BlockReduce = hipcub::BlockReduce, TPB>; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ T mu; // mean + __shared__ T rsigma; // 1 / std.dev. + + KeyValuePairSum pair_sum; + const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum); + + if (threadIdx.x == 0) { + mu = sum_kv.key; + rsigma = Rsqrt(sum_kv.value - mu * mu + epsilon); + } + __syncthreads(); + + for (int i = threadIdx.x; i < ld; i += TPB) { + const int idx = offset + i; + const T val = output[idx]; + const T g(gamma[i]); + const T b = (nullptr == beta) ? (T)0 : beta[i]; + output[idx] = g * (val - mu) * rsigma + b; + } +} + +template +__device__ inline void LayerNormSmall(const T val, const hipcub::KeyValuePair& thread_data, const int ld, const int idx, + const T* beta, const T* gamma, const T epsilon, T* output) { + // Assuming thread_data is already divided by ld + // Small settings: the block covers the leading dimension TPB >= ld. The input + // value is available in a register + + using BlockReduce = hipcub::BlockReduce, TPB>; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ T mu; // mean + __shared__ T rsigma; // 1 / std.dev. + + KeyValuePairSum pair_sum; + const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum); + + if (threadIdx.x == 0) { + mu = sum_kv.key; + rsigma = Rsqrt(sum_kv.value - mu * mu + epsilon); + } + __syncthreads(); + + if (threadIdx.x < ld) { + const T g(gamma[threadIdx.x]); + const T b = (nullptr == beta) ? (T)0 : beta[threadIdx.x]; + output[idx] = g * (val - mu) * rsigma + b; + } +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc index a15d9eea03..4500d31939 100644 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc @@ -128,8 +128,8 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, diff --git a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc index d571f43628..96d0f8c6be 100644 --- a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc +++ b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc @@ -36,6 +36,7 @@ static void RunTest( std::vector bias_dims = gamma_dims; std::vector output_dims = input_dims; + auto rocm_ep = DefaultRocmExecutionProvider(); if (!use_float16) { OpTester test("SkipLayerNormalization", 1, onnxruntime::kMSDomain); test.AddInput("input", input_dims, input_data); @@ -53,7 +54,8 @@ static void RunTest( test.AddOutput("output", output_dims, output_data); test.Run(); - } else if (HasCudaEnvironment(530 /*min_cuda_architecture*/)) { + } else if (HasCudaEnvironment(530 /*min_cuda_architecture*/) || + rocm_ep != nullptr) { OpTester test("SkipLayerNormalization", 1, onnxruntime::kMSDomain); test.AddInput("input", input_dims, ToFloat16(input_data)); test.AddInput("skip", skip_dims, ToFloat16(skip_data)); @@ -71,7 +73,12 @@ static void RunTest( test.AddOutput("output", output_dims, ToFloat16(output_data)); std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); + if (rocm_ep != nullptr) { + execution_providers.push_back(DefaultRocmExecutionProvider()); + } else { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } } diff --git a/tools/ci_build/amd_hipify.py b/tools/ci_build/amd_hipify.py index 4b2bdc7329..06f09b17c0 100644 --- a/tools/ci_build/amd_hipify.py +++ b/tools/ci_build/amd_hipify.py @@ -32,10 +32,6 @@ contrib_ops_excluded_files = [ 'bert/longformer_attention_impl.h', 'bert/longformer_global_impl.cu', 'bert/longformer_global_impl.h', - 'bert/skip_layer_norm.cc', - 'bert/skip_layer_norm.h', - 'bert/skip_layer_norm_impl.cu', - 'bert/skip_layer_norm_impl.h', 'math/bias_softmax.cc', 'math/bias_softmax.h', 'math/bias_softmax_impl.cu',