From 4e6ea730d633756e9e04df8968304d11a575dde4 Mon Sep 17 00:00:00 2001 From: Khalia Spear <134653808+khspear@users.noreply.github.com> Date: Mon, 7 Aug 2023 09:55:42 -0700 Subject: [PATCH] Broadcasting for SLN for CPU and CUDA (#16510) ### Description Enhanced SkipLayerNorm by implementing broadcasting for both CPU and CUDA ### Motivation and Context The input and skip tensors no longer have to be the same size which means that it can accept data where the skip shape can be the same size as the input shape, have a shape of {1, sequence_length, hidden_size}, or {sequence_length, hidden_size}. --------- Co-authored-by: Tianlei Wu --- .vscode/settings.json | 1 - docs/ContribOperators.md | 2 +- .../contrib_ops/cpu/skip_layer_norm.cc | 55 ++------ .../contrib_ops/cpu/skip_layer_norm_helper.h | 84 ++++++++++++ .../contrib_ops/cuda/bert/skip_layer_norm.cc | 67 +++------- .../cuda/bert/skip_layer_norm_impl.cu | 28 ++-- .../cuda/bert/skip_layer_norm_impl.h | 4 +- .../core/graph/contrib_ops/bert_defs.cc | 2 +- .../core/optimizer/skip_layer_norm_fusion.cc | 3 +- .../core/providers/cuda/nn/layer_norm_impl.cu | 22 +++- .../core/providers/cuda/nn/layer_norm_impl.h | 4 +- .../test/contrib_ops/skiplayernorm_op_test.cc | 121 +++++++++++++++++- 12 files changed, 276 insertions(+), 117 deletions(-) create mode 100644 onnxruntime/contrib_ops/cpu/skip_layer_norm_helper.h diff --git a/.vscode/settings.json b/.vscode/settings.json index fd28e2d7b3..b7a1292efb 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -40,4 +40,3 @@ "-build/include_subdir", "-runtime/references" ] -} diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 77b84e2b3f..4cdadeb1ff 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -4651,7 +4651,7 @@ This version of the operator has been available since version 1 of the 'com.micr
input : T
3D input tensor with shape (batch_size, sequence_length, hidden_size)
skip : T
-
3D skip tensor with shape (batch_size, sequence_length, hidden_size)
+
3D skip tensor with shape (batch_size, sequence_length, hidden_size) or (1, sequence_length, hidden_size) or (sequence_length, hidden_size)
gamma : T
1D input tensor with shape (hidden_size)
beta (optional) : T
diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc index b3668bcc0b..e86a12d9fb 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc @@ -6,6 +6,7 @@ #include "core/providers/common.h" #include "core/platform/threadpool.h" #include "skip_layer_norm.h" +#include "skip_layer_norm_helper.h" namespace onnxruntime { namespace contrib { @@ -45,51 +46,15 @@ Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { const auto& input_dims = input->Shape().GetDims(); size_t input_dims_size = input_dims.size(); - if (input_dims_size != 3 && input_dims_size != 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "input is expected to have 3 or 2 dimensions, got ", input_dims_size); - } - int hidden_size = static_cast(input_dims[input_dims_size - 1]); - if (input->Shape() != skip->Shape()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "skip is expected to have same shape as input"); - } - - const auto& gamma_dims = gamma->Shape().GetDims(); - if (gamma_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "gamma is expected to have 1 dimension, got ", gamma_dims.size()); - } - if (gamma_dims[0] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Last dimension of gamma and input does not match"); - } - - if (nullptr != beta) { - const auto& beta_dims = beta->Shape().GetDims(); - if (beta_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "beta is expected to have 1 dimension, got ", beta_dims.size()); - } - if (beta_dims[0] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Last dimension of beta and input does not match"); - } - } - - if (nullptr != bias) { - const auto& bias_dims = bias->Shape().GetDims(); - if (bias_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "bias is expected to have 1 dimension, got ", bias_dims.size()); - } - if (bias_dims[0] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Last dimension of bias and input does not match"); - } - } + ORT_RETURN_IF_ERROR(onnxruntime::contrib::skip_layer_norm_helper::CheckInputs(input, + skip, + gamma, + beta, + bias, + hidden_size, + input_dims_size)); int64_t task_count = input->Shape().SizeToDimension(input_dims_size - 1); @@ -105,13 +70,15 @@ Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { // of the input and skip tensors T* skip_input_bias_add_output_data = skip_input_bias_add_output != nullptr ? skip_input_bias_add_output->MutableData() : nullptr; + const auto& skip_size = skip->Shape().Size(); + concurrency::ThreadPool::TryBatchParallelFor( p_ctx->GetOperatorThreadPool(), static_cast(task_count), [&](ptrdiff_t task_idx) { auto offset = task_idx * hidden_size; const T* p_input = input_data + offset; - const T* p_skip = skip_data + offset; + const T* p_skip = skip_data + (offset % skip_size); T* p_output = output_data + offset; T* p_skip_input_bias_add_output_data = skip_input_bias_add_output_data != nullptr ? skip_input_bias_add_output_data + offset : nullptr; diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm_helper.h b/onnxruntime/contrib_ops/cpu/skip_layer_norm_helper.h new file mode 100644 index 0000000000..6271f82228 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm_helper.h @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/providers/common.h" +#include "contrib_ops/cpu/bert/attention_common.h" + +namespace onnxruntime { +namespace contrib { +namespace skip_layer_norm_helper { + +template +Status CheckInputs(const T* input, + const T* skip, + const T* gamma, + const T* beta, + const T* bias, + int hidden_size_check, + size_t input_dims_size_check) { + const auto& input_dims_check = input->Shape().GetDims(); + const auto& skip_dims_check = skip->Shape().GetDims(); + size_t skip_dims_size_check = skip_dims_check.size(); + + if (skip_dims_size_check != 3 && skip_dims_size_check != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "skip is expected to have 3 or 2 dimensions, got ", skip_dims_size_check); + } + + if ((input->Shape() != skip->Shape()) && ((skip_dims_check[0] != 1 || skip_dims_size_check != 2) && input_dims_size_check != 3)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "skip is expected to have same shape as input or, a batch size of 1 or no batch size when input has 3 dimensions"); + } + + if (input_dims_size_check != 3 && input_dims_size_check != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "input is expected to have 3 or 2 dimensions, got ", input_dims_size_check); + } + + if (skip_dims_check[skip_dims_size_check - 1] != input_dims_check[input_dims_size_check - 1] || skip_dims_check[skip_dims_size_check - 2] != input_dims_check[input_dims_size_check - 2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "last two dimensions of skip needs to be same as input"); + } + + const auto& gamma_dims = gamma->Shape().GetDims(); + if (gamma_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "gamma is expected to have 1 dimension, got ", gamma_dims.size()); + } + if (gamma_dims[0] != hidden_size_check) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Last dimension of gamma and input does not match"); + } + + if (nullptr != beta) { + const auto& beta_dims = beta->Shape().GetDims(); + if (beta_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "beta is expected to have 1 dimension, got ", beta_dims.size()); + } + if (beta_dims[0] != hidden_size_check) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Last dimension of beta and input does not match"); + } + } + + if (nullptr != bias) { + const auto& bias_dims = bias->Shape().GetDims(); + if (bias_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "bias is expected to have 1 dimension, got ", bias_dims.size()); + } + if (bias_dims[0] != hidden_size_check) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Last dimension of bias and input does not match"); + } + } + return Status::OK(); +} + +} // namespace skip_layer_norm_helper +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc index 62f93b5e5b..0d8f887d98 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc @@ -5,6 +5,7 @@ #include "core/providers/cuda/nn/layer_norm_impl.h" #include "skip_layer_norm.h" #include "skip_layer_norm_impl.h" +#include "contrib_ops/cpu/skip_layer_norm_helper.h" namespace onnxruntime { namespace contrib { @@ -60,59 +61,23 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const // of the input and skip tensors Tensor* skip_input_bias_add_output = ctx->Output(3, input->Shape()); - if (input->Shape() != skip->Shape()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "skip is expected to have same shape as input"); - } - - if (input->Shape().Size() == 0) { - return Status::OK(); - } - const auto& input_dims = input->Shape().GetDims(); size_t input_dims_size = input_dims.size(); - if (input_dims_size != 3 && input_dims_size != 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "input is expected to have 3 or 2 dimensions, got ", input_dims_size); - } + const auto& skip_dims = skip->Shape().GetDims(); + size_t skip_dims_size = skip_dims.size(); int hidden_size = static_cast(input_dims[input_dims_size - 1]); - const auto& gamma_dims = gamma->Shape().GetDims(); - if (gamma_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "gamma is expected to have 1 dimension, got ", gamma_dims.size()); - } - if (gamma_dims[0] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Last dimension of gamma and input does not match"); - } + ORT_RETURN_IF_ERROR(onnxruntime::contrib::skip_layer_norm_helper::CheckInputs(input, + skip, + gamma, + beta, + bias, + hidden_size, + input_dims_size)); - if (!Simplified) { - if (nullptr != beta) { - const auto& beta_dims = beta->Shape().GetDims(); - if (beta_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "beta is expected to have 1 dimension, got ", beta_dims.size()); - } - if (beta_dims[0] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Last dimension of beta and input does not match"); - } - } - } - - if (nullptr != bias) { - const auto& bias_dims = bias->Shape().GetDims(); - if (bias_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "bias is expected to have 1 dimension, got ", bias_dims.size()); - } - if (bias_dims[0] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Last dimension of bias and input does not match"); - } - } + const bool skip_broadcasted = (skip_dims[0] == 1 || skip_dims_size == 2) ? true : false; + const int skip_size = static_cast(skip_dims[skip_dims_size - 1] * skip_dims[skip_dims_size - 2]); if (strict_) { int row_count = gsl::narrow(input->Shape().SizeToDimension(input_dims_size - 1)); @@ -131,7 +96,9 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, // beta reinterpret_cast(skip->Data()), // skip or residual to add (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, // bias to add - skip_input_bias_add_output != nullptr ? reinterpret_cast(skip_input_bias_add_output->MutableData()) : nullptr); + skip_input_bias_add_output != nullptr ? reinterpret_cast(skip_input_bias_add_output->MutableData()) : nullptr, + skip_broadcasted, + skip_size); CUDA_RETURN_IF_ERROR(cudaGetLastError()); return Status::OK(); @@ -153,7 +120,9 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const epsilon_, hidden_size, static_cast(element_count), - element_size); + element_size, + skip_broadcasted, + skip_size); CUDA_RETURN_IF_ERROR(cudaGetLastError()); return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu index 0a815f7f49..1d0d97323d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu @@ -82,7 +82,7 @@ template __global__ void SkipLayerNormKernel( const int ld, const T* input, const T* skip, const T* beta, const T* gamma, const T* bias, - const T epsilon, T* output, T* skip_input_bias_add_output) { + const T epsilon, T* output, T* skip_input_bias_add_output, const bool skip_broadcasted, int skip_size) { const T reverse_ld = T(1.f / ld); const int offset = blockIdx.x * ld; @@ -90,10 +90,13 @@ __global__ void SkipLayerNormKernel( // reduce x and x^2 cub::KeyValuePair thread_data(0, 0); + for (int i = threadIdx.x; i < ld; i += TPB) { const int idx = offset + i; - const T val = (bias == nullptr) ? input[idx] + skip[idx] : input[idx] + skip[idx] + bias[i]; + const T skip_data = skip_broadcasted ? skip[idx % skip_size] : skip[idx]; + const T val = (bias == nullptr) ? input[idx] + skip_data : input[idx] + skip_data + bias[i]; + const T rldval = reverse_ld * val; thread_data = pair_sum(thread_data, cub::KeyValuePair(rldval, rldval * val)); @@ -115,7 +118,7 @@ template __global__ void SkipLayerNormKernelSmall( const int ld, const T* input, const T* skip, const T* beta, const T* gamma, const T* bias, const T epsilon, T* output, T* skip_input_bias_add_output, - bool hasBias, bool hasSkipInputBiasAdditionOutput) { + bool hasBias, bool hasSkipInputBiasAdditionOutput, const bool skip_broadcasted, const int skip_size) { const T rld = T(1.f / ld); const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld @@ -127,7 +130,11 @@ __global__ void SkipLayerNormKernelSmall( *input_val = *reinterpret_cast(&input[idx]); VecT* skip_val = reinterpret_cast(&skip_v); + if (skip_broadcasted){ + *skip_val = *reinterpret_cast(&skip[idx % skip_size]); + }else{ *skip_val = *reinterpret_cast(&skip[idx]); + } if (hasBias) { VecT* bias_val = reinterpret_cast(&bias_v); @@ -170,8 +177,13 @@ template Status LaunchSkipLayerNormKernel( cudaStream_t stream, T* output, T* skip_input_bias_add_output, const T* input, const T* skip, const T* gamma, const T* beta, const T* bias, float epsilon, const int ld, const int element_count, - size_t element_size) { + size_t element_size, const bool skip_broadcasted, const int skip_size) { // this must be true because n is the total size of the tensor + + if (element_count == 0) { + return Status::OK(); + } + assert(element_count % ld == 0); bool hasBias = (bias == nullptr) ? false : true; bool hasSkipInputBiasAdditionOutput = (skip_input_bias_add_output == nullptr) ? false : true; @@ -187,10 +199,10 @@ Status LaunchSkipLayerNormKernel( #define LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(num_unroll) \ SkipLayerNormKernelSmall \ <<>>(ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, \ - skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput) + skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput, skip_broadcasted, skip_size) #define LAUNCH_SKIP_LAYER_NORM_KERNEL() \ SkipLayerNormKernel<<>>( \ - ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, skip_input_bias_add_output) + ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, skip_input_bias_add_output, skip_broadcasted, skip_size) #define CASE_NEXT_SIZE(next_size_value) \ case next_size_value: { \ if (flag_vec4) { \ @@ -219,7 +231,6 @@ Status LaunchSkipLayerNormKernel( #undef LAUNCH_SKIP_LAYER_NORM_KERNEL #undef LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL } - return CUDA_CALL(cudaGetLastError()); } @@ -229,7 +240,8 @@ Status LaunchSkipLayerNormKernel( const T* input, const T* skip, const T* gamma, \ const T* beta, const T* bias, float epsilon, \ const int ld, const int element_count, \ - size_t element_size); + size_t element_size, const bool skip_broadcasted, \ + const int skip_size); SKIPLAYERNORM_IMPL(float, true); SKIPLAYERNORM_IMPL(float, false); diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h index da2894928c..4900e64be1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h @@ -21,7 +21,9 @@ Status LaunchSkipLayerNormKernel( float epsilon, // Layer normalization epsilon int hidden_size, // hidden size, it is the leading dimension (ld) int element_count, // number of elements in input tensor - size_t element_size); + size_t element_size, + const bool skip_broadcasted, // determines if broadcasting should be implemented + const int skip_size); // determines size of the skip tensor } // namespace cuda } // namespace contrib diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index d2ff7e5351..e5956a575d 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1097,7 +1097,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .SetDoc("Skip and Layer Normalization Fusion") .Attr("epsilon", "The epsilon value to use to avoid division by zero.", AttributeProto::FLOAT, kDefaultSkipLayerNormEpsilon) .Input(0, "input", "3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T") - .Input(1, "skip", "3D skip tensor with shape (batch_size, sequence_length, hidden_size)", "T") + .Input(1, "skip", "3D skip tensor with shape (batch_size, sequence_length, hidden_size) or (1, sequence_length, hidden_size) or (sequence_length, hidden_size)", "T") .Input(2, "gamma", "1D input tensor with shape (hidden_size)", "T") .Input(3, "beta", "1D skip tensor with shape (hidden_size", "T", OpSchema::Optional) .Input(4, "bias", "1D bias tensor with shape (hidden_size", "T", OpSchema::Optional) diff --git a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc index ba2f1ea52f..cf70a7d821 100644 --- a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc @@ -12,7 +12,8 @@ using namespace onnxruntime::common; namespace onnxruntime { // LayerNorm supports limited data types. -static constexpr std::array supported_data_types{"tensor(float16)", "tensor(float)", "tensor(bfloat16)"}; +static constexpr std::array supported_data_types{ + "tensor(float16)", "tensor(float)", "tensor(bfloat16)"}; static bool IsSupportedDataType(const Node& node) { for (const auto& input_arg : node.InputDefs()) { diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu index 458dcf0b0f..4cc560a117 100644 --- a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu +++ b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu @@ -338,7 +338,9 @@ __global__ void cuApplyLayerNorm( const V* __restrict__ beta, const T* __restrict__ skip, const T* __restrict__ bias, - T* __restrict__ skip_input_bias_add_output) { + T* __restrict__ skip_input_bias_add_output, + const bool skip_broadcasted, + const int skip_size) { // Assumptions: // 1) blockDim.x == GPU_WARP_SIZE // 2) Tensors are contiguous @@ -358,11 +360,16 @@ __global__ void cuApplyLayerNorm( for (int i = thrx; i < n2; i += numx) { U curr = static_cast(lvals[i]); + + if (bias != NULL) { curr += static_cast(bias[i]); } - if (skip_vals != NULL) { + if (skip_vals != NULL && skip_broadcasted) { + int skip_i = i % skip_size; + curr += static_cast(skip_vals[skip_i]); //Calculates index for the second dimension of the skip tensor + }else if (skip_vals != NULL){ curr += static_cast(skip_vals[i]); } @@ -411,7 +418,9 @@ void HostApplyLayerNorm( const V* beta, const T* skip, const T* bias, - T* skip_input_bias_add_output) { + T* skip_input_bias_add_output, + const bool skip_broadcasted, + const int skip_size) { const int maxGridY = prop.maxGridSize[1]; const int warp_size = prop.warpSize; ORT_ENFORCE(warp_size == GPU_WARP_SIZE_HOST); @@ -443,14 +452,17 @@ void HostApplyLayerNorm( n1, n2, U(epsilon), gamma, beta, - skip, bias, skip_input_bias_add_output); + skip, bias, skip_input_bias_add_output, + skip_broadcasted, + skip_size); } #define LAYERNORM_LINEAR_IMPL(T, U, V, simplified) \ template void HostApplyLayerNorm(const cudaDeviceProp& prop, cudaStream_t stream, V* output, \ U* mean, U* inv_std_dev, const T* input, int n1, int n2, \ double epsilon, const V* gamma, const V* beta, const T* skip, \ - const T* bias, T* skip_input_bias_add_output); + const T* bias, T* skip_input_bias_add_output, const bool skip_broadcasted, \ + const int skip_size); LAYERNORM_LINEAR_IMPL(float, float, float, true) LAYERNORM_LINEAR_IMPL(half, float, half, true) diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h index e3952eefae..d0d5db8ba3 100644 --- a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h +++ b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h @@ -43,7 +43,9 @@ void HostApplyLayerNorm( const V* beta, const T* skip = nullptr, const T* bias = nullptr, - T* skip_input_bias_add_output = nullptr); + T* skip_input_bias_add_output = nullptr, + const bool skip_broadcasted = false, + const int skip_size = 0); } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc index 67e279dc08..a41a1dd4ec 100644 --- a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc +++ b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc @@ -27,18 +27,31 @@ static void RunTest( bool no_beta = false, bool simplified = false, bool use_token_count = false, - bool strict = false) { + bool strict = false, + bool broadcast_skip = false, + bool no_batch_size = false) { // Input and output shapes // Input 0 - input: (batch_size, sequence_length, hidden_size) or (batch_size * sequence_length, hidden_size) - // Input 1 - skip : (batch_size, sequence_length, hidden_size) or (batch_size * sequence_length, hidden_size) + // Input 1 - skip : (batch_size, sequence_length, hidden_size) or (batch_size * sequence_length, hidden_size) or (1, sequence_length, hidden_size) or (sequence_length, hidden_size) // Input 2 - gamma: (hidden_size) // Input 3 - beta : (hidden_size) // Output : (batch_size, sequence_length, hidden_size) or (batch_size * sequence_length, hidden_size) std::vector input_dims = {batch_size, sequence_length, hidden_size}; + std::vector skip_dims = input_dims; + if (use_token_count) { input_dims = {batch_size * sequence_length, hidden_size}; + skip_dims = input_dims; } - std::vector skip_dims = input_dims; + + if (broadcast_skip) { + skip_dims = {1, sequence_length, hidden_size}; + } + + if (no_batch_size) { + skip_dims = {sequence_length, hidden_size}; + } + std::vector gamma_dims = {hidden_size}; std::vector beta_dims = gamma_dims; std::vector bias_dims = gamma_dims; @@ -48,6 +61,8 @@ static void RunTest( auto rocm_ep = DefaultRocmExecutionProvider(); auto dml_ep = DefaultDmlExecutionProvider(); + auto cpu_ep = DefaultCpuExecutionProvider(); + std::vector> execution_providers; if (!use_float16) { OpTester test(op_type.c_str(), 1, onnxruntime::kMSDomain); test.AddInput("input", input_dims, input_data); @@ -77,7 +92,10 @@ static void RunTest( skip_input_bias_add_output_data); } - test.Run(); + if (cpu_ep != nullptr) { + execution_providers.push_back(DefaultCpuExecutionProvider()); + } + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } else if (HasCudaEnvironment(530 /*min_cuda_architecture*/) || dml_ep != nullptr || rocm_ep != nullptr) { @@ -109,7 +127,6 @@ static void RunTest( ToFloat16(skip_input_bias_add_output_data)); } - std::vector> execution_providers; if (dml_ep != nullptr) { execution_providers.push_back(DefaultDmlExecutionProvider()); } else if (rocm_ep != nullptr) { @@ -718,6 +735,100 @@ TEST(SkipLayerNormTest, SkipSimplifiedLayerNormBatch1_Float16) { true, true); } + +TEST(SkipLayerNormTest, SkipLayerNormBatch2_Skip_Broadcast_No_Batch_Size) { + int batch_size = 2; + int sequence_length = 2; + int hidden_size = 4; + + std::vector input_data = { + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f, + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f}; + + std::vector skip_data = { + 0.1f, -0.2f, 0.3f, 1.0f, + 0.5f, 0.1f, 0.4f, 1.6f}; + + std::vector gamma_data = { + 0.3f, 0.2f, 4.0f, 2.2f}; + + std::vector beta_data = { + 0.2f, 0.1f, 0.4f, 1.6f}; + + std::vector output_data = { + 0.28433859348297119, -0.17090578377246857, -0.92897164821624756, 4.6924152374267578, + 0.46111652255058289, -0.21333980560302734, -0.29631003737449646, 3.5148544311523438, + 0.28433859348297119, -0.17090578377246857, -0.92897164821624756, 4.6924152374267578, + 0.46111652255058289, -0.21333980560302734, -0.29631003737449646, 3.5148544311523438}; + + RunTest(input_data, + skip_data, + gamma_data, + beta_data, + std::vector(), + output_data, + {}, + epsilon_, + batch_size, + sequence_length, + hidden_size, + false, + false, + false, + false, + false, + false, + true); +} + +TEST(SkipLayerNormTest, SkipLayerNormBatch2_Skip_Broadcast_Batch_Size_1) { + int batch_size = 2; + int sequence_length = 2; + int hidden_size = 4; + + std::vector input_data = { + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f, + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f}; + + std::vector skip_data = { + 0.1f, -0.2f, 0.3f, 1.0f, + 0.5f, 0.1f, 0.4f, 1.6f}; + + std::vector gamma_data = { + 0.3f, 0.2f, 4.0f, 2.2f}; + + std::vector beta_data = { + 0.2f, 0.1f, 0.4f, 1.6f}; + + std::vector output_data = { + 0.28433859348297119, -0.17090578377246857, -0.92897164821624756, 4.6924152374267578, + 0.46111652255058289, -0.21333980560302734, -0.29631003737449646, 3.5148544311523438, + 0.28433859348297119, -0.17090578377246857, -0.92897164821624756, 4.6924152374267578, + 0.46111652255058289, -0.21333980560302734, -0.29631003737449646, 3.5148544311523438}; + + RunTest(input_data, + skip_data, + gamma_data, + beta_data, + std::vector(), + output_data, + {}, + epsilon_, + batch_size, + sequence_length, + hidden_size, + false, + false, + false, + false, + false, + true, + false); +} #endif } // namespace test