mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-26 22:35:43 +00:00
Broadcasting for SLN for CPU and CUDA (#16510)
### Description
Enhanced SkipLayerNorm by implementing broadcasting for both CPU and
CUDA
### Motivation and Context
The input and skip tensors no longer have to be the same size which
means that it can accept data where the skip shape can be the same size
as the input shape, have a shape of {1, sequence_length, hidden_size},
or {sequence_length, hidden_size}.
---------
Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
This commit is contained in:
parent
3649376f09
commit
4e6ea730d6
12 changed files with 276 additions and 117 deletions
1
.vscode/settings.json
vendored
1
.vscode/settings.json
vendored
|
|
@ -40,4 +40,3 @@
|
|||
"-build/include_subdir",
|
||||
"-runtime/references"
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4651,7 +4651,7 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
<dt><tt>input</tt> : T</dt>
|
||||
<dd>3D input tensor with shape (batch_size, sequence_length, hidden_size)</dd>
|
||||
<dt><tt>skip</tt> : T</dt>
|
||||
<dd>3D skip tensor with shape (batch_size, sequence_length, hidden_size)</dd>
|
||||
<dd>3D skip tensor with shape (batch_size, sequence_length, hidden_size) or (1, sequence_length, hidden_size) or (sequence_length, hidden_size)</dd>
|
||||
<dt><tt>gamma</tt> : T</dt>
|
||||
<dd>1D input tensor with shape (hidden_size)</dd>
|
||||
<dt><tt>beta</tt> (optional) : T</dt>
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#include "core/providers/common.h"
|
||||
#include "core/platform/threadpool.h"
|
||||
#include "skip_layer_norm.h"
|
||||
#include "skip_layer_norm_helper.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -45,51 +46,15 @@ Status SkipLayerNorm<T>::Compute(OpKernelContext* p_ctx) const {
|
|||
|
||||
const auto& input_dims = input->Shape().GetDims();
|
||||
size_t input_dims_size = input_dims.size();
|
||||
if (input_dims_size != 3 && input_dims_size != 2) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"input is expected to have 3 or 2 dimensions, got ", input_dims_size);
|
||||
}
|
||||
|
||||
int hidden_size = static_cast<int>(input_dims[input_dims_size - 1]);
|
||||
|
||||
if (input->Shape() != skip->Shape()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"skip is expected to have same shape as input");
|
||||
}
|
||||
|
||||
const auto& gamma_dims = gamma->Shape().GetDims();
|
||||
if (gamma_dims.size() != 1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"gamma is expected to have 1 dimension, got ", gamma_dims.size());
|
||||
}
|
||||
if (gamma_dims[0] != hidden_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Last dimension of gamma and input does not match");
|
||||
}
|
||||
|
||||
if (nullptr != beta) {
|
||||
const auto& beta_dims = beta->Shape().GetDims();
|
||||
if (beta_dims.size() != 1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"beta is expected to have 1 dimension, got ", beta_dims.size());
|
||||
}
|
||||
if (beta_dims[0] != hidden_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Last dimension of beta and input does not match");
|
||||
}
|
||||
}
|
||||
|
||||
if (nullptr != bias) {
|
||||
const auto& bias_dims = bias->Shape().GetDims();
|
||||
if (bias_dims.size() != 1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"bias is expected to have 1 dimension, got ", bias_dims.size());
|
||||
}
|
||||
if (bias_dims[0] != hidden_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Last dimension of bias and input does not match");
|
||||
}
|
||||
}
|
||||
ORT_RETURN_IF_ERROR(onnxruntime::contrib::skip_layer_norm_helper::CheckInputs<Tensor>(input,
|
||||
skip,
|
||||
gamma,
|
||||
beta,
|
||||
bias,
|
||||
hidden_size,
|
||||
input_dims_size));
|
||||
|
||||
int64_t task_count = input->Shape().SizeToDimension(input_dims_size - 1);
|
||||
|
||||
|
|
@ -105,13 +70,15 @@ Status SkipLayerNorm<T>::Compute(OpKernelContext* p_ctx) const {
|
|||
// of the input and skip tensors
|
||||
T* skip_input_bias_add_output_data = skip_input_bias_add_output != nullptr ? skip_input_bias_add_output->MutableData<T>() : nullptr;
|
||||
|
||||
const auto& skip_size = skip->Shape().Size();
|
||||
|
||||
concurrency::ThreadPool::TryBatchParallelFor(
|
||||
p_ctx->GetOperatorThreadPool(), static_cast<int32_t>(task_count),
|
||||
[&](ptrdiff_t task_idx) {
|
||||
auto offset = task_idx * hidden_size;
|
||||
|
||||
const T* p_input = input_data + offset;
|
||||
const T* p_skip = skip_data + offset;
|
||||
const T* p_skip = skip_data + (offset % skip_size);
|
||||
T* p_output = output_data + offset;
|
||||
T* p_skip_input_bias_add_output_data = skip_input_bias_add_output_data != nullptr ? skip_input_bias_add_output_data + offset : nullptr;
|
||||
|
||||
|
|
|
|||
84
onnxruntime/contrib_ops/cpu/skip_layer_norm_helper.h
Normal file
84
onnxruntime/contrib_ops/cpu/skip_layer_norm_helper.h
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "contrib_ops/cpu/bert/attention_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace skip_layer_norm_helper {
|
||||
|
||||
template <typename T>
|
||||
Status CheckInputs(const T* input,
|
||||
const T* skip,
|
||||
const T* gamma,
|
||||
const T* beta,
|
||||
const T* bias,
|
||||
int hidden_size_check,
|
||||
size_t input_dims_size_check) {
|
||||
const auto& input_dims_check = input->Shape().GetDims();
|
||||
const auto& skip_dims_check = skip->Shape().GetDims();
|
||||
size_t skip_dims_size_check = skip_dims_check.size();
|
||||
|
||||
if (skip_dims_size_check != 3 && skip_dims_size_check != 2) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"skip is expected to have 3 or 2 dimensions, got ", skip_dims_size_check);
|
||||
}
|
||||
|
||||
if ((input->Shape() != skip->Shape()) && ((skip_dims_check[0] != 1 || skip_dims_size_check != 2) && input_dims_size_check != 3)) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"skip is expected to have same shape as input or, a batch size of 1 or no batch size when input has 3 dimensions");
|
||||
}
|
||||
|
||||
if (input_dims_size_check != 3 && input_dims_size_check != 2) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"input is expected to have 3 or 2 dimensions, got ", input_dims_size_check);
|
||||
}
|
||||
|
||||
if (skip_dims_check[skip_dims_size_check - 1] != input_dims_check[input_dims_size_check - 1] || skip_dims_check[skip_dims_size_check - 2] != input_dims_check[input_dims_size_check - 2]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"last two dimensions of skip needs to be same as input");
|
||||
}
|
||||
|
||||
const auto& gamma_dims = gamma->Shape().GetDims();
|
||||
if (gamma_dims.size() != 1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"gamma is expected to have 1 dimension, got ", gamma_dims.size());
|
||||
}
|
||||
if (gamma_dims[0] != hidden_size_check) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Last dimension of gamma and input does not match");
|
||||
}
|
||||
|
||||
if (nullptr != beta) {
|
||||
const auto& beta_dims = beta->Shape().GetDims();
|
||||
if (beta_dims.size() != 1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"beta is expected to have 1 dimension, got ", beta_dims.size());
|
||||
}
|
||||
if (beta_dims[0] != hidden_size_check) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Last dimension of beta and input does not match");
|
||||
}
|
||||
}
|
||||
|
||||
if (nullptr != bias) {
|
||||
const auto& bias_dims = bias->Shape().GetDims();
|
||||
if (bias_dims.size() != 1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"bias is expected to have 1 dimension, got ", bias_dims.size());
|
||||
}
|
||||
if (bias_dims[0] != hidden_size_check) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Last dimension of bias and input does not match");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace skip_layer_norm_helper
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -5,6 +5,7 @@
|
|||
#include "core/providers/cuda/nn/layer_norm_impl.h"
|
||||
#include "skip_layer_norm.h"
|
||||
#include "skip_layer_norm_impl.h"
|
||||
#include "contrib_ops/cpu/skip_layer_norm_helper.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -60,59 +61,23 @@ Status SkipLayerNorm<T, Simplified>::ComputeInternal(OpKernelContext* ctx) const
|
|||
// of the input and skip tensors
|
||||
Tensor* skip_input_bias_add_output = ctx->Output(3, input->Shape());
|
||||
|
||||
if (input->Shape() != skip->Shape()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"skip is expected to have same shape as input");
|
||||
}
|
||||
|
||||
if (input->Shape().Size() == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const auto& input_dims = input->Shape().GetDims();
|
||||
size_t input_dims_size = input_dims.size();
|
||||
if (input_dims_size != 3 && input_dims_size != 2) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"input is expected to have 3 or 2 dimensions, got ", input_dims_size);
|
||||
}
|
||||
const auto& skip_dims = skip->Shape().GetDims();
|
||||
size_t skip_dims_size = skip_dims.size();
|
||||
|
||||
int hidden_size = static_cast<int>(input_dims[input_dims_size - 1]);
|
||||
|
||||
const auto& gamma_dims = gamma->Shape().GetDims();
|
||||
if (gamma_dims.size() != 1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"gamma is expected to have 1 dimension, got ", gamma_dims.size());
|
||||
}
|
||||
if (gamma_dims[0] != hidden_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Last dimension of gamma and input does not match");
|
||||
}
|
||||
ORT_RETURN_IF_ERROR(onnxruntime::contrib::skip_layer_norm_helper::CheckInputs<Tensor>(input,
|
||||
skip,
|
||||
gamma,
|
||||
beta,
|
||||
bias,
|
||||
hidden_size,
|
||||
input_dims_size));
|
||||
|
||||
if (!Simplified) {
|
||||
if (nullptr != beta) {
|
||||
const auto& beta_dims = beta->Shape().GetDims();
|
||||
if (beta_dims.size() != 1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"beta is expected to have 1 dimension, got ", beta_dims.size());
|
||||
}
|
||||
if (beta_dims[0] != hidden_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Last dimension of beta and input does not match");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (nullptr != bias) {
|
||||
const auto& bias_dims = bias->Shape().GetDims();
|
||||
if (bias_dims.size() != 1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"bias is expected to have 1 dimension, got ", bias_dims.size());
|
||||
}
|
||||
if (bias_dims[0] != hidden_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Last dimension of bias and input does not match");
|
||||
}
|
||||
}
|
||||
const bool skip_broadcasted = (skip_dims[0] == 1 || skip_dims_size == 2) ? true : false;
|
||||
const int skip_size = static_cast<int>(skip_dims[skip_dims_size - 1] * skip_dims[skip_dims_size - 2]);
|
||||
|
||||
if (strict_) {
|
||||
int row_count = gsl::narrow<int>(input->Shape().SizeToDimension(input_dims_size - 1));
|
||||
|
|
@ -131,7 +96,9 @@ Status SkipLayerNorm<T, Simplified>::ComputeInternal(OpKernelContext* ctx) const
|
|||
(beta != nullptr) ? reinterpret_cast<const CudaT*>(beta->Data<T>()) : nullptr, // beta
|
||||
reinterpret_cast<const CudaT*>(skip->Data<T>()), // skip or residual to add
|
||||
(bias != nullptr) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr, // bias to add
|
||||
skip_input_bias_add_output != nullptr ? reinterpret_cast<CudaT*>(skip_input_bias_add_output->MutableData<T>()) : nullptr);
|
||||
skip_input_bias_add_output != nullptr ? reinterpret_cast<CudaT*>(skip_input_bias_add_output->MutableData<T>()) : nullptr,
|
||||
skip_broadcasted,
|
||||
skip_size);
|
||||
|
||||
CUDA_RETURN_IF_ERROR(cudaGetLastError());
|
||||
return Status::OK();
|
||||
|
|
@ -153,7 +120,9 @@ Status SkipLayerNorm<T, Simplified>::ComputeInternal(OpKernelContext* ctx) const
|
|||
epsilon_,
|
||||
hidden_size,
|
||||
static_cast<int>(element_count),
|
||||
element_size);
|
||||
element_size,
|
||||
skip_broadcasted,
|
||||
skip_size);
|
||||
|
||||
CUDA_RETURN_IF_ERROR(cudaGetLastError());
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ template <typename T, unsigned TPB, bool Simplified>
|
|||
__global__ void SkipLayerNormKernel(
|
||||
const int ld, const T* input, const T* skip,
|
||||
const T* beta, const T* gamma, const T* bias,
|
||||
const T epsilon, T* output, T* skip_input_bias_add_output) {
|
||||
const T epsilon, T* output, T* skip_input_bias_add_output, const bool skip_broadcasted, int skip_size) {
|
||||
const T reverse_ld = T(1.f / ld);
|
||||
const int offset = blockIdx.x * ld;
|
||||
|
||||
|
|
@ -90,10 +90,13 @@ __global__ void SkipLayerNormKernel(
|
|||
// reduce x and x^2
|
||||
cub::KeyValuePair<T, T> thread_data(0, 0);
|
||||
|
||||
|
||||
for (int i = threadIdx.x; i < ld; i += TPB) {
|
||||
const int idx = offset + i;
|
||||
|
||||
const T val = (bias == nullptr) ? input[idx] + skip[idx] : input[idx] + skip[idx] + bias[i];
|
||||
const T skip_data = skip_broadcasted ? skip[idx % skip_size] : skip[idx];
|
||||
const T val = (bias == nullptr) ? input[idx] + skip_data : input[idx] + skip_data + bias[i];
|
||||
|
||||
const T rldval = reverse_ld * val;
|
||||
thread_data = pair_sum(thread_data, cub::KeyValuePair<T, T>(rldval, rldval * val));
|
||||
|
||||
|
|
@ -115,7 +118,7 @@ template <typename T, unsigned TPB, int ILP, bool Simplified>
|
|||
__global__ void SkipLayerNormKernelSmall(
|
||||
const int ld, const T* input, const T* skip, const T* beta, const T* gamma,
|
||||
const T* bias, const T epsilon, T* output, T* skip_input_bias_add_output,
|
||||
bool hasBias, bool hasSkipInputBiasAdditionOutput) {
|
||||
bool hasBias, bool hasSkipInputBiasAdditionOutput, const bool skip_broadcasted, const int skip_size) {
|
||||
const T rld = T(1.f / ld);
|
||||
const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld
|
||||
|
||||
|
|
@ -127,7 +130,11 @@ __global__ void SkipLayerNormKernelSmall(
|
|||
*input_val = *reinterpret_cast<const VecT*>(&input[idx]);
|
||||
|
||||
VecT* skip_val = reinterpret_cast<VecT*>(&skip_v);
|
||||
if (skip_broadcasted){
|
||||
*skip_val = *reinterpret_cast<const VecT*>(&skip[idx % skip_size]);
|
||||
}else{
|
||||
*skip_val = *reinterpret_cast<const VecT*>(&skip[idx]);
|
||||
}
|
||||
|
||||
if (hasBias) {
|
||||
VecT* bias_val = reinterpret_cast<VecT*>(&bias_v);
|
||||
|
|
@ -170,8 +177,13 @@ template <typename T, bool Simplified>
|
|||
Status LaunchSkipLayerNormKernel(
|
||||
cudaStream_t stream, T* output, T* skip_input_bias_add_output, const T* input, const T* skip, const T* gamma,
|
||||
const T* beta, const T* bias, float epsilon, const int ld, const int element_count,
|
||||
size_t element_size) {
|
||||
size_t element_size, const bool skip_broadcasted, const int skip_size) {
|
||||
// this must be true because n is the total size of the tensor
|
||||
|
||||
if (element_count == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
assert(element_count % ld == 0);
|
||||
bool hasBias = (bias == nullptr) ? false : true;
|
||||
bool hasSkipInputBiasAdditionOutput = (skip_input_bias_add_output == nullptr) ? false : true;
|
||||
|
|
@ -187,10 +199,10 @@ Status LaunchSkipLayerNormKernel(
|
|||
#define LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL(num_unroll) \
|
||||
SkipLayerNormKernelSmall<T, block_size, num_unroll, Simplified> \
|
||||
<<<grid_size, block_size, 0, stream>>>(ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, \
|
||||
skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput)
|
||||
skip_input_bias_add_output, hasBias, hasSkipInputBiasAdditionOutput, skip_broadcasted, skip_size)
|
||||
#define LAUNCH_SKIP_LAYER_NORM_KERNEL() \
|
||||
SkipLayerNormKernel<T, kMaxBlockSize, Simplified><<<grid_size, kMaxBlockSize, 0, stream>>>( \
|
||||
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, skip_input_bias_add_output)
|
||||
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, skip_input_bias_add_output, skip_broadcasted, skip_size)
|
||||
#define CASE_NEXT_SIZE(next_size_value) \
|
||||
case next_size_value: { \
|
||||
if (flag_vec4) { \
|
||||
|
|
@ -219,7 +231,6 @@ Status LaunchSkipLayerNormKernel(
|
|||
#undef LAUNCH_SKIP_LAYER_NORM_KERNEL
|
||||
#undef LAUNCH_SKIP_LAYER_NORM_KERNEL_SMALL
|
||||
}
|
||||
|
||||
return CUDA_CALL(cudaGetLastError());
|
||||
}
|
||||
|
||||
|
|
@ -229,7 +240,8 @@ Status LaunchSkipLayerNormKernel(
|
|||
const T* input, const T* skip, const T* gamma, \
|
||||
const T* beta, const T* bias, float epsilon, \
|
||||
const int ld, const int element_count, \
|
||||
size_t element_size);
|
||||
size_t element_size, const bool skip_broadcasted, \
|
||||
const int skip_size);
|
||||
|
||||
SKIPLAYERNORM_IMPL(float, true);
|
||||
SKIPLAYERNORM_IMPL(float, false);
|
||||
|
|
|
|||
|
|
@ -21,7 +21,9 @@ Status LaunchSkipLayerNormKernel(
|
|||
float epsilon, // Layer normalization epsilon
|
||||
int hidden_size, // hidden size, it is the leading dimension (ld)
|
||||
int element_count, // number of elements in input tensor
|
||||
size_t element_size);
|
||||
size_t element_size,
|
||||
const bool skip_broadcasted, // determines if broadcasting should be implemented
|
||||
const int skip_size); // determines size of the skip tensor
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
|
|
|
|||
|
|
@ -1097,7 +1097,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
|
|||
.SetDoc("Skip and Layer Normalization Fusion")
|
||||
.Attr("epsilon", "The epsilon value to use to avoid division by zero.", AttributeProto::FLOAT, kDefaultSkipLayerNormEpsilon)
|
||||
.Input(0, "input", "3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T")
|
||||
.Input(1, "skip", "3D skip tensor with shape (batch_size, sequence_length, hidden_size)", "T")
|
||||
.Input(1, "skip", "3D skip tensor with shape (batch_size, sequence_length, hidden_size) or (1, sequence_length, hidden_size) or (sequence_length, hidden_size)", "T")
|
||||
.Input(2, "gamma", "1D input tensor with shape (hidden_size)", "T")
|
||||
.Input(3, "beta", "1D skip tensor with shape (hidden_size", "T", OpSchema::Optional)
|
||||
.Input(4, "bias", "1D bias tensor with shape (hidden_size", "T", OpSchema::Optional)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,8 @@ using namespace onnxruntime::common;
|
|||
namespace onnxruntime {
|
||||
|
||||
// LayerNorm supports limited data types.
|
||||
static constexpr std::array supported_data_types{"tensor(float16)", "tensor(float)", "tensor(bfloat16)"};
|
||||
static constexpr std::array supported_data_types{
|
||||
"tensor(float16)", "tensor(float)", "tensor(bfloat16)"};
|
||||
|
||||
static bool IsSupportedDataType(const Node& node) {
|
||||
for (const auto& input_arg : node.InputDefs()) {
|
||||
|
|
|
|||
|
|
@ -338,7 +338,9 @@ __global__ void cuApplyLayerNorm(
|
|||
const V* __restrict__ beta,
|
||||
const T* __restrict__ skip,
|
||||
const T* __restrict__ bias,
|
||||
T* __restrict__ skip_input_bias_add_output) {
|
||||
T* __restrict__ skip_input_bias_add_output,
|
||||
const bool skip_broadcasted,
|
||||
const int skip_size) {
|
||||
// Assumptions:
|
||||
// 1) blockDim.x == GPU_WARP_SIZE
|
||||
// 2) Tensors are contiguous
|
||||
|
|
@ -358,11 +360,16 @@ __global__ void cuApplyLayerNorm(
|
|||
for (int i = thrx; i < n2; i += numx) {
|
||||
U curr = static_cast<U>(lvals[i]);
|
||||
|
||||
|
||||
|
||||
if (bias != NULL) {
|
||||
curr += static_cast<U>(bias[i]);
|
||||
}
|
||||
|
||||
if (skip_vals != NULL) {
|
||||
if (skip_vals != NULL && skip_broadcasted) {
|
||||
int skip_i = i % skip_size;
|
||||
curr += static_cast<U>(skip_vals[skip_i]); //Calculates index for the second dimension of the skip tensor
|
||||
}else if (skip_vals != NULL){
|
||||
curr += static_cast<U>(skip_vals[i]);
|
||||
}
|
||||
|
||||
|
|
@ -411,7 +418,9 @@ void HostApplyLayerNorm(
|
|||
const V* beta,
|
||||
const T* skip,
|
||||
const T* bias,
|
||||
T* skip_input_bias_add_output) {
|
||||
T* skip_input_bias_add_output,
|
||||
const bool skip_broadcasted,
|
||||
const int skip_size) {
|
||||
const int maxGridY = prop.maxGridSize[1];
|
||||
const int warp_size = prop.warpSize;
|
||||
ORT_ENFORCE(warp_size == GPU_WARP_SIZE_HOST);
|
||||
|
|
@ -443,14 +452,17 @@ void HostApplyLayerNorm(
|
|||
n1, n2,
|
||||
U(epsilon),
|
||||
gamma, beta,
|
||||
skip, bias, skip_input_bias_add_output);
|
||||
skip, bias, skip_input_bias_add_output,
|
||||
skip_broadcasted,
|
||||
skip_size);
|
||||
}
|
||||
|
||||
#define LAYERNORM_LINEAR_IMPL(T, U, V, simplified) \
|
||||
template void HostApplyLayerNorm<T, U, V, simplified>(const cudaDeviceProp& prop, cudaStream_t stream, V* output, \
|
||||
U* mean, U* inv_std_dev, const T* input, int n1, int n2, \
|
||||
double epsilon, const V* gamma, const V* beta, const T* skip, \
|
||||
const T* bias, T* skip_input_bias_add_output);
|
||||
const T* bias, T* skip_input_bias_add_output, const bool skip_broadcasted, \
|
||||
const int skip_size);
|
||||
|
||||
LAYERNORM_LINEAR_IMPL(float, float, float, true)
|
||||
LAYERNORM_LINEAR_IMPL(half, float, half, true)
|
||||
|
|
|
|||
|
|
@ -43,7 +43,9 @@ void HostApplyLayerNorm(
|
|||
const V* beta,
|
||||
const T* skip = nullptr,
|
||||
const T* bias = nullptr,
|
||||
T* skip_input_bias_add_output = nullptr);
|
||||
T* skip_input_bias_add_output = nullptr,
|
||||
const bool skip_broadcasted = false,
|
||||
const int skip_size = 0);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -27,18 +27,31 @@ static void RunTest(
|
|||
bool no_beta = false,
|
||||
bool simplified = false,
|
||||
bool use_token_count = false,
|
||||
bool strict = false) {
|
||||
bool strict = false,
|
||||
bool broadcast_skip = false,
|
||||
bool no_batch_size = false) {
|
||||
// Input and output shapes
|
||||
// Input 0 - input: (batch_size, sequence_length, hidden_size) or (batch_size * sequence_length, hidden_size)
|
||||
// Input 1 - skip : (batch_size, sequence_length, hidden_size) or (batch_size * sequence_length, hidden_size)
|
||||
// Input 1 - skip : (batch_size, sequence_length, hidden_size) or (batch_size * sequence_length, hidden_size) or (1, sequence_length, hidden_size) or (sequence_length, hidden_size)
|
||||
// Input 2 - gamma: (hidden_size)
|
||||
// Input 3 - beta : (hidden_size)
|
||||
// Output : (batch_size, sequence_length, hidden_size) or (batch_size * sequence_length, hidden_size)
|
||||
std::vector<int64_t> input_dims = {batch_size, sequence_length, hidden_size};
|
||||
std::vector<int64_t> skip_dims = input_dims;
|
||||
|
||||
if (use_token_count) {
|
||||
input_dims = {batch_size * sequence_length, hidden_size};
|
||||
skip_dims = input_dims;
|
||||
}
|
||||
std::vector<int64_t> skip_dims = input_dims;
|
||||
|
||||
if (broadcast_skip) {
|
||||
skip_dims = {1, sequence_length, hidden_size};
|
||||
}
|
||||
|
||||
if (no_batch_size) {
|
||||
skip_dims = {sequence_length, hidden_size};
|
||||
}
|
||||
|
||||
std::vector<int64_t> gamma_dims = {hidden_size};
|
||||
std::vector<int64_t> beta_dims = gamma_dims;
|
||||
std::vector<int64_t> bias_dims = gamma_dims;
|
||||
|
|
@ -48,6 +61,8 @@ static void RunTest(
|
|||
|
||||
auto rocm_ep = DefaultRocmExecutionProvider();
|
||||
auto dml_ep = DefaultDmlExecutionProvider();
|
||||
auto cpu_ep = DefaultCpuExecutionProvider();
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
if (!use_float16) {
|
||||
OpTester test(op_type.c_str(), 1, onnxruntime::kMSDomain);
|
||||
test.AddInput<float>("input", input_dims, input_data);
|
||||
|
|
@ -77,7 +92,10 @@ static void RunTest(
|
|||
skip_input_bias_add_output_data);
|
||||
}
|
||||
|
||||
test.Run();
|
||||
if (cpu_ep != nullptr) {
|
||||
execution_providers.push_back(DefaultCpuExecutionProvider());
|
||||
}
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
} else if (HasCudaEnvironment(530 /*min_cuda_architecture*/) ||
|
||||
dml_ep != nullptr ||
|
||||
rocm_ep != nullptr) {
|
||||
|
|
@ -109,7 +127,6 @@ static void RunTest(
|
|||
ToFloat16(skip_input_bias_add_output_data));
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
if (dml_ep != nullptr) {
|
||||
execution_providers.push_back(DefaultDmlExecutionProvider());
|
||||
} else if (rocm_ep != nullptr) {
|
||||
|
|
@ -718,6 +735,100 @@ TEST(SkipLayerNormTest, SkipSimplifiedLayerNormBatch1_Float16) {
|
|||
true,
|
||||
true);
|
||||
}
|
||||
|
||||
TEST(SkipLayerNormTest, SkipLayerNormBatch2_Skip_Broadcast_No_Batch_Size) {
|
||||
int batch_size = 2;
|
||||
int sequence_length = 2;
|
||||
int hidden_size = 4;
|
||||
|
||||
std::vector<float> input_data = {
|
||||
0.8f, -0.5f, 0.0f, 1.f,
|
||||
0.5f, 0.2f, 0.3f, -0.6f,
|
||||
0.8f, -0.5f, 0.0f, 1.f,
|
||||
0.5f, 0.2f, 0.3f, -0.6f};
|
||||
|
||||
std::vector<float> skip_data = {
|
||||
0.1f, -0.2f, 0.3f, 1.0f,
|
||||
0.5f, 0.1f, 0.4f, 1.6f};
|
||||
|
||||
std::vector<float> gamma_data = {
|
||||
0.3f, 0.2f, 4.0f, 2.2f};
|
||||
|
||||
std::vector<float> beta_data = {
|
||||
0.2f, 0.1f, 0.4f, 1.6f};
|
||||
|
||||
std::vector<float> output_data = {
|
||||
0.28433859348297119, -0.17090578377246857, -0.92897164821624756, 4.6924152374267578,
|
||||
0.46111652255058289, -0.21333980560302734, -0.29631003737449646, 3.5148544311523438,
|
||||
0.28433859348297119, -0.17090578377246857, -0.92897164821624756, 4.6924152374267578,
|
||||
0.46111652255058289, -0.21333980560302734, -0.29631003737449646, 3.5148544311523438};
|
||||
|
||||
RunTest(input_data,
|
||||
skip_data,
|
||||
gamma_data,
|
||||
beta_data,
|
||||
std::vector<float>(),
|
||||
output_data,
|
||||
{},
|
||||
epsilon_,
|
||||
batch_size,
|
||||
sequence_length,
|
||||
hidden_size,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true);
|
||||
}
|
||||
|
||||
TEST(SkipLayerNormTest, SkipLayerNormBatch2_Skip_Broadcast_Batch_Size_1) {
|
||||
int batch_size = 2;
|
||||
int sequence_length = 2;
|
||||
int hidden_size = 4;
|
||||
|
||||
std::vector<float> input_data = {
|
||||
0.8f, -0.5f, 0.0f, 1.f,
|
||||
0.5f, 0.2f, 0.3f, -0.6f,
|
||||
0.8f, -0.5f, 0.0f, 1.f,
|
||||
0.5f, 0.2f, 0.3f, -0.6f};
|
||||
|
||||
std::vector<float> skip_data = {
|
||||
0.1f, -0.2f, 0.3f, 1.0f,
|
||||
0.5f, 0.1f, 0.4f, 1.6f};
|
||||
|
||||
std::vector<float> gamma_data = {
|
||||
0.3f, 0.2f, 4.0f, 2.2f};
|
||||
|
||||
std::vector<float> beta_data = {
|
||||
0.2f, 0.1f, 0.4f, 1.6f};
|
||||
|
||||
std::vector<float> output_data = {
|
||||
0.28433859348297119, -0.17090578377246857, -0.92897164821624756, 4.6924152374267578,
|
||||
0.46111652255058289, -0.21333980560302734, -0.29631003737449646, 3.5148544311523438,
|
||||
0.28433859348297119, -0.17090578377246857, -0.92897164821624756, 4.6924152374267578,
|
||||
0.46111652255058289, -0.21333980560302734, -0.29631003737449646, 3.5148544311523438};
|
||||
|
||||
RunTest(input_data,
|
||||
skip_data,
|
||||
gamma_data,
|
||||
beta_data,
|
||||
std::vector<float>(),
|
||||
output_data,
|
||||
{},
|
||||
epsilon_,
|
||||
batch_size,
|
||||
sequence_length,
|
||||
hidden_size,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace test
|
||||
|
|
|
|||
Loading…
Reference in a new issue