diff --git a/.vscode/settings.json b/.vscode/settings.json index fd28e2d7b3..b7a1292efb 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -40,4 +40,3 @@ "-build/include_subdir", "-runtime/references" ] -} diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 77b84e2b3f..4cdadeb1ff 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -4651,7 +4651,7 @@ This version of the operator has been available since version 1 of the 'com.micr
input : T
3D input tensor with shape (batch_size, sequence_length, hidden_size)
skip : T
-
3D skip tensor with shape (batch_size, sequence_length, hidden_size)
+
3D skip tensor with shape (batch_size, sequence_length, hidden_size) or (1, sequence_length, hidden_size) or (sequence_length, hidden_size)
gamma : T
1D input tensor with shape (hidden_size)
beta (optional) : T
diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc index b3668bcc0b..e86a12d9fb 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc @@ -6,6 +6,7 @@ #include "core/providers/common.h" #include "core/platform/threadpool.h" #include "skip_layer_norm.h" +#include "skip_layer_norm_helper.h" namespace onnxruntime { namespace contrib { @@ -45,51 +46,15 @@ Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { const auto& input_dims = input->Shape().GetDims(); size_t input_dims_size = input_dims.size(); - if (input_dims_size != 3 && input_dims_size != 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "input is expected to have 3 or 2 dimensions, got ", input_dims_size); - } - int hidden_size = static_cast(input_dims[input_dims_size - 1]); - if (input->Shape() != skip->Shape()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "skip is expected to have same shape as input"); - } - - const auto& gamma_dims = gamma->Shape().GetDims(); - if (gamma_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "gamma is expected to have 1 dimension, got ", gamma_dims.size()); - } - if (gamma_dims[0] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Last dimension of gamma and input does not match"); - } - - if (nullptr != beta) { - const auto& beta_dims = beta->Shape().GetDims(); - if (beta_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "beta is expected to have 1 dimension, got ", beta_dims.size()); - } - if (beta_dims[0] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Last dimension of beta and input does not match"); - } - } - - if (nullptr != bias) { - const auto& bias_dims = bias->Shape().GetDims(); - if (bias_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "bias is expected to have 1 dimension, got ", bias_dims.size()); - } - if (bias_dims[0] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Last dimension of bias and input does not match"); - } - } + ORT_RETURN_IF_ERROR(onnxruntime::contrib::skip_layer_norm_helper::CheckInputs(input, + skip, + gamma, + beta, + bias, + hidden_size, + input_dims_size)); int64_t task_count = input->Shape().SizeToDimension(input_dims_size - 1); @@ -105,13 +70,15 @@ Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { // of the input and skip tensors T* skip_input_bias_add_output_data = skip_input_bias_add_output != nullptr ? skip_input_bias_add_output->MutableData() : nullptr; + const auto& skip_size = skip->Shape().Size(); + concurrency::ThreadPool::TryBatchParallelFor( p_ctx->GetOperatorThreadPool(), static_cast(task_count), [&](ptrdiff_t task_idx) { auto offset = task_idx * hidden_size; const T* p_input = input_data + offset; - const T* p_skip = skip_data + offset; + const T* p_skip = skip_data + (offset % skip_size); T* p_output = output_data + offset; T* p_skip_input_bias_add_output_data = skip_input_bias_add_output_data != nullptr ? skip_input_bias_add_output_data + offset : nullptr; diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm_helper.h b/onnxruntime/contrib_ops/cpu/skip_layer_norm_helper.h new file mode 100644 index 0000000000..6271f82228 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm_helper.h @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/providers/common.h" +#include "contrib_ops/cpu/bert/attention_common.h" + +namespace onnxruntime { +namespace contrib { +namespace skip_layer_norm_helper { + +template +Status CheckInputs(const T* input, + const T* skip, + const T* gamma, + const T* beta, + const T* bias, + int hidden_size_check, + size_t input_dims_size_check) { + const auto& input_dims_check = input->Shape().GetDims(); + const auto& skip_dims_check = skip->Shape().GetDims(); + size_t skip_dims_size_check = skip_dims_check.size(); + + if (skip_dims_size_check != 3 && skip_dims_size_check != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "skip is expected to have 3 or 2 dimensions, got ", skip_dims_size_check); + } + + if ((input->Shape() != skip->Shape()) && ((skip_dims_check[0] != 1 || skip_dims_size_check != 2) && input_dims_size_check != 3)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "skip is expected to have same shape as input or, a batch size of 1 or no batch size when input has 3 dimensions"); + } + + if (input_dims_size_check != 3 && input_dims_size_check != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "input is expected to have 3 or 2 dimensions, got ", input_dims_size_check); + } + + if (skip_dims_check[skip_dims_size_check - 1] != input_dims_check[input_dims_size_check - 1] || skip_dims_check[skip_dims_size_check - 2] != input_dims_check[input_dims_size_check - 2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "last two dimensions of skip needs to be same as input"); + } + + const auto& gamma_dims = gamma->Shape().GetDims(); + if (gamma_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "gamma is expected to have 1 dimension, got ", gamma_dims.size()); + } + if (gamma_dims[0] != hidden_size_check) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Last dimension of gamma and input does not match"); + } + + if (nullptr != beta) { + const auto& beta_dims = beta->Shape().GetDims(); + if (beta_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "beta is expected to have 1 dimension, got ", beta_dims.size()); + } + if (beta_dims[0] != hidden_size_check) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Last dimension of beta and input does not match"); + } + } + + if (nullptr != bias) { + const auto& bias_dims = bias->Shape().GetDims(); + if (bias_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "bias is expected to have 1 dimension, got ", bias_dims.size()); + } + if (bias_dims[0] != hidden_size_check) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Last dimension of bias and input does not match"); + } + } + return Status::OK(); +} + +} // namespace skip_layer_norm_helper +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc index 62f93b5e5b..0d8f887d98 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc @@ -5,6 +5,7 @@ #include "core/providers/cuda/nn/layer_norm_impl.h" #include "skip_layer_norm.h" #include "skip_layer_norm_impl.h" +#include "contrib_ops/cpu/skip_layer_norm_helper.h" namespace onnxruntime { namespace contrib { @@ -60,59 +61,23 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const // of the input and skip tensors Tensor* skip_input_bias_add_output = ctx->Output(3, input->Shape()); - if (input->Shape() != skip->Shape()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "skip is expected to have same shape as input"); - } - - if (input->Shape().Size() == 0) { - return Status::OK(); - } - const auto& input_dims = input->Shape().GetDims(); size_t input_dims_size = input_dims.size(); - if (input_dims_size != 3 && input_dims_size != 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "input is expected to have 3 or 2 dimensions, got ", input_dims_size); - } + const auto& skip_dims = skip->Shape().GetDims(); + size_t skip_dims_size = skip_dims.size(); int hidden_size = static_cast(input_dims[input_dims_size - 1]); - const auto& gamma_dims = gamma->Shape().GetDims(); - if (gamma_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "gamma is expected to have 1 dimension, got ", gamma_dims.size()); - } - if (gamma_dims[0] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Last dimension of gamma and input does not match"); - } + ORT_RETURN_IF_ERROR(onnxruntime::contrib::skip_layer_norm_helper::CheckInputs(input, + skip, + gamma, + beta, + bias, + hidden_size, + input_dims_size)); - if (!Simplified) { - if (nullptr != beta) { - const auto& beta_dims = beta->Shape().GetDims(); - if (beta_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "beta is expected to have 1 dimension, got ", beta_dims.size()); - } - if (beta_dims[0] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Last dimension of beta and input does not match"); - } - } - } - - if (nullptr != bias) { - const auto& bias_dims = bias->Shape().GetDims(); - if (bias_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "bias is expected to have 1 dimension, got ", bias_dims.size()); - } - if (bias_dims[0] != hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Last dimension of bias and input does not match"); - } - } + const bool skip_broadcasted = (skip_dims[0] == 1 || skip_dims_size == 2) ? true : false; + const int skip_size = static_cast(skip_dims[skip_dims_size - 1] * skip_dims[skip_dims_size - 2]); if (strict_) { int row_count = gsl::narrow(input->Shape().SizeToDimension(input_dims_size - 1)); @@ -131,7 +96,9 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, // beta reinterpret_cast(skip->Data()), // skip or residual to add (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, // bias to add - skip_input_bias_add_output != nullptr ? reinterpret_cast(skip_input_bias_add_output->MutableData()) : nullptr); + skip_input_bias_add_output != nullptr ? reinterpret_cast(skip_input_bias_add_output->MutableData()) : nullptr, + skip_broadcasted, + skip_size); CUDA_RETURN_IF_ERROR(cudaGetLastError()); return Status::OK(); @@ -153,7 +120,9 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const epsilon_, hidden_size, static_cast(element_count), - element_size); + element_size, + skip_broadcasted, + skip_size); CUDA_RETURN_IF_ERROR(cudaGetLastError()); return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu index 0a815f7f49..1d0d97323d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu @@ -82,7 +82,7 @@ template __global__ void SkipLayerNormKernel( const int ld, const T* input, const T* skip, const T* beta, const T* gamma, const T* bias, - const T epsilon, T* output, T* skip_input_bias_add_output) { + const T epsilon, T* output, T* skip_input_bias_add_output, const bool skip_broadcasted, int skip_size) { const T reverse_ld = T(1.f / ld); const int offset = blockIdx.x * ld; @@ -90,10 +90,13 @@ __global__ void SkipLayerNormKernel( // reduce x and x^2 cub::KeyValuePair thread_data(0, 0); + for (int i = threadIdx.x; i < ld; i += TPB) { const int idx = offset + i; - const T val = (bias == nullptr) ? input[idx] + skip[idx] : input[idx] + skip[idx] + bias[i]; + const T skip_data = skip_broadcasted ? skip[idx % skip_size] : skip[idx]; + const T val = (bias == nullptr) ? input[idx] + skip_data : input[idx] + skip_data + bias[i]; + const T rldval = reverse_ld * val; thread_data = pair_sum(thread_data, cub::KeyValuePair(rldval, rldval * val)); @@ -115,7 +118,7 @@ template __global__ void SkipLayerNormKernelSmall( const int ld, const T* input, const T* skip, const T* beta, const T* gamma, const T* bias, const T epsilon, T* output, T* skip_input_bias_add_output, - bool hasBias, bool hasSkipInputBiasAdditionOutput) { + bool hasBias, bool hasSkipInputBiasAdditionOutput, const bool skip_broadcasted, const int skip_size) { const T rld = T(1.f / ld); const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld @@ -127,7 +130,11 @@ __global__ void SkipLayerNormKernelSmall( *input_val = *reinterpret_cast(&input[idx]); VecT* skip_val = reinterpret_cast(&skip_v); + if (skip_broadcasted){ + *skip_val = *reinterpret_cast(&skip[idx % skip_size]); + }else{ *skip_val = *reinterpret_cast(&skip[idx]); + } if (hasBias) { VecT* bias_val = reinterpret_cast(&bias_v); @@ -170,8 +177,13 @@ template Status LaunchSkipLayerNormKernel( cudaStream_t stream, T* output, T* skip_input_bias_add_output, const T* input, const T* skip, const T* gamma, const T* beta, const T* bias, float epsilon, const int ld, const int element_count, - size_t element_size) { + size_t element_size, const bool skip_broadcasted, const int skip_size) { // this must be true because n is the total size of the tensor + + if (element_count == 0) { + return Status::OK(); + } + assert(element_count % ld == 0); bool hasBias = (bias == nullptr) ? false : true; bool hasSkipInputBiasAdditionOutput = (skip_input_bias_add_output == nullptr) ? false : true; @@ -187,10 +199,10 @@ Status LaunchSkipLayerNormKernel( #define LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(num_unroll) \ SkipLayerNormKernelSmall \ <<>>(ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, \ - skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput) + skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput, skip_broadcasted, skip_size) #define LAUNCH_SKIP_LAYER_NORM_KERNEL() \ SkipLayerNormKernel<<>>( \ - ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, skip_input_bias_add_output) + ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, skip_input_bias_add_output, skip_broadcasted, skip_size) #define CASE_NEXT_SIZE(next_size_value) \ case next_size_value: { \ if (flag_vec4) { \ @@ -219,7 +231,6 @@ Status LaunchSkipLayerNormKernel( #undef LAUNCH_SKIP_LAYER_NORM_KERNEL #undef LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL } - return CUDA_CALL(cudaGetLastError()); } @@ -229,7 +240,8 @@ Status LaunchSkipLayerNormKernel( const T* input, const T* skip, const T* gamma, \ const T* beta, const T* bias, float epsilon, \ const int ld, const int element_count, \ - size_t element_size); + size_t element_size, const bool skip_broadcasted, \ + const int skip_size); SKIPLAYERNORM_IMPL(float, true); SKIPLAYERNORM_IMPL(float, false); diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h index da2894928c..4900e64be1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h @@ -21,7 +21,9 @@ Status LaunchSkipLayerNormKernel( float epsilon, // Layer normalization epsilon int hidden_size, // hidden size, it is the leading dimension (ld) int element_count, // number of elements in input tensor - size_t element_size); + size_t element_size, + const bool skip_broadcasted, // determines if broadcasting should be implemented + const int skip_size); // determines size of the skip tensor } // namespace cuda } // namespace contrib diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index d2ff7e5351..e5956a575d 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1097,7 +1097,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .SetDoc("Skip and Layer Normalization Fusion") .Attr("epsilon", "The epsilon value to use to avoid division by zero.", AttributeProto::FLOAT, kDefaultSkipLayerNormEpsilon) .Input(0, "input", "3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T") - .Input(1, "skip", "3D skip tensor with shape (batch_size, sequence_length, hidden_size)", "T") + .Input(1, "skip", "3D skip tensor with shape (batch_size, sequence_length, hidden_size) or (1, sequence_length, hidden_size) or (sequence_length, hidden_size)", "T") .Input(2, "gamma", "1D input tensor with shape (hidden_size)", "T") .Input(3, "beta", "1D skip tensor with shape (hidden_size", "T", OpSchema::Optional) .Input(4, "bias", "1D bias tensor with shape (hidden_size", "T", OpSchema::Optional) diff --git a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc index ba2f1ea52f..cf70a7d821 100644 --- a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc @@ -12,7 +12,8 @@ using namespace onnxruntime::common; namespace onnxruntime { // LayerNorm supports limited data types. -static constexpr std::array supported_data_types{"tensor(float16)", "tensor(float)", "tensor(bfloat16)"}; +static constexpr std::array supported_data_types{ + "tensor(float16)", "tensor(float)", "tensor(bfloat16)"}; static bool IsSupportedDataType(const Node& node) { for (const auto& input_arg : node.InputDefs()) { diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu index 458dcf0b0f..4cc560a117 100644 --- a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu +++ b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu @@ -338,7 +338,9 @@ __global__ void cuApplyLayerNorm( const V* __restrict__ beta, const T* __restrict__ skip, const T* __restrict__ bias, - T* __restrict__ skip_input_bias_add_output) { + T* __restrict__ skip_input_bias_add_output, + const bool skip_broadcasted, + const int skip_size) { // Assumptions: // 1) blockDim.x == GPU_WARP_SIZE // 2) Tensors are contiguous @@ -358,11 +360,16 @@ __global__ void cuApplyLayerNorm( for (int i = thrx; i < n2; i += numx) { U curr = static_cast(lvals[i]); + + if (bias != NULL) { curr += static_cast(bias[i]); } - if (skip_vals != NULL) { + if (skip_vals != NULL && skip_broadcasted) { + int skip_i = i % skip_size; + curr += static_cast(skip_vals[skip_i]); //Calculates index for the second dimension of the skip tensor + }else if (skip_vals != NULL){ curr += static_cast(skip_vals[i]); } @@ -411,7 +418,9 @@ void HostApplyLayerNorm( const V* beta, const T* skip, const T* bias, - T* skip_input_bias_add_output) { + T* skip_input_bias_add_output, + const bool skip_broadcasted, + const int skip_size) { const int maxGridY = prop.maxGridSize[1]; const int warp_size = prop.warpSize; ORT_ENFORCE(warp_size == GPU_WARP_SIZE_HOST); @@ -443,14 +452,17 @@ void HostApplyLayerNorm( n1, n2, U(epsilon), gamma, beta, - skip, bias, skip_input_bias_add_output); + skip, bias, skip_input_bias_add_output, + skip_broadcasted, + skip_size); } #define LAYERNORM_LINEAR_IMPL(T, U, V, simplified) \ template void HostApplyLayerNorm(const cudaDeviceProp& prop, cudaStream_t stream, V* output, \ U* mean, U* inv_std_dev, const T* input, int n1, int n2, \ double epsilon, const V* gamma, const V* beta, const T* skip, \ - const T* bias, T* skip_input_bias_add_output); + const T* bias, T* skip_input_bias_add_output, const bool skip_broadcasted, \ + const int skip_size); LAYERNORM_LINEAR_IMPL(float, float, float, true) LAYERNORM_LINEAR_IMPL(half, float, half, true) diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h index e3952eefae..d0d5db8ba3 100644 --- a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h +++ b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h @@ -43,7 +43,9 @@ void HostApplyLayerNorm( const V* beta, const T* skip = nullptr, const T* bias = nullptr, - T* skip_input_bias_add_output = nullptr); + T* skip_input_bias_add_output = nullptr, + const bool skip_broadcasted = false, + const int skip_size = 0); } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc index 67e279dc08..a41a1dd4ec 100644 --- a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc +++ b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc @@ -27,18 +27,31 @@ static void RunTest( bool no_beta = false, bool simplified = false, bool use_token_count = false, - bool strict = false) { + bool strict = false, + bool broadcast_skip = false, + bool no_batch_size = false) { // Input and output shapes // Input 0 - input: (batch_size, sequence_length, hidden_size) or (batch_size * sequence_length, hidden_size) - // Input 1 - skip : (batch_size, sequence_length, hidden_size) or (batch_size * sequence_length, hidden_size) + // Input 1 - skip : (batch_size, sequence_length, hidden_size) or (batch_size * sequence_length, hidden_size) or (1, sequence_length, hidden_size) or (sequence_length, hidden_size) // Input 2 - gamma: (hidden_size) // Input 3 - beta : (hidden_size) // Output : (batch_size, sequence_length, hidden_size) or (batch_size * sequence_length, hidden_size) std::vector input_dims = {batch_size, sequence_length, hidden_size}; + std::vector skip_dims = input_dims; + if (use_token_count) { input_dims = {batch_size * sequence_length, hidden_size}; + skip_dims = input_dims; } - std::vector skip_dims = input_dims; + + if (broadcast_skip) { + skip_dims = {1, sequence_length, hidden_size}; + } + + if (no_batch_size) { + skip_dims = {sequence_length, hidden_size}; + } + std::vector gamma_dims = {hidden_size}; std::vector beta_dims = gamma_dims; std::vector bias_dims = gamma_dims; @@ -48,6 +61,8 @@ static void RunTest( auto rocm_ep = DefaultRocmExecutionProvider(); auto dml_ep = DefaultDmlExecutionProvider(); + auto cpu_ep = DefaultCpuExecutionProvider(); + std::vector> execution_providers; if (!use_float16) { OpTester test(op_type.c_str(), 1, onnxruntime::kMSDomain); test.AddInput("input", input_dims, input_data); @@ -77,7 +92,10 @@ static void RunTest( skip_input_bias_add_output_data); } - test.Run(); + if (cpu_ep != nullptr) { + execution_providers.push_back(DefaultCpuExecutionProvider()); + } + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } else if (HasCudaEnvironment(530 /*min_cuda_architecture*/) || dml_ep != nullptr || rocm_ep != nullptr) { @@ -109,7 +127,6 @@ static void RunTest( ToFloat16(skip_input_bias_add_output_data)); } - std::vector> execution_providers; if (dml_ep != nullptr) { execution_providers.push_back(DefaultDmlExecutionProvider()); } else if (rocm_ep != nullptr) { @@ -718,6 +735,100 @@ TEST(SkipLayerNormTest, SkipSimplifiedLayerNormBatch1_Float16) { true, true); } + +TEST(SkipLayerNormTest, SkipLayerNormBatch2_Skip_Broadcast_No_Batch_Size) { + int batch_size = 2; + int sequence_length = 2; + int hidden_size = 4; + + std::vector input_data = { + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f, + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f}; + + std::vector skip_data = { + 0.1f, -0.2f, 0.3f, 1.0f, + 0.5f, 0.1f, 0.4f, 1.6f}; + + std::vector gamma_data = { + 0.3f, 0.2f, 4.0f, 2.2f}; + + std::vector beta_data = { + 0.2f, 0.1f, 0.4f, 1.6f}; + + std::vector output_data = { + 0.28433859348297119, -0.17090578377246857, -0.92897164821624756, 4.6924152374267578, + 0.46111652255058289, -0.21333980560302734, -0.29631003737449646, 3.5148544311523438, + 0.28433859348297119, -0.17090578377246857, -0.92897164821624756, 4.6924152374267578, + 0.46111652255058289, -0.21333980560302734, -0.29631003737449646, 3.5148544311523438}; + + RunTest(input_data, + skip_data, + gamma_data, + beta_data, + std::vector(), + output_data, + {}, + epsilon_, + batch_size, + sequence_length, + hidden_size, + false, + false, + false, + false, + false, + false, + true); +} + +TEST(SkipLayerNormTest, SkipLayerNormBatch2_Skip_Broadcast_Batch_Size_1) { + int batch_size = 2; + int sequence_length = 2; + int hidden_size = 4; + + std::vector input_data = { + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f, + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f}; + + std::vector skip_data = { + 0.1f, -0.2f, 0.3f, 1.0f, + 0.5f, 0.1f, 0.4f, 1.6f}; + + std::vector gamma_data = { + 0.3f, 0.2f, 4.0f, 2.2f}; + + std::vector beta_data = { + 0.2f, 0.1f, 0.4f, 1.6f}; + + std::vector output_data = { + 0.28433859348297119, -0.17090578377246857, -0.92897164821624756, 4.6924152374267578, + 0.46111652255058289, -0.21333980560302734, -0.29631003737449646, 3.5148544311523438, + 0.28433859348297119, -0.17090578377246857, -0.92897164821624756, 4.6924152374267578, + 0.46111652255058289, -0.21333980560302734, -0.29631003737449646, 3.5148544311523438}; + + RunTest(input_data, + skip_data, + gamma_data, + beta_data, + std::vector(), + output_data, + {}, + epsilon_, + batch_size, + sequence_length, + hidden_size, + false, + false, + false, + false, + false, + true, + false); +} #endif } // namespace test