diff --git a/.vscode/settings.json b/.vscode/settings.json
index fd28e2d7b3..b7a1292efb 100644
--- a/.vscode/settings.json
+++ b/.vscode/settings.json
@@ -40,4 +40,3 @@
"-build/include_subdir",
"-runtime/references"
]
-}
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 77b84e2b3f..4cdadeb1ff 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -4651,7 +4651,7 @@ This version of the operator has been available since version 1 of the 'com.micr
input : T
3D input tensor with shape (batch_size, sequence_length, hidden_size)
skip : T
-3D skip tensor with shape (batch_size, sequence_length, hidden_size)
+3D skip tensor with shape (batch_size, sequence_length, hidden_size) or (1, sequence_length, hidden_size) or (sequence_length, hidden_size)
gamma : T
1D input tensor with shape (hidden_size)
beta (optional) : T
diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
index b3668bcc0b..e86a12d9fb 100644
--- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
+++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
@@ -6,6 +6,7 @@
#include "core/providers/common.h"
#include "core/platform/threadpool.h"
#include "skip_layer_norm.h"
+#include "skip_layer_norm_helper.h"
namespace onnxruntime {
namespace contrib {
@@ -45,51 +46,15 @@ Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const {
const auto& input_dims = input->Shape().GetDims();
size_t input_dims_size = input_dims.size();
- if (input_dims_size != 3 && input_dims_size != 2) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "input is expected to have 3 or 2 dimensions, got ", input_dims_size);
- }
-
int hidden_size = static_cast(input_dims[input_dims_size - 1]);
- if (input->Shape() != skip->Shape()) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "skip is expected to have same shape as input");
- }
-
- const auto& gamma_dims = gamma->Shape().GetDims();
- if (gamma_dims.size() != 1) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "gamma is expected to have 1 dimension, got ", gamma_dims.size());
- }
- if (gamma_dims[0] != hidden_size) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Last dimension of gamma and input does not match");
- }
-
- if (nullptr != beta) {
- const auto& beta_dims = beta->Shape().GetDims();
- if (beta_dims.size() != 1) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "beta is expected to have 1 dimension, got ", beta_dims.size());
- }
- if (beta_dims[0] != hidden_size) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Last dimension of beta and input does not match");
- }
- }
-
- if (nullptr != bias) {
- const auto& bias_dims = bias->Shape().GetDims();
- if (bias_dims.size() != 1) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "bias is expected to have 1 dimension, got ", bias_dims.size());
- }
- if (bias_dims[0] != hidden_size) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Last dimension of bias and input does not match");
- }
- }
+ ORT_RETURN_IF_ERROR(onnxruntime::contrib::skip_layer_norm_helper::CheckInputs(input,
+ skip,
+ gamma,
+ beta,
+ bias,
+ hidden_size,
+ input_dims_size));
int64_t task_count = input->Shape().SizeToDimension(input_dims_size - 1);
@@ -105,13 +70,15 @@ Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const {
// of the input and skip tensors
T* skip_input_bias_add_output_data = skip_input_bias_add_output != nullptr ? skip_input_bias_add_output->MutableData() : nullptr;
+ const auto& skip_size = skip->Shape().Size();
+
concurrency::ThreadPool::TryBatchParallelFor(
p_ctx->GetOperatorThreadPool(), static_cast(task_count),
[&](ptrdiff_t task_idx) {
auto offset = task_idx * hidden_size;
const T* p_input = input_data + offset;
- const T* p_skip = skip_data + offset;
+ const T* p_skip = skip_data + (offset % skip_size);
T* p_output = output_data + offset;
T* p_skip_input_bias_add_output_data = skip_input_bias_add_output_data != nullptr ? skip_input_bias_add_output_data + offset : nullptr;
diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm_helper.h b/onnxruntime/contrib_ops/cpu/skip_layer_norm_helper.h
new file mode 100644
index 0000000000..6271f82228
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm_helper.h
@@ -0,0 +1,84 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/common/common.h"
+#include "core/providers/common.h"
+#include "contrib_ops/cpu/bert/attention_common.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace skip_layer_norm_helper {
+
+template
+Status CheckInputs(const T* input,
+ const T* skip,
+ const T* gamma,
+ const T* beta,
+ const T* bias,
+ int hidden_size_check,
+ size_t input_dims_size_check) {
+ const auto& input_dims_check = input->Shape().GetDims();
+ const auto& skip_dims_check = skip->Shape().GetDims();
+ size_t skip_dims_size_check = skip_dims_check.size();
+
+ if (skip_dims_size_check != 3 && skip_dims_size_check != 2) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "skip is expected to have 3 or 2 dimensions, got ", skip_dims_size_check);
+ }
+
+ if ((input->Shape() != skip->Shape()) && ((skip_dims_check[0] != 1 || skip_dims_size_check != 2) && input_dims_size_check != 3)) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "skip is expected to have same shape as input or, a batch size of 1 or no batch size when input has 3 dimensions");
+ }
+
+ if (input_dims_size_check != 3 && input_dims_size_check != 2) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "input is expected to have 3 or 2 dimensions, got ", input_dims_size_check);
+ }
+
+ if (skip_dims_check[skip_dims_size_check - 1] != input_dims_check[input_dims_size_check - 1] || skip_dims_check[skip_dims_size_check - 2] != input_dims_check[input_dims_size_check - 2]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "last two dimensions of skip needs to be same as input");
+ }
+
+ const auto& gamma_dims = gamma->Shape().GetDims();
+ if (gamma_dims.size() != 1) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "gamma is expected to have 1 dimension, got ", gamma_dims.size());
+ }
+ if (gamma_dims[0] != hidden_size_check) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Last dimension of gamma and input does not match");
+ }
+
+ if (nullptr != beta) {
+ const auto& beta_dims = beta->Shape().GetDims();
+ if (beta_dims.size() != 1) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "beta is expected to have 1 dimension, got ", beta_dims.size());
+ }
+ if (beta_dims[0] != hidden_size_check) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Last dimension of beta and input does not match");
+ }
+ }
+
+ if (nullptr != bias) {
+ const auto& bias_dims = bias->Shape().GetDims();
+ if (bias_dims.size() != 1) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "bias is expected to have 1 dimension, got ", bias_dims.size());
+ }
+ if (bias_dims[0] != hidden_size_check) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Last dimension of bias and input does not match");
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace skip_layer_norm_helper
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc
index 62f93b5e5b..0d8f887d98 100644
--- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc
@@ -5,6 +5,7 @@
#include "core/providers/cuda/nn/layer_norm_impl.h"
#include "skip_layer_norm.h"
#include "skip_layer_norm_impl.h"
+#include "contrib_ops/cpu/skip_layer_norm_helper.h"
namespace onnxruntime {
namespace contrib {
@@ -60,59 +61,23 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const
// of the input and skip tensors
Tensor* skip_input_bias_add_output = ctx->Output(3, input->Shape());
- if (input->Shape() != skip->Shape()) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "skip is expected to have same shape as input");
- }
-
- if (input->Shape().Size() == 0) {
- return Status::OK();
- }
-
const auto& input_dims = input->Shape().GetDims();
size_t input_dims_size = input_dims.size();
- if (input_dims_size != 3 && input_dims_size != 2) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "input is expected to have 3 or 2 dimensions, got ", input_dims_size);
- }
+ const auto& skip_dims = skip->Shape().GetDims();
+ size_t skip_dims_size = skip_dims.size();
int hidden_size = static_cast(input_dims[input_dims_size - 1]);
- const auto& gamma_dims = gamma->Shape().GetDims();
- if (gamma_dims.size() != 1) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "gamma is expected to have 1 dimension, got ", gamma_dims.size());
- }
- if (gamma_dims[0] != hidden_size) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Last dimension of gamma and input does not match");
- }
+ ORT_RETURN_IF_ERROR(onnxruntime::contrib::skip_layer_norm_helper::CheckInputs(input,
+ skip,
+ gamma,
+ beta,
+ bias,
+ hidden_size,
+ input_dims_size));
- if (!Simplified) {
- if (nullptr != beta) {
- const auto& beta_dims = beta->Shape().GetDims();
- if (beta_dims.size() != 1) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "beta is expected to have 1 dimension, got ", beta_dims.size());
- }
- if (beta_dims[0] != hidden_size) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Last dimension of beta and input does not match");
- }
- }
- }
-
- if (nullptr != bias) {
- const auto& bias_dims = bias->Shape().GetDims();
- if (bias_dims.size() != 1) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "bias is expected to have 1 dimension, got ", bias_dims.size());
- }
- if (bias_dims[0] != hidden_size) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Last dimension of bias and input does not match");
- }
- }
+ const bool skip_broadcasted = (skip_dims[0] == 1 || skip_dims_size == 2) ? true : false;
+ const int skip_size = static_cast(skip_dims[skip_dims_size - 1] * skip_dims[skip_dims_size - 2]);
if (strict_) {
int row_count = gsl::narrow(input->Shape().SizeToDimension(input_dims_size - 1));
@@ -131,7 +96,9 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const
(beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, // beta
reinterpret_cast(skip->Data()), // skip or residual to add
(bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, // bias to add
- skip_input_bias_add_output != nullptr ? reinterpret_cast(skip_input_bias_add_output->MutableData()) : nullptr);
+ skip_input_bias_add_output != nullptr ? reinterpret_cast(skip_input_bias_add_output->MutableData()) : nullptr,
+ skip_broadcasted,
+ skip_size);
CUDA_RETURN_IF_ERROR(cudaGetLastError());
return Status::OK();
@@ -153,7 +120,9 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const
epsilon_,
hidden_size,
static_cast(element_count),
- element_size);
+ element_size,
+ skip_broadcasted,
+ skip_size);
CUDA_RETURN_IF_ERROR(cudaGetLastError());
return Status::OK();
diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu
index 0a815f7f49..1d0d97323d 100644
--- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu
@@ -82,7 +82,7 @@ template
__global__ void SkipLayerNormKernel(
const int ld, const T* input, const T* skip,
const T* beta, const T* gamma, const T* bias,
- const T epsilon, T* output, T* skip_input_bias_add_output) {
+ const T epsilon, T* output, T* skip_input_bias_add_output, const bool skip_broadcasted, int skip_size) {
const T reverse_ld = T(1.f / ld);
const int offset = blockIdx.x * ld;
@@ -90,10 +90,13 @@ __global__ void SkipLayerNormKernel(
// reduce x and x^2
cub::KeyValuePair thread_data(0, 0);
+
for (int i = threadIdx.x; i < ld; i += TPB) {
const int idx = offset + i;
- const T val = (bias == nullptr) ? input[idx] + skip[idx] : input[idx] + skip[idx] + bias[i];
+ const T skip_data = skip_broadcasted ? skip[idx % skip_size] : skip[idx];
+ const T val = (bias == nullptr) ? input[idx] + skip_data : input[idx] + skip_data + bias[i];
+
const T rldval = reverse_ld * val;
thread_data = pair_sum(thread_data, cub::KeyValuePair(rldval, rldval * val));
@@ -115,7 +118,7 @@ template
__global__ void SkipLayerNormKernelSmall(
const int ld, const T* input, const T* skip, const T* beta, const T* gamma,
const T* bias, const T epsilon, T* output, T* skip_input_bias_add_output,
- bool hasBias, bool hasSkipInputBiasAdditionOutput) {
+ bool hasBias, bool hasSkipInputBiasAdditionOutput, const bool skip_broadcasted, const int skip_size) {
const T rld = T(1.f / ld);
const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld
@@ -127,7 +130,11 @@ __global__ void SkipLayerNormKernelSmall(
*input_val = *reinterpret_cast(&input[idx]);
VecT* skip_val = reinterpret_cast(&skip_v);
+ if (skip_broadcasted){
+ *skip_val = *reinterpret_cast(&skip[idx % skip_size]);
+ }else{
*skip_val = *reinterpret_cast(&skip[idx]);
+ }
if (hasBias) {
VecT* bias_val = reinterpret_cast(&bias_v);
@@ -170,8 +177,13 @@ template
Status LaunchSkipLayerNormKernel(
cudaStream_t stream, T* output, T* skip_input_bias_add_output, const T* input, const T* skip, const T* gamma,
const T* beta, const T* bias, float epsilon, const int ld, const int element_count,
- size_t element_size) {
+ size_t element_size, const bool skip_broadcasted, const int skip_size) {
// this must be true because n is the total size of the tensor
+
+ if (element_count == 0) {
+ return Status::OK();
+ }
+
assert(element_count % ld == 0);
bool hasBias = (bias == nullptr) ? false : true;
bool hasSkipInputBiasAdditionOutput = (skip_input_bias_add_output == nullptr) ? false : true;
@@ -187,10 +199,10 @@ Status LaunchSkipLayerNormKernel(
#define LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(num_unroll) \
SkipLayerNormKernelSmall \
<<>>(ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, \
- skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput)
+ skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput, skip_broadcasted, skip_size)
#define LAUNCH_SKIP_LAYER_NORM_KERNEL() \
SkipLayerNormKernel<<>>( \
- ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, skip_input_bias_add_output)
+ ld, input, skip, beta, gamma, bias, maybe2half(epsilon), output, skip_input_bias_add_output, skip_broadcasted, skip_size)
#define CASE_NEXT_SIZE(next_size_value) \
case next_size_value: { \
if (flag_vec4) { \
@@ -219,7 +231,6 @@ Status LaunchSkipLayerNormKernel(
#undef LAUNCH_SKIP_LAYER_NORM_KERNEL
#undef LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL
}
-
return CUDA_CALL(cudaGetLastError());
}
@@ -229,7 +240,8 @@ Status LaunchSkipLayerNormKernel(
const T* input, const T* skip, const T* gamma, \
const T* beta, const T* bias, float epsilon, \
const int ld, const int element_count, \
- size_t element_size);
+ size_t element_size, const bool skip_broadcasted, \
+ const int skip_size);
SKIPLAYERNORM_IMPL(float, true);
SKIPLAYERNORM_IMPL(float, false);
diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h
index da2894928c..4900e64be1 100644
--- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h
+++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h
@@ -21,7 +21,9 @@ Status LaunchSkipLayerNormKernel(
float epsilon, // Layer normalization epsilon
int hidden_size, // hidden size, it is the leading dimension (ld)
int element_count, // number of elements in input tensor
- size_t element_size);
+ size_t element_size,
+ const bool skip_broadcasted, // determines if broadcasting should be implemented
+ const int skip_size); // determines size of the skip tensor
} // namespace cuda
} // namespace contrib
diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
index d2ff7e5351..e5956a575d 100644
--- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
@@ -1097,7 +1097,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
.SetDoc("Skip and Layer Normalization Fusion")
.Attr("epsilon", "The epsilon value to use to avoid division by zero.", AttributeProto::FLOAT, kDefaultSkipLayerNormEpsilon)
.Input(0, "input", "3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T")
- .Input(1, "skip", "3D skip tensor with shape (batch_size, sequence_length, hidden_size)", "T")
+ .Input(1, "skip", "3D skip tensor with shape (batch_size, sequence_length, hidden_size) or (1, sequence_length, hidden_size) or (sequence_length, hidden_size)", "T")
.Input(2, "gamma", "1D input tensor with shape (hidden_size)", "T")
.Input(3, "beta", "1D skip tensor with shape (hidden_size", "T", OpSchema::Optional)
.Input(4, "bias", "1D bias tensor with shape (hidden_size", "T", OpSchema::Optional)
diff --git a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc
index ba2f1ea52f..cf70a7d821 100644
--- a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc
+++ b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc
@@ -12,7 +12,8 @@ using namespace onnxruntime::common;
namespace onnxruntime {
// LayerNorm supports limited data types.
-static constexpr std::array supported_data_types{"tensor(float16)", "tensor(float)", "tensor(bfloat16)"};
+static constexpr std::array supported_data_types{
+ "tensor(float16)", "tensor(float)", "tensor(bfloat16)"};
static bool IsSupportedDataType(const Node& node) {
for (const auto& input_arg : node.InputDefs()) {
diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu
index 458dcf0b0f..4cc560a117 100644
--- a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu
+++ b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu
@@ -338,7 +338,9 @@ __global__ void cuApplyLayerNorm(
const V* __restrict__ beta,
const T* __restrict__ skip,
const T* __restrict__ bias,
- T* __restrict__ skip_input_bias_add_output) {
+ T* __restrict__ skip_input_bias_add_output,
+ const bool skip_broadcasted,
+ const int skip_size) {
// Assumptions:
// 1) blockDim.x == GPU_WARP_SIZE
// 2) Tensors are contiguous
@@ -358,11 +360,16 @@ __global__ void cuApplyLayerNorm(
for (int i = thrx; i < n2; i += numx) {
U curr = static_cast(lvals[i]);
+
+
if (bias != NULL) {
curr += static_cast(bias[i]);
}
- if (skip_vals != NULL) {
+ if (skip_vals != NULL && skip_broadcasted) {
+ int skip_i = i % skip_size;
+ curr += static_cast(skip_vals[skip_i]); //Calculates index for the second dimension of the skip tensor
+ }else if (skip_vals != NULL){
curr += static_cast(skip_vals[i]);
}
@@ -411,7 +418,9 @@ void HostApplyLayerNorm(
const V* beta,
const T* skip,
const T* bias,
- T* skip_input_bias_add_output) {
+ T* skip_input_bias_add_output,
+ const bool skip_broadcasted,
+ const int skip_size) {
const int maxGridY = prop.maxGridSize[1];
const int warp_size = prop.warpSize;
ORT_ENFORCE(warp_size == GPU_WARP_SIZE_HOST);
@@ -443,14 +452,17 @@ void HostApplyLayerNorm(
n1, n2,
U(epsilon),
gamma, beta,
- skip, bias, skip_input_bias_add_output);
+ skip, bias, skip_input_bias_add_output,
+ skip_broadcasted,
+ skip_size);
}
#define LAYERNORM_LINEAR_IMPL(T, U, V, simplified) \
template void HostApplyLayerNorm(const cudaDeviceProp& prop, cudaStream_t stream, V* output, \
U* mean, U* inv_std_dev, const T* input, int n1, int n2, \
double epsilon, const V* gamma, const V* beta, const T* skip, \
- const T* bias, T* skip_input_bias_add_output);
+ const T* bias, T* skip_input_bias_add_output, const bool skip_broadcasted, \
+ const int skip_size);
LAYERNORM_LINEAR_IMPL(float, float, float, true)
LAYERNORM_LINEAR_IMPL(half, float, half, true)
diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h
index e3952eefae..d0d5db8ba3 100644
--- a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h
+++ b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h
@@ -43,7 +43,9 @@ void HostApplyLayerNorm(
const V* beta,
const T* skip = nullptr,
const T* bias = nullptr,
- T* skip_input_bias_add_output = nullptr);
+ T* skip_input_bias_add_output = nullptr,
+ const bool skip_broadcasted = false,
+ const int skip_size = 0);
} // namespace cuda
} // namespace onnxruntime
diff --git a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc
index 67e279dc08..a41a1dd4ec 100644
--- a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc
+++ b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc
@@ -27,18 +27,31 @@ static void RunTest(
bool no_beta = false,
bool simplified = false,
bool use_token_count = false,
- bool strict = false) {
+ bool strict = false,
+ bool broadcast_skip = false,
+ bool no_batch_size = false) {
// Input and output shapes
// Input 0 - input: (batch_size, sequence_length, hidden_size) or (batch_size * sequence_length, hidden_size)
- // Input 1 - skip : (batch_size, sequence_length, hidden_size) or (batch_size * sequence_length, hidden_size)
+ // Input 1 - skip : (batch_size, sequence_length, hidden_size) or (batch_size * sequence_length, hidden_size) or (1, sequence_length, hidden_size) or (sequence_length, hidden_size)
// Input 2 - gamma: (hidden_size)
// Input 3 - beta : (hidden_size)
// Output : (batch_size, sequence_length, hidden_size) or (batch_size * sequence_length, hidden_size)
std::vector input_dims = {batch_size, sequence_length, hidden_size};
+ std::vector skip_dims = input_dims;
+
if (use_token_count) {
input_dims = {batch_size * sequence_length, hidden_size};
+ skip_dims = input_dims;
}
- std::vector skip_dims = input_dims;
+
+ if (broadcast_skip) {
+ skip_dims = {1, sequence_length, hidden_size};
+ }
+
+ if (no_batch_size) {
+ skip_dims = {sequence_length, hidden_size};
+ }
+
std::vector gamma_dims = {hidden_size};
std::vector beta_dims = gamma_dims;
std::vector bias_dims = gamma_dims;
@@ -48,6 +61,8 @@ static void RunTest(
auto rocm_ep = DefaultRocmExecutionProvider();
auto dml_ep = DefaultDmlExecutionProvider();
+ auto cpu_ep = DefaultCpuExecutionProvider();
+ std::vector> execution_providers;
if (!use_float16) {
OpTester test(op_type.c_str(), 1, onnxruntime::kMSDomain);
test.AddInput("input", input_dims, input_data);
@@ -77,7 +92,10 @@ static void RunTest(
skip_input_bias_add_output_data);
}
- test.Run();
+ if (cpu_ep != nullptr) {
+ execution_providers.push_back(DefaultCpuExecutionProvider());
+ }
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
} else if (HasCudaEnvironment(530 /*min_cuda_architecture*/) ||
dml_ep != nullptr ||
rocm_ep != nullptr) {
@@ -109,7 +127,6 @@ static void RunTest(
ToFloat16(skip_input_bias_add_output_data));
}
- std::vector> execution_providers;
if (dml_ep != nullptr) {
execution_providers.push_back(DefaultDmlExecutionProvider());
} else if (rocm_ep != nullptr) {
@@ -718,6 +735,100 @@ TEST(SkipLayerNormTest, SkipSimplifiedLayerNormBatch1_Float16) {
true,
true);
}
+
+TEST(SkipLayerNormTest, SkipLayerNormBatch2_Skip_Broadcast_No_Batch_Size) {
+ int batch_size = 2;
+ int sequence_length = 2;
+ int hidden_size = 4;
+
+ std::vector input_data = {
+ 0.8f, -0.5f, 0.0f, 1.f,
+ 0.5f, 0.2f, 0.3f, -0.6f,
+ 0.8f, -0.5f, 0.0f, 1.f,
+ 0.5f, 0.2f, 0.3f, -0.6f};
+
+ std::vector skip_data = {
+ 0.1f, -0.2f, 0.3f, 1.0f,
+ 0.5f, 0.1f, 0.4f, 1.6f};
+
+ std::vector gamma_data = {
+ 0.3f, 0.2f, 4.0f, 2.2f};
+
+ std::vector beta_data = {
+ 0.2f, 0.1f, 0.4f, 1.6f};
+
+ std::vector output_data = {
+ 0.28433859348297119, -0.17090578377246857, -0.92897164821624756, 4.6924152374267578,
+ 0.46111652255058289, -0.21333980560302734, -0.29631003737449646, 3.5148544311523438,
+ 0.28433859348297119, -0.17090578377246857, -0.92897164821624756, 4.6924152374267578,
+ 0.46111652255058289, -0.21333980560302734, -0.29631003737449646, 3.5148544311523438};
+
+ RunTest(input_data,
+ skip_data,
+ gamma_data,
+ beta_data,
+ std::vector(),
+ output_data,
+ {},
+ epsilon_,
+ batch_size,
+ sequence_length,
+ hidden_size,
+ false,
+ false,
+ false,
+ false,
+ false,
+ false,
+ true);
+}
+
+TEST(SkipLayerNormTest, SkipLayerNormBatch2_Skip_Broadcast_Batch_Size_1) {
+ int batch_size = 2;
+ int sequence_length = 2;
+ int hidden_size = 4;
+
+ std::vector input_data = {
+ 0.8f, -0.5f, 0.0f, 1.f,
+ 0.5f, 0.2f, 0.3f, -0.6f,
+ 0.8f, -0.5f, 0.0f, 1.f,
+ 0.5f, 0.2f, 0.3f, -0.6f};
+
+ std::vector skip_data = {
+ 0.1f, -0.2f, 0.3f, 1.0f,
+ 0.5f, 0.1f, 0.4f, 1.6f};
+
+ std::vector gamma_data = {
+ 0.3f, 0.2f, 4.0f, 2.2f};
+
+ std::vector beta_data = {
+ 0.2f, 0.1f, 0.4f, 1.6f};
+
+ std::vector output_data = {
+ 0.28433859348297119, -0.17090578377246857, -0.92897164821624756, 4.6924152374267578,
+ 0.46111652255058289, -0.21333980560302734, -0.29631003737449646, 3.5148544311523438,
+ 0.28433859348297119, -0.17090578377246857, -0.92897164821624756, 4.6924152374267578,
+ 0.46111652255058289, -0.21333980560302734, -0.29631003737449646, 3.5148544311523438};
+
+ RunTest(input_data,
+ skip_data,
+ gamma_data,
+ beta_data,
+ std::vector(),
+ output_data,
+ {},
+ epsilon_,
+ batch_size,
+ sequence_length,
+ hidden_size,
+ false,
+ false,
+ false,
+ false,
+ false,
+ true,
+ false);
+}
#endif
} // namespace test