diff --git a/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh b/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh new file mode 100644 index 0000000000..5ecf6815fb --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/layer_norm.cuh @@ -0,0 +1,157 @@ +#include "hip/hip_runtime.h" +/* + The implementation of this file is based on bert plugins in TensorRT demo: + https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ + +Copyright 2019 NVIDIA Corporation + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include "core/providers/rocm/rocm_common.h" +#include "core/providers/rocm/cu_inc/common.cuh" +#include "core/providers/rocm/shared_inc/rocm_call.h" + +using namespace onnxruntime::rocm; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +template +__device__ inline T Rsqrt(const T& x); + +template <> +__device__ inline float Rsqrt(const float& x) { + return rsqrtf(x); +} + +template <> +__device__ inline half Rsqrt(const half& x) { +#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__) + return hrsqrt(x); +#else + return half(rsqrtf(static_cast(x))); +#endif +} + +__device__ inline half2 AddHalf2(const half2 a, const half2 b) { +#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__) + return __hadd2(a, b); +#else + return __halves2half2(__hadd(a.x, b.x), __hadd(a.y, b.y)); +#endif +} + +struct KeyValuePairSum { + __device__ inline hipcub::KeyValuePair operator()(const hipcub::KeyValuePair& a, + const hipcub::KeyValuePair& b) { + return hipcub::KeyValuePair(a.key + b.key, a.value + b.value); + } + + __device__ inline hipcub::KeyValuePair operator()(const hipcub::KeyValuePair& a, + const hipcub::KeyValuePair& b) { + const half2 a2 = __halves2half2(a.key, a.value); + const half2 b2 = __halves2half2(b.key, b.value); + const half2 res = AddHalf2(a2, b2); + return hipcub::KeyValuePair(__low2half(res), __high2half(res)); + } + + __device__ inline hipcub::KeyValuePair operator()(const hipcub::KeyValuePair& a, + const hipcub::KeyValuePair& b) { + return hipcub::KeyValuePair(AddHalf2(a.key, b.key), AddHalf2(a.value, b.value)); + } +}; + +template +__device__ inline void LayerNorm( + const hipcub::KeyValuePair& thread_data, const int ld, const int offset, const T* beta, + const T* gamma, const T epsilon, T* output) { + // Assuming thread_data is already divided by ld + + using BlockReduce = hipcub::BlockReduce, TPB>; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ T mu; // mean + __shared__ T rsigma; // 1 / std.dev. + + KeyValuePairSum pair_sum; + const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum); + + if (threadIdx.x == 0) { + mu = sum_kv.key; + rsigma = Rsqrt(sum_kv.value - mu * mu + epsilon); + } + __syncthreads(); + + for (int i = threadIdx.x; i < ld; i += TPB) { + const int idx = offset + i; + const T val = output[idx]; + const T g(gamma[i]); + const T b = (nullptr == beta) ? (T)0 : beta[i]; + output[idx] = g * (val - mu) * rsigma + b; + } +} + +template +__device__ inline void LayerNormSmall(const T* input_v, const hipcub::KeyValuePair& thread_data, + const int ld, const int idx, const T* beta, const T* gamma, + const T epsilon, T* output) { + // Assuming thread_data is already divided by ld + // Small settings: the block covers the leading dimension TPB >= ld. The input + // value is available in a register + using VecT = aligned_vector; + using BlockReduce = hipcub::BlockReduce, TPB>; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ T mu; // mean + __shared__ T rsigma; // 1 / std.dev. + T beta_v[ILP], gamma_v[ILP], output_v[ILP]; + + if (beta != nullptr) { + VecT* beta_val = reinterpret_cast(&beta_v); + *beta_val = *reinterpret_cast(&beta[threadIdx.x * ILP]); + } + VecT* gamma_val = reinterpret_cast(&gamma_v); + *gamma_val = *reinterpret_cast(&gamma[threadIdx.x * ILP]); + + VecT* output_val = reinterpret_cast(&output_v); + + KeyValuePairSum pair_sum; + const hipcub::KeyValuePair sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum); + + if (threadIdx.x == 0) { + mu = sum_kv.key; + rsigma = Rsqrt(sum_kv.value - mu * mu + epsilon); + } + __syncthreads(); + + if (ILP * threadIdx.x < ld) { + #pragma unroll + for (int i = 0; i < ILP; i++) { + output_v[i] = (beta != nullptr) ? gamma_v[i] * (input_v[i] - mu) * rsigma + beta_v[i] : + gamma_v[i] * (input_v[i] - mu) * rsigma; + } + *(reinterpret_cast(&output[idx])) = *reinterpret_cast(&output_v[0]); + } +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime + diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc new file mode 100644 index 0000000000..b71521f0c1 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc @@ -0,0 +1,122 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/rocm/rocm_common.h" +#include "contrib_ops/rocm/bert/skip_layer_norm.h" +#include "contrib_ops/rocm/bert/skip_layer_norm_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + SkipLayerNormalization, \ + kMSDomain, \ + 1, \ + T, \ + kRocmExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + SkipLayerNorm); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) + +using namespace ONNX_NAMESPACE; + +template +SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) { + ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon_).IsOK()); + ORT_ENFORCE(epsilon_ >= 0); +} + +template +Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const { + const Tensor* input = ctx->Input(0); + const Tensor* skip = ctx->Input(1); + const Tensor* gamma = ctx->Input(2); + const Tensor* beta = ctx->Input(3); + const Tensor* bias = ctx->Input(4); + + Tensor* output = ctx->Output(0, input->Shape()); + + if (input->Shape() != skip->Shape()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "skip is expected to have same shape as input"); + } + + if (input->Shape().Size() == 0) { + return Status::OK(); + } + + const auto& input_dims = input->Shape().GetDims(); + if (input_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "input is expected to have 3 dimensions, got ", input_dims.size()); + } + + const auto& gamma_dims = gamma->Shape().GetDims(); + if (gamma_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "gamma is expected to have 1 dimension, got ", gamma_dims.size()); + } + if (gamma_dims[0] != input_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Last dimension of gamma and input does not match"); + } + + if (nullptr != beta) { + const auto& beta_dims = beta->Shape().GetDims(); + if (beta_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "beta is expected to have 1 dimension, got ", beta_dims.size()); + } + if (beta_dims[0] != input_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Last dimension of beta and input does not match"); + } + } + + if (nullptr != bias) { + const auto& bias_dims = bias->Shape().GetDims(); + if (bias_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "bias is expected to have 1 dimension, got ", bias_dims.size()); + } + if (bias_dims[0] != input_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Last dimension of bias and input does not match"); + } + } + + int sequence_length = static_cast(input_dims[1]); + int hidden_size = static_cast(input_dims[2]); + int64_t element_count = input_dims[0] * sequence_length * hidden_size; + size_t element_size = sizeof(T); + typedef typename ToHipType::MappedType HipT; + + if (!LaunchSkipLayerNormKernel( + Stream(), + reinterpret_cast(output->template MutableData()), + reinterpret_cast(input->template Data()), + reinterpret_cast(skip->template Data()), + reinterpret_cast(gamma->template Data()), + (beta != nullptr) ? reinterpret_cast(beta->template Data()) : nullptr, + (bias != nullptr) ? reinterpret_cast(bias->template Data()) : nullptr, + epsilon_, + hidden_size, + static_cast(element_count), + element_size)) { + // Get last error to reset it to hipSuccess. + HIP_CALL(hipGetLastError()); + return Status(common::ONNXRUNTIME, common::FAIL); + } + + return Status::OK(); +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime + diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.h new file mode 100644 index 0000000000..07d7037227 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/rocm/rocm_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +using namespace onnxruntime::rocm; + +template +class SkipLayerNorm final : public RocmKernel { + public: + SkipLayerNorm(const OpKernelInfo& op_kernel_info); + Status ComputeInternal(OpKernelContext* context) const override; + + private: + float epsilon_; +}; + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu new file mode 100644 index 0000000000..4b99e0e527 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu @@ -0,0 +1,197 @@ +#include "hip/hip_runtime.h" +/* + The implementation of this file is based on skipLayerNorm plugin in TensorRT demo: + https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ + +Copyright 2019 NVIDIA Corporation + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Modifications: Add SkipLayerNormKernelVec to +// leverage vectorized load/write. +// and templatize ComputeSkipLayerNorm for different +// data types. +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +// Licensed under the MIT License. + +#include +#include "contrib_ops/rocm/bert/layer_norm.cuh" +#include "contrib_ops/rocm/bert/skip_layer_norm_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +template +T maybe2half(float x); + +template<> +float maybe2half(float x) { + return x; +} + +template<> +half maybe2half(float x) { + return __float2half_rn(x); +} + +template +__global__ void SkipLayerNormKernel( + const int ld, const T* input, const T* skip, const T* beta, const T* gamma, const T* bias, + const T epsilon, T* output) { + const T reverse_ld = T(1.f / ld); + const int offset = blockIdx.x * ld; + + KeyValuePairSum pair_sum; + // reduce x and x^2 + hipcub::KeyValuePair thread_data(0, 0); + + for (int i = threadIdx.x; i < ld; i += TPB) { + const int idx = offset + i; + const T val = (bias == nullptr) ? input[idx] + skip[idx] : input[idx] + skip[idx] + bias[i]; + const T rldval = reverse_ld * val; + thread_data = pair_sum(thread_data, hipcub::KeyValuePair(rldval, rldval * val)); + output[idx] = val; + } + + LayerNorm(thread_data, ld, offset, beta, gamma, epsilon, output); +} + +// Vectorized kernel +template +__global__ void SkipLayerNormKernelSmall( + const int ld, const T* input, const T* skip, const T* beta, const T* gamma, + const T* bias, const T epsilon, T* output, bool hasBias) { + const T rld = T(1.f / ld); + const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld + + using VecT = aligned_vector; + __shared__ T mu; // mean + __shared__ T rsigma; // 1 / std.dev. + + T input_v[ILP], skip_v[ILP], bias_v[ILP], output_v[ILP]; + + VecT* input_val = reinterpret_cast(&input_v); + *input_val = *reinterpret_cast(&input[idx]); + + VecT* skip_val = reinterpret_cast(&skip_v); + *skip_val = *reinterpret_cast(&skip[idx]); + + if (hasBias) { + VecT* bias_val = reinterpret_cast(&bias_v); + *bias_val = *reinterpret_cast(&bias[threadIdx.x * ILP]); + } + + hipcub::KeyValuePair thread_data(T(0.f), T(0.f)); + + if (ILP * threadIdx.x < ld) { + T rldval_sum = T(0.f); + T rldvalsq_sum = T(0.f); + #pragma unroll + for (int i = 0; i < ILP; i++) { + input_v[i] += hasBias ? skip_v[i] + bias_v[i]: skip_v[i]; + const T rldval = rld * input_v[i]; + rldval_sum += rldval; + rldvalsq_sum += rldval * input_v[i]; + } + thread_data = hipcub::KeyValuePair(rldval_sum, rldvalsq_sum); + } + LayerNormSmall(input_v, thread_data, ld, idx, beta, gamma, epsilon, output); +} + +template +bool LaunchSkipLayerNormKernel( + hipStream_t stream, T* output, const T* input, const T* skip, const T* gamma, + const T* beta, const T* bias, float epsilon, const int ld, const int element_count, + size_t element_size) { + + // this must be true because n is the total size of the tensor + assert(element_count % ld == 0); + bool hasBias = (bias == nullptr) ? false : true; + if (0 == (ld % 4)) { + const int grid_size = element_count / ld; + if (ld <= 32) { + constexpr int block_size = 32; + hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall), grid_size, block_size, + 0, stream, ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, hasBias); + } else if (ld <= 64) { + constexpr int block_size = 64 / 2; + hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall), grid_size, block_size, + 0, stream, ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, hasBias); + } else if (ld <= 128) { + constexpr int block_size = 128 / 4; + hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall), grid_size, block_size, + 0, stream, ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, hasBias); + } else if (ld <= 384) { + constexpr int block_size = 384 / 4; + hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall), grid_size, block_size, + 0, stream, ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, hasBias); + } else if (ld <= 768) { + constexpr int block_size = 768 / 4; + hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall), grid_size, block_size, + 0, stream, ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, hasBias); + } else if (ld <= 1024) { + constexpr int block_size = 1024 / 4; + hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall), grid_size, block_size, + 0, stream, ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, hasBias); + } else { + constexpr int block_size = 256; + hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernel), grid_size, block_size, + 0, stream, ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output); + } + } else { + const int grid_size = element_count / ld; + if (ld <= 32) { + constexpr int block_size = 32; + hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall), grid_size, block_size, + 0, stream, ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, hasBias); + } else if (ld <= 64) { + constexpr int block_size = 64; + hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall), grid_size, block_size, + 0, stream, ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, hasBias); + } else if (ld <= 128) { + constexpr int block_size = 128; + hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall), grid_size, block_size, + 0, stream, ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, hasBias); + } else if (ld == 384) { + constexpr int block_size = 384; + hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall), grid_size, block_size, + 0, stream, ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, hasBias); + } else { + constexpr int block_size = 256; + hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernel), grid_size, block_size, + 0, stream, ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output); + } + } + return HIP_CALL(hipPeekAtLastError()); +} + +template bool LaunchSkipLayerNormKernel(hipStream_t stream, float* output, const float* input, + const float* skip, const float* gamma, const float* beta, + const float* bias, float epsilon, const int ld, + const int element_count, size_t element_size); + +template bool LaunchSkipLayerNormKernel(hipStream_t stream, half* output, const half* input, + const half* skip, const half* gamma, const half* beta, + const half* bias, float epsilon, const int ld, + const int element_count, size_t element_size); + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime + + diff --git a/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h new file mode 100644 index 0000000000..08ab0a5f06 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +template +bool LaunchSkipLayerNormKernel( + hipStream_t stream, + T* output, // output tensor + const T* input, // input tensor + const T* skip, // skip tensor + const T* gamma, // Layer normalization gamma tensor + const T* beta, // Layer normalization beta tensor + const T* bias, // Layer normalization beta tensor + float epsilon, // Layer normalization epsilon + int hidden_size, // hidden size, it is the leading dimension (ld) + int element_count, // number of elements in input tensor + size_t element_size +); + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime + diff --git a/tools/ci_build/amd_hipify.py b/tools/ci_build/amd_hipify.py index bbf0dcc242..650747a155 100644 --- a/tools/ci_build/amd_hipify.py +++ b/tools/ci_build/amd_hipify.py @@ -29,7 +29,11 @@ contrib_ops_excluded_files = [ "bert/fast_gelu_impl.h", "bert/fast_gelu.cc", "bert/fast_gelu.h", - # 'bert/layer_norm.cuh', + "bert/layer_norm.cuh", + "bert/skip_layer_norm.cc", + "bert/skip_layer_norm.h", + "bert/skip_layer_norm_impl.cu", + "bert/skip_layer_norm_impl.h", "bert/longformer_attention.cc", "bert/longformer_attention.h", "bert/longformer_attention_softmax.cu",