From 73f5b0c5976fec04ebbdfe0761ee23055e0fcf8b Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 10 Jan 2025 21:57:18 -0800 Subject: [PATCH] LayerNormalization broadcast (limited support for axis=2) (#23297) ### Description Spec of LayerNormalization supports broadcasting (tensors Scale and B should be unidirectional broadcastable to tensor X). https://onnx.ai/onnx/operators/onnx__LayerNormalization.html However, current implementation only allow scale and bias size to be X.shape()[axis:]. Example of input tensors that normalized with axis=2: | X shape | Scale shape | B shape | Before | After | | - | - | - | - | - | | (B, S, D) | (D) | (D) | Supported | Supported | | (B, S, D) | (1, 1, D) | (1, 1, D) | Supported | Supported | | (B, S, D) | (B, 1, D) | (B, 1, D) | Not Supported | Supported | | (B, S, D) | (1, S, D) | (1, S, D) | Not Supported | Supported | | (B, S, D) | (B, S, D) | (B, S, D) | Not Supported | Supported | Here we add limited support: axis=2; scale/bias has same shape; scale/bias/X have same number of dimensions. It could support common use case in LLM and vision models. ### Motivation and Context Support Stable Diffusion 3.x and Flux model. --- .../contrib_ops/cuda/bert/skip_layer_norm.cc | 1 + .../core/providers/cpu/nn/layer_norm_helper.h | 116 ++++++++++++++ .../core/providers/cpu/nn/layer_norm_impl.cc | 62 ++++---- .../core/providers/cpu/nn/layer_norm_impl.h | 8 +- .../core/providers/cuda/nn/layer_norm.cc | 28 ++-- .../core/providers/cuda/nn/layer_norm_impl.cu | 27 ++-- .../core/providers/cuda/nn/layer_norm_impl.h | 1 + .../test/contrib_ops/layer_norm_op_test.cc | 148 +++++++++++++++++- .../microbenchmark/layer_normalization.cc | 4 +- 9 files changed, 331 insertions(+), 64 deletions(-) create mode 100644 onnxruntime/core/providers/cpu/nn/layer_norm_helper.h diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc index 3299bc2cb1..428b903c03 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc @@ -101,6 +101,7 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const (double)epsilon_, // epsilon reinterpret_cast(gamma->Data()), // gamma (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, // beta + 0, // no broadcast for gamma/beta reinterpret_cast(skip->Data()), // skip or residual to add (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, // bias to add sum_output != nullptr ? reinterpret_cast(sum_output->MutableData()) : nullptr); diff --git a/onnxruntime/core/providers/cpu/nn/layer_norm_helper.h b/onnxruntime/core/providers/cpu/nn/layer_norm_helper.h new file mode 100644 index 0000000000..ed5ea83d9d --- /dev/null +++ b/onnxruntime/core/providers/cpu/nn/layer_norm_helper.h @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/tensor_shape.h" +#include "core/common/status.h" +#include "core/common/narrow.h" + +namespace onnxruntime { + +constexpr const char* kLayerNormInputShapeMismatchError = + "Size of scale and bias (if provided) must match X.shape[axis:], " + "or scale and bias (with same shape) can be broadcasted to X when axis is 2."; + +constexpr const char* kLayerNormInvalidSize = "Size of X.shape[axis:] must be larger than 1, got "; + +constexpr int64_t kLayerNormInvalidInput = -1; + +struct LayerNormParams { + int64_t num_rows; + int64_t norm_size; // size per row + int64_t scale_size; + int64_t bias_size; + int64_t broadcast_param; +}; + +// We support broadcasting for axis=2, where the first two dimensions are rows, and the rest are columns. +// When X shape is (B, S, ...), and x_row (index of one row in X) is in the range of [0, B * S). +// We support scale and bias shape like below: +// When scale and bias shape is (1, 1, ...) or (...), value of broadcast_param is 0. +// When scale and bias shape is (B, 1, ...), value of broadcast_param is S. +// When scale and bias shape is (B, S, ...), value of broadcast_param is 1. +// When scale and bias shape is (1, S, ...), value of broadcast_param is -S. + +// Below is a macro to compute the offset for scale and bias data for a row of X. +#ifndef LAYER_NORM_SCALE_BIAS_OFFSET +#define LAYER_NORM_SCALE_BIAS_OFFSET(broadcast_param, x_row, norm_size) \ + ((broadcast_param == 0) ? 0 \ + : norm_size * (broadcast_param > 0 ? x_row / broadcast_param : x_row % (-broadcast_param))) +#endif + +class LayerNormHelper { + public: + static Status CheckInputs(const TensorShape& x_shape, + const TensorShape& scale_shape, + const TensorShape& bias_shape, + bool has_bias, + int64_t axis, + LayerNormParams& params) { + params.num_rows = x_shape.SizeToDimension(onnxruntime::narrow(axis)); + params.norm_size = x_shape.SizeFromDimension(onnxruntime::narrow(axis)); + params.scale_size = scale_shape.Size(); + params.bias_size = bias_shape.Size(); + params.broadcast_param = 0; + + if (params.norm_size <= 1) { + params.broadcast_param = kLayerNormInvalidInput; + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, kLayerNormInvalidSize, params.norm_size); + } else if (params.scale_size != params.norm_size || (has_bias && params.bias_size != params.scale_size)) { + params.broadcast_param = GetBroadcastParam(x_shape, scale_shape, has_bias ? &bias_shape : nullptr, axis); + if (params.broadcast_param == kLayerNormInvalidInput) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + kLayerNormInputShapeMismatchError, + " X.shape=", x_shape, + " scale.shape=", scale_shape, + " bias.shape=", bias_shape, + " and axis=", axis); + } + } + return Status::OK(); + } + + private: + static int64_t GetBroadcastParam(const TensorShape& x_shape, + const TensorShape& scale_shape, + const TensorShape* bias_shape, + int64_t axis) { + // Note that when size of scale and bias is norm_size, it won't enter this function (see CheckInputs). + + // X shape is (B, S, ...) + if (axis == 2 && + x_shape.NumDimensions() >= 3 && + x_shape.NumDimensions() == scale_shape.NumDimensions() && + (bias_shape == nullptr || *bias_shape == scale_shape)) { + for (size_t i = 2; i < x_shape.NumDimensions(); ++i) { + if (x_shape.GetDims()[i] != scale_shape.GetDims()[i]) { + // scale cannot be broadcasted to X. It is invalid input. + return kLayerNormInvalidInput; + } + } + + if (x_shape.GetDims()[0] == scale_shape.GetDims()[0]) { + // scale and bias shape is (B, S, ...). + if (x_shape.GetDims()[1] == scale_shape.GetDims()[1]) { + return 1; + } + + // scale and bias shape is (B, 1, ...), returns S + if (scale_shape.GetDims()[1] == 1) { + return x_shape.GetDims()[1]; + } + } else if (scale_shape.GetDims()[0] == 1) { + // scale and bias shape is (1, S, ...), returns -S + if (x_shape.GetDims()[1] == scale_shape.GetDims()[1]) { + return -(x_shape.GetDims()[1]); + } + } + } + + // Other cases that are not supported. + return kLayerNormInvalidInput; + } +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc index 24a5dcab22..9a6295def4 100644 --- a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc +++ b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "layer_norm_impl.h" +#include "layer_norm_helper.h" #include "core/common/safeint.h" #include "core/framework/tensor.h" @@ -24,6 +25,7 @@ void ComputeJob( const T* bias_data, const ptrdiff_t task_idx, const int64_t norm_size, + const int64_t broadcast_param, const float* scale_float_ptr, const float* bias_float_ptr, float epsilon, @@ -55,13 +57,16 @@ void ComputeJob( mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon); } - for (int64_t h = 0; h < norm_size; h++) { + // Compute the offset of gamma and beta to support broadcasting. + int64_t i = LAYER_NORM_SCALE_BIAS_OFFSET(broadcast_param, task_idx, norm_size); + + for (int64_t h = 0; h < norm_size; h++, i++) { if (simplified) { - p_output[h] = p_output[h] / mean_square * scale_data[h]; + p_output[h] = p_output[h] / mean_square * scale_data[i]; } else if (nullptr == bias_data) { - p_output[h] = (p_output[h] - mean) / mean_square * scale_data[h]; + p_output[h] = (p_output[h] - mean) / mean_square * scale_data[i]; } else { - p_output[h] = (p_output[h] - mean) / mean_square * scale_data[h] + bias_data[h]; + p_output[h] = (p_output[h] - mean) / mean_square * scale_data[i] + bias_data[i]; } } @@ -82,6 +87,7 @@ void ComputeJob( const MLFloat16* bias_data, const ptrdiff_t task_idx, const int64_t norm_size, + const int64_t broadcast_param, const float* scale_float_ptr, const float* bias_float_ptr, float epsilon, @@ -120,13 +126,16 @@ void ComputeJob( mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon); } - for (size_t h = 0; h < num_elems; h++) { + // Compute the offset of gamma and beta to support broadcasting. + int64_t i = LAYER_NORM_SCALE_BIAS_OFFSET(broadcast_param, task_idx, norm_size); + + for (size_t h = 0; h < num_elems; h++, i++) { if (simplified) { - output_float_ptr[h] = output_float_ptr[h] / mean_square * scale_float_ptr[h]; + output_float_ptr[h] = output_float_ptr[h] / mean_square * scale_float_ptr[i]; } else if (nullptr == bias_float_ptr) { - output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[h]; + output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[i]; } else { - output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[h] + bias_float_ptr[h]; + output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[i] + bias_float_ptr[i]; } } @@ -161,9 +170,7 @@ LayerNormImpl::LayerNormImpl(const OpKernelInfo& op_kernel_info, bool simplified simplified_{simplified}, contrib_op_{contrib_op}, prepacked_scale_fp32_data_(nullptr), - prepacked_scale_fp32_size_(0), - prepacked_bias_fp32_data_(nullptr), - prepacked_bias_fp32_size_(0) { + prepacked_bias_fp32_data_(nullptr) { ORT_ENFORCE(op_kernel_info.GetAttr("axis", &axis_).IsOK()); ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon_).IsOK()); } @@ -179,8 +186,8 @@ Status LayerNormImpl::ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, flo const T* bias_data = (simplified || nullptr == bias) ? nullptr : bias->Data(); const TensorShape& x_shape = X->Shape(); - size_t scale_size = scale ? static_cast(scale->Shape().Size()) : prepacked_scale_fp32_size_; - size_t bias_size = bias ? static_cast(bias->Shape().Size()) : prepacked_bias_fp32_size_; + const TensorShape& scale_shape = scale ? scale->Shape() : prepacked_scale_fp32_shape_; + const TensorShape& bias_shape = bias ? bias->Shape() : prepacked_bias_fp32_shape_; Tensor* Y = p_ctx->Output(0, x_shape); T* Y_data = Y->MutableData(); @@ -215,7 +222,7 @@ Status LayerNormImpl::ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, flo AllocatorPtr alloc; ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc)); - return ComputeWithoutContext(X_data, x_shape, scale_data, scale_size, bias_data, bias_size, Y_data, mean_data, + return ComputeWithoutContext(X_data, x_shape, scale_data, scale_shape, bias_data, bias_shape, Y_data, mean_data, inv_std_dev_data, thread_pool, axis, epsilon, simplified, alloc); } @@ -234,10 +241,10 @@ Status LayerNormImpl::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr is_packed = false; if (input_idx == 1) { // scale - prepacked_scale_fp32_size_ = static_cast(tensor.Shape().Size()); + prepacked_scale_fp32_shape_ = tensor.Shape(); ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_scale_fp32_data_, is_packed); } else if (input_idx == 2) { // bias - prepacked_bias_fp32_size_ = static_cast(tensor.Shape().Size()); + prepacked_bias_fp32_shape_ = tensor.Shape(); ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_bias_fp32_data_, is_packed); } @@ -249,9 +256,9 @@ Status LayerNormImpl::ComputeWithoutContext( const T* X_data, const TensorShape& x_shape, const T* scale_data, - size_t scale_size, + const TensorShape& scale_shape, const T* bias_data, - size_t bias_size, + const TensorShape& bias_shape, T* Y_data, U* mean_data, U* inv_std_dev_data, @@ -260,35 +267,28 @@ Status LayerNormImpl::ComputeWithoutContext( float epsilon, bool simplified, AllocatorPtr alloc) const { - int64_t norm_count = x_shape.SizeToDimension(onnxruntime::narrow(axis)); - int64_t norm_size = x_shape.SizeFromDimension(onnxruntime::narrow(axis)); - - if (static_cast(scale_size) != norm_size || (bias_data && static_cast(bias_size) != norm_size)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Size of X.shape()[axis:] == ", norm_size, - ". Size of scale and bias (if provided) must match this. Got scale size of ", - scale_size, " and bias size of ", bias_size); - } + LayerNormParams params; + ORT_RETURN_IF_ERROR(LayerNormHelper::CheckInputs(x_shape, scale_shape, bias_shape, bias_data != nullptr, axis, params)); IAllocatorUniquePtr scale_fp32; IAllocatorUniquePtr bias_fp32; if constexpr (std::is_same_v) { if (prepacked_scale_fp32_data_ == nullptr) { - const size_t num_elems = static_cast(norm_size); + const size_t num_elems = static_cast(params.scale_size); scale_fp32 = IAllocator::MakeUniquePtr(alloc, num_elems); MlasConvertHalfToFloatBuffer(scale_data, scale_fp32.get(), num_elems); } if (prepacked_bias_fp32_data_ == nullptr && bias_data) { - const size_t num_elems = static_cast(norm_size); + const size_t num_elems = static_cast(params.bias_size); bias_fp32 = IAllocator::MakeUniquePtr(alloc, num_elems); MlasConvertHalfToFloatBuffer(bias_data, bias_fp32.get(), num_elems); } } concurrency::ThreadPool::TryBatchParallelFor( - thread_pool, static_cast(norm_count), + thread_pool, static_cast(params.num_rows), [&](ptrdiff_t task_idx) { - ComputeJob(X_data, scale_data, bias_data, task_idx, norm_size, + ComputeJob(X_data, scale_data, bias_data, task_idx, params.norm_size, params.broadcast_param, prepacked_scale_fp32_data_ ? prepacked_scale_fp32_data_.get() : scale_fp32.get(), prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get() : bias_fp32.get(), epsilon, simplified, Y_data, mean_data, inv_std_dev_data, alloc); diff --git a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h index f8b528b398..a2debb1679 100644 --- a/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h +++ b/onnxruntime/core/providers/cpu/nn/layer_norm_impl.h @@ -24,9 +24,9 @@ class LayerNormImpl : public OpKernel { const T* X_data, const TensorShape& x_shape, const T* scale_data, - size_t scale_size, + const TensorShape& scale_shape, const T* bias_data, - size_t bias_size, + const TensorShape& bias_shape, T* Y_data, U* mean_data, U* inv_std_dev, @@ -64,9 +64,9 @@ class LayerNormImpl : public OpKernel { const bool simplified_; const bool contrib_op_; IAllocatorUniquePtr prepacked_scale_fp32_data_; - size_t prepacked_scale_fp32_size_; + TensorShape prepacked_scale_fp32_shape_; IAllocatorUniquePtr prepacked_bias_fp32_data_; - size_t prepacked_bias_fp32_size_; + TensorShape prepacked_bias_fp32_shape_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm.cc b/onnxruntime/core/providers/cuda/nn/layer_norm.cc index 7dd10f9c29..d479261855 100644 --- a/onnxruntime/core/providers/cuda/nn/layer_norm.cc +++ b/onnxruntime/core/providers/cuda/nn/layer_norm.cc @@ -4,6 +4,7 @@ #include "core/providers/shared_library/provider_api.h" #include "core/providers/cuda/nn/layer_norm.h" #include "core/providers/cuda/nn/layer_norm_impl.h" +#include "core/providers/cpu/nn/layer_norm_helper.h" #include "core/providers/cuda/cuda_common.h" namespace onnxruntime { @@ -44,20 +45,14 @@ Status LayerNorm::ComputeInternal(OpKernelContext* ctx) con auto bias_data = (simplified || (nullptr == bias)) ? nullptr : reinterpret_cast(bias->Data()); const TensorShape& x_shape = X->Shape(); - const int64_t axis = HandleNegativeAxis(axis_, x_shape.NumDimensions()); + auto x_num_dims = x_shape.NumDimensions(); + const int64_t axis = HandleNegativeAxis(axis_, x_num_dims); - int n1 = gsl::narrow(x_shape.SizeToDimension(axis)); - int n2 = gsl::narrow(x_shape.SizeFromDimension(axis)); + const TensorShape& scale_shape = scale->Shape(); + const TensorShape& bias_shape = bias_data ? bias->Shape() : TensorShape(); - const auto scale_size = scale->Shape().Size(); - const auto bias_size = (bias_data) ? bias->Shape().Size() : 0; - if (n2 == 1 || scale_size != n2 || (bias_data && bias_size != n2)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Size of X.shape()[axis:] == ", n2, - ". Size of scale and bias (if provided) must match this " - "and the size must not be 1. Got scale size of ", - scale_size, " and bias size of ", bias_size); - } + LayerNormParams params; + ORT_RETURN_IF_ERROR(LayerNormHelper::CheckInputs(x_shape, scale_shape, bias_shape, bias_data != nullptr, axis, params)); // Outputs Tensor* Y = ctx->Output(0, x_shape); @@ -65,7 +60,7 @@ Status LayerNorm::ComputeInternal(OpKernelContext* ctx) con // Mean and variance std::vector mean_inv_std_var_dim; - for (int i = 0; i < static_cast(x_shape.NumDimensions()); ++i) { + for (int i = 0; i < static_cast(x_num_dims); ++i) { if (i < axis) { mean_inv_std_var_dim.emplace_back(x_shape.GetDims()[i]); } else { @@ -93,8 +88,11 @@ Status LayerNorm::ComputeInternal(OpKernelContext* ctx) con return Status::OK(); } - HostApplyLayerNorm(GetDeviceProp(), Stream(ctx), Y_data, mean_data, inv_var_data, - X_data, n1, n2, epsilon_, scale_data, bias_data); + HostApplyLayerNorm( + GetDeviceProp(), Stream(ctx), Y_data, mean_data, inv_var_data, X_data, + onnxruntime::narrow(params.num_rows), onnxruntime::narrow(params.norm_size), epsilon_, + scale_data, bias_data, + onnxruntime::narrow(params.broadcast_param)); CUDA_RETURN_IF_ERROR(cudaGetLastError()); return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu index b9e8b45307..90b542beaa 100644 --- a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu +++ b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu @@ -23,8 +23,8 @@ /* Modifications Copyright (c) Microsoft. */ #include "core/providers/cuda/cu_inc/common.cuh" - #include "layer_norm_impl.h" +#include "core/providers/cpu/nn/layer_norm_helper.h" namespace onnxruntime { namespace cuda { @@ -334,6 +334,7 @@ __global__ void cuApplyLayerNorm( const U epsilon, const V* __restrict__ gamma, const V* __restrict__ beta, + int broadcast_param, const T* __restrict__ skip, const T* __restrict__ bias, T* __restrict__ skip_input_bias_add_output) { @@ -353,6 +354,10 @@ __global__ void cuApplyLayerNorm( V* ovals = output_vals + offset; T* skip_input_bias_add_ovals = (skip_input_bias_add_output != nullptr) ? skip_input_bias_add_output + offset : nullptr; U c_inv_std_dev = rsqrt(sigma2 + epsilon); + + // Compute the offset of gamma and beta to support broadcasting. + int gamma_beta_offset = LAYER_NORM_SCALE_BIAS_OFFSET(broadcast_param, i1, n2); + const int numx = blockDim.x * blockDim.y; const int thrx = threadIdx.x + threadIdx.y * blockDim.x; for (int i = thrx; i < n2; i += numx) { @@ -366,8 +371,10 @@ __global__ void cuApplyLayerNorm( curr += static_cast(skip_vals[i]); } - U gamma_i = (gamma != nullptr) ? (U)gamma[i] : (U)1; - U beta_i = (beta != nullptr) ? (U)beta[i] : (U)0; + int index = gamma_beta_offset + i; + U gamma_i = (gamma != nullptr) ? (U)gamma[index] : (U)1; + U beta_i = (beta != nullptr) ? (U)beta[index] : (U)0; + if (simplified) { ovals[i] = static_cast(gamma_i * c_inv_std_dev * curr); } else { @@ -409,6 +416,7 @@ void HostApplyLayerNorm( double epsilon, const V* gamma, const V* beta, + int broadcast_param, const T* skip, const T* bias, T* skip_input_bias_add_output) { @@ -442,15 +450,16 @@ void HostApplyLayerNorm( input, n1, n2, U(epsilon), - gamma, beta, + gamma, beta, broadcast_param, skip, bias, skip_input_bias_add_output); } -#define LAYERNORM_LINEAR_IMPL(T, U, V, simplified) \ - template void HostApplyLayerNorm(const cudaDeviceProp& prop, cudaStream_t stream, V* output, \ - U* mean, U* inv_std_dev, const T* input, int n1, int n2, \ - double epsilon, const V* gamma, const V* beta, const T* skip, \ - const T* bias, T* skip_input_bias_add_output); +#define LAYERNORM_LINEAR_IMPL(T, U, V, simplified) \ + template void HostApplyLayerNorm(const cudaDeviceProp& prop, cudaStream_t stream, V* output, \ + U* mean, U* inv_std_dev, const T* input, int n1, int n2, \ + double epsilon, const V* gamma, const V* beta, \ + int broadcast_param, \ + const T* skip, const T* bias, T* skip_input_bias_add_output); LAYERNORM_LINEAR_IMPL(float, float, float, true) LAYERNORM_LINEAR_IMPL(half, float, half, true) diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h index e3952eefae..4e74aa9ab6 100644 --- a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h +++ b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.h @@ -41,6 +41,7 @@ void HostApplyLayerNorm( double epsilon, const V* gamma, const V* beta, + int broadcast_param = 0, // parameter for broadcasting gamma/beta. const T* skip = nullptr, const T* bias = nullptr, T* skip_input_bias_add_output = nullptr); diff --git a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc index 52e67bf061..4611dc9082 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc @@ -4,6 +4,7 @@ #include #include #include "core/framework/tensor.h" +#include "core/providers/cpu/nn/layer_norm_helper.h" #include "core/session/inference_session.h" #include "test/common/dnnl_op_test_utils.h" #include "test/common/tensor_op_test_utils.h" @@ -20,6 +21,33 @@ using namespace std; namespace onnxruntime { namespace test { +// Some feature (like broadcast support) are implemented in CPU and CUDA/ROCM provider only. A helper to run tests. +void RunTestOnCpuAndCuda(OpTester& test, const std::string& expected_failure_msg = "") { + auto expected_result = expected_failure_msg.empty() + ? OpTester::ExpectResult::kExpectSuccess + : OpTester::ExpectResult::kExpectFailure; + + std::vector> cpu_execution_provider; + cpu_execution_provider.push_back(DefaultCpuExecutionProvider()); + test.Run(expected_result, expected_failure_msg, {}, nullptr, &cpu_execution_provider); + + constexpr int min_cuda_architecture = 0; + bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()); + if (enable_cuda || enable_rocm) { + std::vector> gpu_execution_provider; + if (enable_cuda) { + gpu_execution_provider.push_back(DefaultCudaExecutionProvider()); + } else if (enable_rocm) { + gpu_execution_provider.push_back(DefaultRocmExecutionProvider()); + } + + if (gpu_execution_provider.size() > 0) { + test.Run(expected_result, expected_failure_msg, {}, nullptr, &gpu_execution_provider); + } + } +} + TEST(LayerNormTest, BERTLayerNorm) { OpTester tester("LayerNormalization", 17 /*opset_version*/); tester.AddAttribute("axis", -1); @@ -210,6 +238,106 @@ TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16ScaleBiasOutput) { kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider, kWebGpuExecutionProvider}); } +TEST(LayerNormTest, LayerNorm_Scale_Bias_NoBroadcast) { + OpTester test("LayerNormalization"); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{2, 2, 2}; + test.AddInput("x", dims, {-1.0f, 2.0f, 3.0f, -4.0f, -10.264f, 8.6453f, 43.1561f, -0.641239f}); + test.AddInput("gamma", {2, 2, 2}, {-0.1f, 1.7f, -0.6953f, 5.1824f, -0.1f, 1.7f, -0.6953f, 5.1824f}); + test.AddInput("bias", {2, 2, 2}, {-2.0f, 0.3f, 0.0f, 0.0f, -2.0f, 0.3f, 0.0f, 0.0f}); + test.AddOutput("output", dims, {-1.9f, 2.0f, -0.6953f, -5.1824f, -1.9f, 2.0f, -0.6953f, -5.1824f}); + + test.SetOutputTolerance(0.0001f); + + RunTestOnCpuAndCuda(test); +} + +TEST(LayerNormTest, LayerNorm_Scale_Bias_NoBroadcast_Fp16) { + OpTester test("LayerNormalization"); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{2, 2, 2}; + test.AddInput("x", dims, ToFloat16({-1.0f, 2.0f, 3.0f, -4.0f, -10.264f, 8.6453f, 43.1561f, -0.641239f})); + test.AddInput("gamma", {2, 2, 2}, ToFloat16({-0.1f, 1.7f, -0.6953f, 5.1824f, -0.1f, 1.7f, -0.6953f, 5.1824f})); + test.AddInput("bias", {2, 2, 2}, ToFloat16({-2.0f, 0.3f, 0.0f, 0.0f, -2.0f, 0.3f, 0.0f, 0.0f})); + test.AddOutput("output", dims, ToFloat16({-1.9f, 2.0f, -0.6953f, -5.1824f, -1.9f, 2.0f, -0.6953f, -5.1824f})); + + RunTestOnCpuAndCuda(test); +} + +TEST(LayerNormTest, LayerNorm_Scale_Bias_Broadcast_Dim0) { + OpTester test("LayerNormalization"); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{4, 2, 2}; + test.AddInput("x", dims, {-1.0f, 2.0f, -10.264f, 8.6453f, 3.0f, -4.0f, 43.1561f, -0.641239f, -5.0f, 6.0f, -8.2164f, 0.11412f, 7.0f, 8.0f, 41.3156f, 3.0458f}); + test.AddInput("gamma", {1, 2, 2}, {-0.1f, 1.7f, -0.6953f, 5.1824f}); + test.AddInput("bias", {1, 2, 2}, {-2.0f, 0.3f, 0.0f, 0.0f}); + test.AddOutput("output", dims, {-1.9f, 2.0f, 0.6953f, 5.1824f, -2.1f, -1.4f, -0.6953f, -5.1824f, -1.9f, 2.0f, 0.6953f, 5.1824f, -1.9f, 2.0f, -0.6953f, -5.1824f}); + test.SetOutputTolerance(0.0001f); + + RunTestOnCpuAndCuda(test); +} + +TEST(LayerNormTest, LayerNorm_Scale_Bias_Broadcast_Dim0_Fp16) { + OpTester test("LayerNormalization"); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{4, 2, 2}; + test.AddInput("x", dims, ToFloat16({-1.0f, 2.0f, -10.264f, 8.6453f, 3.0f, -4.0f, 43.1561f, -0.641239f, -5.0f, 6.0f, -8.2164f, 0.11412f, 7.0f, 8.0f, 41.3156f, 3.0458f})); + test.AddInput("gamma", {1, 2, 2}, ToFloat16({-0.1f, 1.7f, -0.6953f, 5.1824f})); + test.AddInput("bias", {1, 2, 2}, ToFloat16({-2.0f, 0.3f, 0.0f, 0.0f})); + test.AddOutput("output", dims, ToFloat16({-1.9f, 2.0f, 0.6953f, 5.1824f, -2.1f, -1.4f, -0.6953f, -5.1824f, -1.9f, 2.0f, 0.6953f, 5.1824f, -1.9f, 2.0f, -0.6953f, -5.1824f})); + + RunTestOnCpuAndCuda(test); +} + +TEST(LayerNormTest, LayerNorm_Scale_Bias_Broadcast_Dim1) { + OpTester test("LayerNormalization"); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{2, 4, 2}; + test.AddInput("x", dims, {-1.0f, 2.0f, 3.0f, -4.0f, -5.0f, 6.0f, 7.0f, 8.0f, -10.264f, 8.6453f, 43.1561f, -0.641239f, -8.2164f, 0.11412f, 41.3156f, 3.0458f}); + test.AddInput("gamma", {2, 1, 2}, {-0.1f, 1.7f, -0.6953f, 5.1824f}); + test.AddInput("bias", {2, 1, 2}, {-2.0f, 0.3f, 0.0f, 0.0f}); + test.AddOutput("output", dims, {-1.9f, 2.0f, -2.1f, -1.4f, -1.9f, 2.0f, -1.9f, 2.0f, 0.6953f, 5.1824f, -0.6953f, -5.1824f, 0.6953f, 5.1824f, -0.6953f, -5.1824f}); + test.SetOutputTolerance(0.0001f); + + RunTestOnCpuAndCuda(test); +} + +TEST(LayerNormTest, LayerNorm_Scale_Bias_Broadcast_Dim1_Fp16) { + OpTester test("LayerNormalization"); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{2, 4, 2}; + test.AddInput("x", dims, ToFloat16({-1.0f, 2.0f, 3.0f, -4.0f, -5.0f, 6.0f, 7.0f, 8.0f, -10.264f, 8.6453f, 43.1561f, -0.641239f, -8.2164f, 0.11412f, 41.3156f, 3.0458f})); + test.AddInput("gamma", {2, 1, 2}, ToFloat16({-0.1f, 1.7f, -0.6953f, 5.1824f})); + test.AddInput("bias", {2, 1, 2}, ToFloat16({-2.0f, 0.3f, 0.0f, 0.0f})); + test.AddOutput("output", dims, ToFloat16({-1.9f, 2.0f, -2.1f, -1.4f, -1.9f, 2.0f, -1.9f, 2.0f, 0.6953f, 5.1824f, -0.6953f, -5.1824f, 0.6953f, 5.1824f, -0.6953f, -5.1824f})); + + RunTestOnCpuAndCuda(test); +} + +TEST(LayerNormTest, LayerNorm_Scale_Bias_Broadcast_Fp16) { + auto run_test = [](bool is_initializer) { + OpTester test("LayerNormalization"); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{1, 3, 2}; + test.AddInput("x", dims, ToFloat16({1.2416f, 0.946123f, 13.1685f, 0.36423f, 21.145f, 0.03941f})); + test.AddInput("gamma", {1, 1, 2}, ToFloat16({-0.6953f, 5.1824f}), is_initializer); + test.AddInput("bias", {1, 1, 2}, ToFloat16({0.6435f, -0.3964f}), is_initializer); + test.AddOutput("output", dims, ToFloat16({-0.0516f, -5.5776f, -0.0518f, -5.5788f, -0.0518f, -5.5788f})); + + RunTestOnCpuAndCuda(test); + }; + + run_test(false); + run_test(true); +} + TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16InputScaleBiasOutput) { auto run_test = [](bool is_initializer) { OpTester test("LayerNormalization"); @@ -300,6 +428,21 @@ TEST(LayerNormTest, LayerNorm17_double) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDnnlExecutionProvider}); } +// Test normalize size shall be larger than 1. +TEST(LayerNormTest, LayerNorm_InvalidNormSize) { + OpTester test("LayerNormalization"); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{1, 3, 1}; + test.AddInput("x", dims, {1.2416f, 0.946123f, 13.1685f}); + test.AddInput("gamma", {1}, {-0.6953f}); + test.AddInput("bias", {1}, {0.6435f}); + test.AddAttribute("axis", 2); + test.AddOutput("output", dims, {-0.0516f, -5.5776f, -0.0518f}); + + RunTestOnCpuAndCuda(test, kLayerNormInvalidSize); +} + TEST(LayerNormTest, LayerNorm_InvalidScaleBias) { OpTester test("LayerNormalization"); test.AddAttribute("epsilon", 1e-05f); @@ -311,11 +454,10 @@ TEST(LayerNormTest, LayerNorm_InvalidScaleBias) { test.AddInput("bias", {2}, {0.6435f, -0.3964f}); test.AddAttribute("axis", 1); test.AddOutput("output", dims, {-0.0516f, -5.5776f, -0.0518f, -5.5788f, -0.0518f, -5.5788f}); + // CPU and CUDA EPs have check for unexpected scale or bias sizes. Exclude other EPs with a LayerNormalization // implementation for which we don't control the check or error message. - test.Run(OpTester::ExpectResult::kExpectFailure, - "Size of X.shape()[axis:] == 6. Size of scale and bias (if provided) must match this", - {kDnnlExecutionProvider, kDmlExecutionProvider, kTensorrtExecutionProvider}); + RunTestOnCpuAndCuda(test, kLayerNormInputShapeMismatchError); } #if defined(USE_DNNL) diff --git a/onnxruntime/test/onnx/microbenchmark/layer_normalization.cc b/onnxruntime/test/onnx/microbenchmark/layer_normalization.cc index f6158d8cbc..0fccc68c59 100644 --- a/onnxruntime/test/onnx/microbenchmark/layer_normalization.cc +++ b/onnxruntime/test/onnx/microbenchmark/layer_normalization.cc @@ -114,9 +114,9 @@ static void BM_LayerNormalization(benchmark::State& state) { auto status = layer_norm_impl.ComputeWithoutContext(x_data, x_shape, scale_data, - static_cast(scale_shape.Size()), + scale_shape, bias_data, - static_cast(bias_shape.Size()), + bias_shape, Y_data, mean_data, inv_std_dev_data,