Add EmbedLayerNormalization and SkipLayerNormalization ops for bert optimization (#2012)

* Add Embed Layer Normalization and Skip Layer Normalization ops for bert optimization.

* add float16 test for skiplayernorm

* Add test for EmbedLayerNormalization op

* fix cpu build error

* fix build warning

* update HasCudaEnvironment function

* handle cuda error
This commit is contained in:
Tianlei Wu 2019-10-07 17:29:43 -07:00 committed by GitHub
parent 8f7657fa32
commit b2c1937523
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 1341 additions and 2 deletions

View file

@ -0,0 +1,123 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/common.h"
#include "core/providers/cuda/cudnn_common.h"
#include "core/framework/tensorprotoutils.h"
#include "onnx/defs/tensor_proto_util.h"
#include "embed_layer_norm.h"
#include "embed_layer_norm_impl.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
EmbedLayerNormalization, \
kMSDomain, \
1, \
T, \
kCudaExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
EmbedLayerNorm<T>);
REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
using namespace ONNX_NAMESPACE;
template <typename T>
EmbedLayerNorm<T>::EmbedLayerNorm(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) {
}
template <typename T>
Status EmbedLayerNorm<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* input_ids = context->Input<Tensor>(0);
const Tensor* segment_ids = context->Input<Tensor>(1);
const Tensor* mask = context->Input<Tensor>(2);
const Tensor* word_embedding = context->Input<Tensor>(3);
const Tensor* position_embedding = context->Input<Tensor>(4);
const Tensor* segment_embedding = context->Input<Tensor>(5);
const Tensor* gamma = context->Input<Tensor>(6);
const Tensor* beta = context->Input<Tensor>(7);
if (input_ids->Shape() != segment_ids->Shape() || input_ids->Shape() != mask->Shape()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 0, 1 and 2 shall have same shape");
}
const auto input_dims = input_ids->Shape().GetDims();
if (input_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"input_ids is expected to have 2 dimensions, got ", input_dims.size());
}
const auto word_embedding_dims = word_embedding->Shape().GetDims();
if (word_embedding_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"word_embedding is expected to have 2 dimensions, got ", word_embedding_dims.size());
}
const auto position_embedding_dims = position_embedding->Shape().GetDims();
if (position_embedding_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"position_embedding is expected to have 2 dimensions, got ", position_embedding_dims.size());
}
const auto segment_embedding_dims = segment_embedding->Shape().GetDims();
if (segment_embedding_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"segment_embedding is expected to have 2 dimensions, got ", segment_embedding_dims.size());
}
if (word_embedding_dims[1] != position_embedding_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"word_embedding and position_embedding shall have same dimension 1");
}
int64_t hidden_size = word_embedding_dims[1];
std::vector<int64_t> out_dims;
out_dims.reserve(3);
out_dims.push_back(input_dims[0]);
out_dims.push_back(input_dims[1]);
out_dims.push_back(hidden_size);
TensorShape output_shape(out_dims);
Tensor* output = context->Output(0, output_shape);
std::vector<int64_t> mask_index_dims;
mask_index_dims.push_back(input_dims[0]);
TensorShape mask_index_shape(mask_index_dims);
Tensor* mask_index = context->Output(1, mask_index_shape);
int batch_size = static_cast<int>(input_dims[0]);
int sequence_length = static_cast<int>(input_dims[1]);
size_t element_size = sizeof(T);
if (!LaunchEmbedLayerNormKernel(
output->template MutableData<T>(),
mask_index->template MutableData<int32_t>(),
input_ids->template Data<int32_t>(),
segment_ids->template Data<int32_t>(),
mask->template Data<int32_t>(),
gamma->template Data<T>(),
beta->template Data<T>(),
word_embedding->template Data<T>(),
position_embedding->template Data<T>(),
segment_embedding->template Data<T>(),
static_cast<int>(hidden_size),
batch_size,
sequence_length,
element_size)) {
// Get last error to reset it to cudaSuccess.
CUDA_CALL(cudaGetLastError());
return Status(common::ONNXRUNTIME, common::FAIL);
}
return Status::OK();
}
} //namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,24 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/providers/cuda/cuda_common.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
using namespace onnxruntime::cuda;
template <typename T>
class EmbedLayerNorm final : public CudaKernel {
public:
EmbedLayerNorm(const OpKernelInfo& op_kernel_info);
Status ComputeInternal(OpKernelContext* ctx) const override;
};
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,206 @@
/*
The implementation of this file is based on embLayerNorm plugin in TensorRT demo:
https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/
Copyright 2019 NVIDIA Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "layer_norm.cuh"
#include "embed_layer_norm_impl.h"
using namespace onnxruntime::cuda;
using namespace cub;
namespace onnxruntime {
namespace contrib {
namespace cuda {
template <unsigned TPB>
__global__ void MaskIndexKernelSmall(int sequence_length, const int* mask, int* mask_index) {
using BlockReduce = cub::BlockReduce<int, TPB>;
__shared__ typename BlockReduce::TempStorage temp_storage;
// blockIdx.x is b
const int offset = blockIdx.x * sequence_length; // batch strides of sequence_length
cub::Min min;
int thread_data(sequence_length);
const int idx = offset + threadIdx.x;
if (threadIdx.x < sequence_length) {
const int val = mask[idx];
if (val == 0) // masked position: report thread idx
{
thread_data = threadIdx.x;
}
}
const auto min_index = BlockReduce(temp_storage).Reduce(thread_data, min);
if (threadIdx.x == 0) {
mask_index[blockIdx.x] = min_index;
}
}
template <unsigned TPB>
__global__ void MaskIndexKernel(int sequence_length, const int* mask, int* mask_index) {
using BlockReduce = cub::BlockReduce<int, TPB>;
__shared__ typename BlockReduce::TempStorage temp_storage;
// blockIdx.x is b
const int offset = blockIdx.x * sequence_length; // batch strides of sequence_length
cub::Min min;
int thread_data(sequence_length);
for (int i = threadIdx.x; i < sequence_length; i += TPB) {
const int idx = offset + i;
const int val = mask[idx];
if (val == 0) // masked position: report thread idx
{
thread_data = min(thread_data, i);
}
}
const auto min_index = BlockReduce(temp_storage).Reduce(thread_data, min);
if (threadIdx.x == 0) {
mask_index[blockIdx.x] = min_index;
}
}
inline bool ComputeMaskIndex(cudaStream_t stream, const int sequence_length, const int batch_size, const int* mask, int* mask_index) {
// Mask idx is of length batch_size and assumes the valid region is contiguous starting
// from the beginning of the sequence
// Assume n = batch_size x sequence_length
if (sequence_length <= 32) {
MaskIndexKernelSmall<32><<<batch_size, 32, 0, stream>>>(sequence_length, mask, mask_index);
} else if (sequence_length <= 128) {
MaskIndexKernelSmall<128><<<batch_size, 128, 0, stream>>>(sequence_length, mask, mask_index);
} else if (sequence_length == 384) {
MaskIndexKernelSmall<384><<<batch_size, 384, 0, stream>>>(sequence_length, mask, mask_index);
} else {
MaskIndexKernel<256><<<batch_size, 256, 0, stream>>>(sequence_length, mask, mask_index);
}
return CUDA_CALL(cudaPeekAtLastError());
}
template <typename T, unsigned TPB>
__global__ void EmbedLayerNormKernel(
int hidden_size, const int* input_ids, const int* segment_ids, const T* beta, const T* gamma,
const T* word_embedding, const T* position_embedding, const T* segment_embedding,
T* output) {
KeyValuePairSum pair_sum;
// 1. lookup word and segment of the block
// blockIdx.x = position in the sequence
// blockIdx.y = batch
// gridDim.x = sequence_length
// gridDim.y = batch_size
__shared__ int word_id;
__shared__ int segment_id;
const T rld = T(1.f) / T(hidden_size);
const int sequence_position = blockIdx.y * gridDim.x + blockIdx.x;
if (threadIdx.x == 0) {
word_id = input_ids[sequence_position];
segment_id = segment_ids[sequence_position];
}
__syncthreads();
// 2. load pos/segment/word embeddings and add them toghether
// offset into embeddings is given by word_id * hidden_size
const int position_offset = blockIdx.x * hidden_size;
const int word_offset = word_id * hidden_size;
const int segment_offset = segment_id * hidden_size;
// the output offset is given by b * (sequence_length * hidden_size) + s * hidden_size
const int output_offset = sequence_position * hidden_size;
cub::KeyValuePair<T, T> thread_data(0, 0);
for (int it = threadIdx.x; it < hidden_size; it += TPB) {
const T w(word_embedding[word_offset + it]);
const T t(segment_embedding[segment_offset + it]);
const T p(position_embedding[position_offset + it]);
const T val = w + t + p;
output[output_offset + it] = val;
const T rldval = rld * val;
thread_data = pair_sum(thread_data, cub::KeyValuePair<T, T>(rldval, rldval * val));
}
// 3. layer norm on the sum
LayerNorm<T, TPB>(thread_data, hidden_size, output_offset, beta, gamma, output);
}
template <typename T>
bool EmbedSkipLayerNorm(
cudaStream_t stream, int hidden_size, int batch_size, int sequence_length,
const int* input_ids, const int* segment_ids, const T* beta, const T* gamma,
const T* word_embedding, const T* position_embedding, const T* segment_embedding,
T* output) {
constexpr int tpb = 256;
const dim3 grid(sequence_length, batch_size, 1);
const dim3 block(tpb, 1, 1);
EmbedLayerNormKernel<T, tpb>
<<<grid, block, 0, stream>>>(hidden_size, input_ids, segment_ids, beta, gamma, word_embedding, position_embedding, segment_embedding, output);
return CUDA_CALL(cudaPeekAtLastError());
}
bool LaunchEmbedLayerNormKernel(
void* output,
void* mask_index,
const int* input_ids,
const int* segment_ids,
const int* input_mask,
const void* gamma,
const void* beta,
const void* word_embedding,
const void* position_embedding,
const void* segment_embedding,
const int hidden_size,
int batch_size,
int sequence_length,
const size_t element_size) {
const cudaStream_t stream = nullptr; // default stream
if (!ComputeMaskIndex(stream, sequence_length, hidden_size, input_mask, static_cast<int*>(mask_index))) {
return false;
}
if (element_size == 2) {
return EmbedSkipLayerNorm<half>(
stream, hidden_size, batch_size, sequence_length, input_ids, segment_ids,
reinterpret_cast<const half*>(beta), reinterpret_cast<const half*>(gamma),
reinterpret_cast<const half*>(word_embedding), reinterpret_cast<const half*>(position_embedding), reinterpret_cast<const half*>(segment_embedding),
reinterpret_cast<half*>(output));
} else {
return EmbedSkipLayerNorm<float>(
stream, hidden_size, batch_size, sequence_length, input_ids, segment_ids,
reinterpret_cast<const float*>(beta), reinterpret_cast<const float*>(gamma),
reinterpret_cast<const float*>(word_embedding), reinterpret_cast<const float*>(position_embedding), reinterpret_cast<const float*>(segment_embedding),
reinterpret_cast<float*>(output));
}
}
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,26 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
namespace onnxruntime {
namespace contrib {
namespace cuda {
bool LaunchEmbedLayerNormKernel(void* output, // output tensor
void* mask_index, // output mask index
const int* input_ids, // input word IDs
const int* segment_ids, // input segment IDs
const int* input_mask, // input mask
const void* gamma, // weight for layer normalization
const void* beta, // bias for layer normalization
const void* word_embedding, // weights for word embeddings
const void* position_embedding, // weights for position embeddings
const void* segment_embedding, // weights for segment (like sentence) embeddings
const int hidden_size, // hidden size (that is head_size * num_heads)
int batch_size, // batch size
int sequence_length, // sequence length
const size_t element_size); // size of element in output tensor. 2 for half, 4 for float.
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,138 @@
/*
The implementation of this file is based on bert plugins in TensorRT demo:
https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/
Copyright 2019 NVIDIA Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/cu_inc/common.cuh"
#include "core/providers/cuda/shared_inc/cuda_call.h"
#include <cuda_fp16.h>
#include <cublas_v2.h>
#include <cub/cub.cuh>
using namespace onnxruntime::cuda;
using namespace cub;
namespace onnxruntime {
namespace contrib {
namespace cuda {
template <typename T>
__device__ inline T Rsqrt(const T& x);
template <>
__device__ inline float Rsqrt(const float& x) {
return rsqrtf(x);
}
template <>
__device__ inline half Rsqrt(const half& x) {
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
return hrsqrt(x);
#else
return half(rsqrtf(float(x)));
#endif
}
__device__ inline half2 AddHalf2(const half2 a, const half2 b) {
#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
return __hadd2(a, b);
#else
return __halves2half2(__hadd(a.x, b.x), __hadd(a.y, b.y));
#endif
}
struct KeyValuePairSum {
__device__ inline cub::KeyValuePair<float, float> operator()(const cub::KeyValuePair<float, float>& a, const cub::KeyValuePair<float, float>& b) {
return cub::KeyValuePair<float, float>(a.key + b.key, a.value + b.value);
}
__device__ inline cub::KeyValuePair<half, half> operator()(const cub::KeyValuePair<half, half>& a, const cub::KeyValuePair<half, half>& b) {
const half2 a2 = __halves2half2(a.key, a.value);
const half2 b2 = __halves2half2(b.key, b.value);
const half2 res = AddHalf2(a2, b2);
return cub::KeyValuePair<half, half>(res.x, res.y);
}
__device__ inline cub::KeyValuePair<half2, half2> operator()(const cub::KeyValuePair<half2, half2>& a, const cub::KeyValuePair<half2, half2>& b) {
return cub::KeyValuePair<half2, half2>(AddHalf2(a.key, b.key), AddHalf2(a.value, b.value));
}
};
template <typename T, int TPB>
__device__ inline void LayerNorm(
const cub::KeyValuePair<T, T>& thread_data, const int ld, const int offset, const T* beta, const T* gamma, T* output) {
// Assuming thread_data is already divided by ld
using BlockReduce = cub::BlockReduce<cub::KeyValuePair<T, T>, TPB>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T mu; // mean
__shared__ T rsigma; // 1 / std.dev.
KeyValuePairSum pair_sum;
const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum);
if (threadIdx.x == 0) {
mu = sum_kv.key;
rsigma = Rsqrt(sum_kv.value - mu * mu);
}
__syncthreads();
for (int i = threadIdx.x; i < ld; i += TPB) {
const int idx = offset + i;
const T val = output[idx];
const T g(gamma[i]);
const T b(beta[i]);
output[idx] = g * (val - mu) * rsigma + b;
}
}
template <typename T, int TPB>
__device__ inline void LayerNormSmall(const T val, const cub::KeyValuePair<T, T>& thread_data, const int ld, const int idx,
const T* beta, const T* gamma, T* output) {
// Assuming thread_data is already divided by ld
// Small settings: the block covers the leading dimension TPB >= ld. The input
// value is available in a register
using BlockReduce = cub::BlockReduce<cub::KeyValuePair<T, T>, TPB>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T mu; // mean
__shared__ T rsigma; // 1 / std.dev.
KeyValuePairSum pair_sum;
const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum);
if (threadIdx.x == 0) {
mu = sum_kv.key;
rsigma = Rsqrt(sum_kv.value - mu * mu);
}
__syncthreads();
if (threadIdx.x < ld) {
const T g(gamma[threadIdx.x]);
const T b(beta[threadIdx.x]);
output[idx] = g * (val - mu) * rsigma + b;
}
}
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,100 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/common.h"
#include "core/providers/cuda/cudnn_common.h"
#include "core/framework/tensorprotoutils.h"
#include "onnx/defs/tensor_proto_util.h"
#include "skip_layer_norm.h"
#include "skip_layer_norm_impl.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
SkipLayerNormalization, \
kMSDomain, \
1, \
T, \
kCudaExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
SkipLayerNorm<T>);
REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
using namespace ONNX_NAMESPACE;
template <typename T>
SkipLayerNorm<T>::SkipLayerNorm(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) {
}
template <typename T>
Status SkipLayerNorm<T>::ComputeInternal(OpKernelContext* ctx) const {
const Tensor* input = ctx->Input<Tensor>(0);
const Tensor* skip = ctx->Input<Tensor>(1);
const Tensor* gamma = ctx->Input<Tensor>(2);
const Tensor* beta = ctx->Input<Tensor>(3);
Tensor* output = ctx->Output(0, input->Shape());
const auto input_dims = input->Shape().GetDims();
if (input_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"input is expected to have 3 dimensions, got ", input_dims.size());
}
if (input->Shape() != skip->Shape()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"skip is expected to have same shape as input");
}
const auto gamma_dims = gamma->Shape().GetDims();
if (gamma_dims.size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"gamma is expected to have 1 dimension, got ", gamma_dims.size());
}
if (gamma_dims[0] != input_dims[2]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Last dimension of gamma and input does not match");
}
const auto beta_dims = beta->Shape().GetDims();
if (beta_dims.size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"beta is expected to have 1 dimension, got ", beta_dims.size());
}
if (beta_dims[0] != input_dims[2]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Last dimension of beta and input does not match");
}
int batch_size = static_cast<int>(input_dims[0]);
int sequence_length = static_cast<int>(input_dims[1]);
int hidden_size = static_cast<int>(input_dims[2]);
int element_count = batch_size * sequence_length * hidden_size;
size_t element_size = sizeof(T);
if (!LaunchSkipLayerNormKernel(
output->template MutableData<T>(),
input->template Data<T>(),
skip->template Data<T>(),
gamma->template Data<T>(),
beta->template Data<T>(),
batch_size,
hidden_size,
element_count,
element_size)) {
// Get last error to reset it to cudaSuccess.
CUDA_CALL(cudaGetLastError());
return Status(common::ONNXRUNTIME, common::FAIL);
}
return Status::OK();
}
} //namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,24 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/providers/cuda/cuda_common.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
using namespace onnxruntime::cuda;
template <typename T>
class SkipLayerNorm final : public CudaKernel {
public:
SkipLayerNorm(const OpKernelInfo& op_kernel_info);
Status ComputeInternal(OpKernelContext* context) const override;
};
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,129 @@
/*
The implementation of this file is based on skipLayerNorm plugin in TensorRT demo:
https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/
Copyright 2019 NVIDIA Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "layer_norm.cuh"
#include "skip_layer_norm_impl.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
template <typename T, unsigned TPB>
__global__ void SkipLayerNormKernelSmall(
const int ld, const T* input, const T* skip, const T* beta, const T* gamma, T* output) {
const T reverse_ld = T(1) / T(ld);
const int offset = blockIdx.x * ld;
KeyValuePairSum pair_sum;
// reduce x and x^2
cub::KeyValuePair<T, T> thread_data(0, 0);
const int idx = offset + threadIdx.x;
T val = 0;
if (threadIdx.x < ld) {
val = input[idx] + skip[idx];
const T rldval = reverse_ld * val;
thread_data = pair_sum(thread_data, cub::KeyValuePair<T, T>(rldval, rldval * val));
}
LayerNormSmall<T, TPB>(val, thread_data, ld, idx, beta, gamma, output);
}
template <typename T, unsigned TPB>
__global__ void SkipLayerNormKernel(
const int ld, const T* input, const T* skip, const T* beta, const T* gamma, T* output) {
const T reverse_ld = T(1) / T(ld);
const int offset = blockIdx.x * ld;
KeyValuePairSum pair_sum;
// reduce x and x^2
cub::KeyValuePair<T, T> thread_data(0, 0);
for (int i = threadIdx.x; i < ld; i += TPB) {
const int idx = offset + i;
const T val = input[idx] + skip[idx];
const T rldval = reverse_ld * val;
thread_data = pair_sum(thread_data, cub::KeyValuePair<T, T>(rldval, rldval * val));
output[idx] = val;
}
LayerNorm<T, TPB>(thread_data, ld, offset, beta, gamma, output);
}
template <typename T>
bool ComputeSkipLayerNorm(
cudaStream_t stream, const int ld, const int n, const T* input, const T* skip,
const T* beta, const T* gamma, T* output) {
// this must be true because n is the total size of the tensor
assert(n % ld == 0);
const int grid_size = n / ld;
if (ld <= 32) {
constexpr int block_size = 32;
SkipLayerNormKernelSmall<T, block_size>
<<<grid_size, block_size, 0, stream>>>(ld, input, skip, beta, gamma, output);
} else if (ld <= 128) {
constexpr int block_size = 128;
SkipLayerNormKernelSmall<T, block_size>
<<<grid_size, block_size, 0, stream>>>(ld, input, skip, beta, gamma, output);
} else if (ld == 384) {
constexpr int block_size = 384;
SkipLayerNormKernelSmall<T, block_size>
<<<grid_size, block_size, 0, stream>>>(ld, input, skip, beta, gamma, output);
} else {
constexpr int block_size = 256;
SkipLayerNormKernel<T, block_size><<<grid_size, block_size, 0, stream>>>(ld, input, skip, beta, gamma, output);
}
return CUDA_CALL(cudaPeekAtLastError());
}
bool LaunchSkipLayerNormKernel(
void* output,
const void* input,
const void* skip,
const void* gamma,
const void* beta,
const int batch_size,
const int hidden_size,
const int element_count,
const size_t element_size) {
// use default stream
const cudaStream_t stream = nullptr;
if (element_size == 2) {
return ComputeSkipLayerNorm(
stream, hidden_size, element_count,
reinterpret_cast<const half*>(input), reinterpret_cast<const half*>(skip),
reinterpret_cast<const half*>(beta), reinterpret_cast<const half*>(gamma),
reinterpret_cast<half*>(output));
} else {
return ComputeSkipLayerNorm(
stream, hidden_size, element_count,
reinterpret_cast<const float*>(input), reinterpret_cast<const float*>(skip),
reinterpret_cast<const float*>(beta), reinterpret_cast<const float*>(gamma),
reinterpret_cast<float*>(output));
}
}
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,24 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
namespace onnxruntime {
namespace contrib {
namespace cuda {
bool LaunchSkipLayerNormKernel(
void* output, // output tensor
const void* input, // input tensor
const void* skip, // skip tensor
const void* gamma, // Layer normalization gamma tensor
const void* beta, // Layer normalization beta tensor
const int batch_size, // batch size (B)
const int hidden_size, // hidden size, it is the leading dimension (ld)
const int element_count, // number of elements in input tensor
const size_t element_size // element size of input tensor
);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -26,6 +26,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, EmbedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ImageScaler);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ImageScaler);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ImageScaler);
@ -35,6 +37,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ThresholdedRelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ThresholdedRelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ThresholdedRelu);
@ -61,6 +65,8 @@ void RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, EmbedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ImageScaler)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ImageScaler)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ImageScaler)>,
@ -70,6 +76,8 @@ void RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ThresholdedRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ThresholdedRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ThresholdedRelu)>,

View file

@ -209,6 +209,124 @@ void RegisterBertSchemas() {
.TypeConstraint("M", {"tensor(int32)"}, "Constrain mask index to integer types")
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);
ONNX_CONTRIB_OPERATOR_SCHEMA(EmbedLayerNormalization)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL)
.SetDoc("Embedding Layer Normalization")
.Input(0, "input_ids", "2D words IDs with shape (batch_size, sequence_length)", "T1")
.Input(1, "segment_ids", "2D segment IDs with shape (batch_size, sequence_length)", "T1")
.Input(2, "mask", "2D attention mask with shape (batch_size, sequence_length)", "T1")
.Input(3, "word_embedding", "2D with shape (,hidden_size)", "T")
.Input(4, "position_embedding", "2D with shape (, hidden_size)", "T")
.Input(5, "segment_embedding", "2D with shape (, hidden_size)", "T")
.Input(6, "gamma", "1D gamma tensor for layer normalization with shape (hidden_size)", "T")
.Input(7, "beta", "1D beta tensor for layer normalization with shape (hidden_size)", "T")
.Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", "T")
.Output(1, "mask_index", "1D mask_index tensor with shape (batch_size)", "T1")
.TypeConstraint("T1", {"tensor(int32)"}, "Constrain input and output integer tensors types")
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output float tensors types.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 3, 0);
propagateElemTypeFromInputToOutput(ctx, 0, 1);
if (!hasInputShape(ctx, 0))
return;
auto& input_ids_shape = getInputShape(ctx, 0);
auto& input_ids_dims = input_ids_shape.dim();
auto& segment_ids_shape = getInputShape(ctx, 1);
auto& segment_ids_dims = segment_ids_shape.dim();
auto& mask_shape = getInputShape(ctx, 2);
auto& mask_dims = mask_shape.dim();
if (input_ids_dims.size() != 2 || segment_ids_dims.size() != 2 || mask_dims.size() != 2) {
fail_shape_inference("Inputs 0, 1 and 2 shall be 2 dimensions");
}
if (input_ids_shape.dim(1).has_dim_value() && segment_ids_shape.dim(1).has_dim_value() && mask_shape.dim(1).has_dim_value()) {
if (input_ids_shape.dim(1).dim_value() != segment_ids_shape.dim(1).dim_value() || input_ids_shape.dim(1).dim_value() != mask_shape.dim(1).dim_value()) {
fail_shape_inference("Inputs 0, 1 and 2 shall have same value in dimension 1");
}
} else {
fail_shape_inference("Inputs 0, 1 and 2 shall have value in dimension 1");
}
// get hidden_size from the last dimension of embedding
auto& word_embedding_shape = getInputShape(ctx, 3);
auto& word_embedding_dims = word_embedding_shape.dim();
if (word_embedding_dims.size() != 2 || !word_embedding_dims[1].has_dim_value()) {
fail_shape_inference("word_embedding should have 2 dimensions and dimension size is known.");
}
int64_t hidden_size = word_embedding_shape.dim(1).dim_value();
auto& position_embedding_shape = getInputShape(ctx, 4);
auto& position_embedding_dims = position_embedding_shape.dim();
if (position_embedding_dims.size() != 2) {
fail_shape_inference("position_embedding should have 2 dimensions");
}
if (position_embedding_shape.dim(1).dim_value() != hidden_size) {
fail_shape_inference("The last dimension of word_embedding and position_embedding does not match.");
}
auto& segment_embedding_shape = getInputShape(ctx, 5);
auto& segment_embedding_dims = segment_embedding_shape.dim();
if (segment_embedding_dims.size() != 2) {
fail_shape_inference("segment_embedding should have 2 dimensions");
}
if (segment_embedding_shape.dim(1).dim_value() != hidden_size) {
fail_shape_inference("The last dimension of word_embedding and segment_embedding does not match.");
}
auto& gamma_shape = getInputShape(ctx, 6);
auto& gamma_dims = gamma_shape.dim();
if (gamma_dims.size() != 1) {
fail_shape_inference("gamma should have 1 dimension");
}
if (gamma_shape.dim(0).dim_value() != hidden_size) {
fail_shape_inference("The last dimension of word_embedding and gamma does not match.");
}
auto& beta_shape = getInputShape(ctx, 7);
auto& beta_dims = beta_shape.dim();
if (beta_dims.size() != 1) {
fail_shape_inference("beta should have 1 dimension");
}
if (beta_shape.dim(0).dim_value() != hidden_size) {
fail_shape_inference("The last dimension of word_embedding and beta does not match.");
}
// mask shape is (batch_size, sequence_length), output shape is (batch_size, sequence_length, hidden_size)
ONNX_NAMESPACE::TensorShapeProto output_shape;
for (auto& dim : mask_dims) {
*output_shape.add_dim() = dim;
}
if (hidden_size > 0) {
output_shape.add_dim();
output_shape.mutable_dim(2)->set_dim_value(hidden_size);
}
updateOutputShape(ctx, 0, output_shape);
// mask_index shape is (batch_size)
ONNX_NAMESPACE::TensorShapeProto mask_index_shape;
*mask_index_shape.add_dim() = mask_shape.dim(0);
updateOutputShape(ctx, 1, mask_index_shape);
});
ONNX_CONTRIB_OPERATOR_SCHEMA(SkipLayerNormalization)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL)
.SetDoc("Skip and Layer Normalization Fusion")
.Input(0, "input", "3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T")
.Input(1, "skip", "3D skip tensor with shape (batch_size, sequence_length, hidden_size)", "T")
.Input(2, "gamma", "1D input tensor with shape (hidden_size)", "T")
.Input(3, "beta", "1D skip tensor with shape (hidden_size", "T")
.Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", "T")
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or half tensors.")
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);
}
void RegisterContribSchemas() {
@ -1792,8 +1910,6 @@ Example 4:
}
});
RegisterBertSchemas();
// Register the NCHWc schemas if supported by the platform.
if (MlasNchwcGetBlockSize() > 1) {
RegisterNchwcSchemas();
@ -1812,6 +1928,8 @@ Example 4:
"Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);
RegisterBertSchemas();
#ifdef MICROSOFT_INTERNAL
// register internal ops
RegisterInternalSchemas();

View file

@ -0,0 +1,273 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "gtest/gtest.h"
#include "test/common/tensor_op_test_utils.h"
#include "test/common/cuda_op_test_utils.h"
#include "test/providers/provider_test_utils.h"
namespace onnxruntime {
namespace test {
static void RunTest(
const std::vector<int32_t>& input_ids_data,
const std::vector<int32_t>& segment_ids_data,
const std::vector<int32_t>& mask_data,
const std::vector<float>& word_embedding_data,
const std::vector<float>& position_embedding_data,
const std::vector<float>& segment_embedding_data,
const std::vector<float>& gamma_data,
const std::vector<float>& beta_data,
const std::vector<float>& output_data,
const std::vector<int32_t>& mask_index_data,
int batch_size,
int sequence_length,
int hidden_size,
bool use_float16 = false) {
int min_cuda_architecture = use_float16 ? 530 : 0;
if (HasCudaEnvironment(min_cuda_architecture)) {
// Input and output shapes
// Input 0 - input_ids : (batch_size, sequence_length)
// Input 1 - segment_ids : (batch_size, sequence_length)
// Input 2 - mask : (batch_size, sequence_length)
// Input 3 - word_embedding : (,hidden_size)
// Input 4 - position_embedding : (,hidden_size)
// Input 5 - segment_embedding : (,hidden_size)
// Input 6 - gamma : (hidden_size)
// Input 7 - beta : (hidden_size)
// Output 0 - output : (batch_size, sequence_length, hidden_size)
// Output 1 - mask_index : (batch_size)
std::vector<int64_t> input_ids_dims = {batch_size, sequence_length};
std::vector<int64_t> segment_ids_dims = {batch_size, sequence_length};
std::vector<int64_t> mask_dims = {batch_size, sequence_length};
ASSERT_TRUE(word_embedding_data.size() % hidden_size == 0);
std::vector<int64_t> word_embedding_dims = {static_cast<int64_t>(word_embedding_data.size() / hidden_size), hidden_size};
ASSERT_TRUE(position_embedding_data.size() % hidden_size == 0);
std::vector<int64_t> position_embedding_dims = {static_cast<int64_t>(position_embedding_data.size() / hidden_size), hidden_size};
ASSERT_TRUE(segment_embedding_data.size() % hidden_size == 0);
std::vector<int64_t> segment_embedding_dims = {static_cast<int64_t>(segment_embedding_data.size() / hidden_size), hidden_size};
std::vector<int64_t> gamma_dims = {hidden_size};
std::vector<int64_t> beta_dims = gamma_dims;
std::vector<int64_t> output_dims = {batch_size, sequence_length, hidden_size};
std::vector<int64_t> mask_index_dims = {batch_size};
OpTester tester("EmbedLayerNormalization", 1, onnxruntime::kMSDomain);
tester.AddInput<int32_t>("input_ids", input_ids_dims, input_ids_data);
tester.AddInput<int32_t>("segment_ids", segment_ids_dims, segment_ids_data);
tester.AddInput<int32_t>("mask", mask_dims, mask_data);
if (use_float16) {
tester.AddInput<MLFloat16>("word_embedding", word_embedding_dims, ToFloat16(word_embedding_data));
tester.AddInput<MLFloat16>("position_embedding", position_embedding_dims, ToFloat16(position_embedding_data));
tester.AddInput<MLFloat16>("segment_embedding", segment_embedding_dims, ToFloat16(segment_embedding_data));
tester.AddInput<MLFloat16>("gamma", gamma_dims, ToFloat16(gamma_data));
tester.AddInput<MLFloat16>("beta", beta_dims, ToFloat16(beta_data));
tester.AddOutput<MLFloat16>("output", output_dims, ToFloat16(output_data));
} else {
tester.AddInput<float>("word_embedding", word_embedding_dims, word_embedding_data);
tester.AddInput<float>("position_embedding", position_embedding_dims, position_embedding_data);
tester.AddInput<float>("segment_embedding", segment_embedding_dims, segment_embedding_data);
tester.AddInput<float>("gamma", gamma_dims, gamma_data);
tester.AddInput<float>("beta", beta_dims, beta_data);
tester.AddOutput<float>("output", output_dims, output_data);
}
tester.AddOutput<int32_t>("mask_index", mask_index_dims, mask_index_data);
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}
TEST(EmbedLayerNormTest, EmbedLayerNormBatch1) {
int batch_size = 1;
int sequence_length = 2;
int hidden_size = 4;
std::vector<int32_t> input_ids_data = {
1, 3};
std::vector<int32_t> segment_ids_data = {
0, 1};
std::vector<int32_t> mask_data = {
1, 1};
std::vector<float> word_embedding_data = {
0.2f, 0.1f, 0.4f, -0.6f,
0.3f, 0.2f, 0.5f, 0.6f,
0.6f, 0.7f, 0.0f, -0.1f,
0.8f, 0.6f, 0.9f, 1.2f,
0.1f, 0.3f, 0.5f, 0.9f,
1.0f, -2.0f, 1.1f, 0.8f};
std::vector<float> position_embedding_data = {
0.1f, 0.1f, 0.4f, 0.6f,
0.6f, 0.0f, 0.8f, 0.6f,
0.3f, 0.9f, -2.0f, 0.8f};
std::vector<float> segment_embedding_data = {
0.3f, 0.4f, 0.9f, 0.1f,
0.7f, 0.3f, 0.5f, 0.2f};
std::vector<float> gamma_data = {
0.25f, 0.15f, 0.45f, -0.66f};
std::vector<float> beta_data = {
0.6f, 0.2f, 0.5f, -0.6f};
std::vector<float> output_data = {
0.36917170882225037, 0.061503000557422638, 1.1598974466323853, -0.85092413425445557,
0.74301940202713013, -0.057434864342212677, 0.84324657917022705, -0.85171419382095337};
std::vector<int32_t> mask_index_data = {
2};
RunTest(input_ids_data,
segment_ids_data,
mask_data,
word_embedding_data,
position_embedding_data,
segment_embedding_data,
gamma_data,
beta_data,
output_data,
mask_index_data,
batch_size,
sequence_length,
hidden_size);
}
TEST(EmbedLayerNormTest, EmbedLayerNormBatch1_Float16) {
int batch_size = 1;
int sequence_length = 2;
int hidden_size = 4;
std::vector<int32_t> input_ids_data = {
1, 3};
std::vector<int32_t> segment_ids_data = {
0, 1};
std::vector<int32_t> mask_data = {
1, 1};
std::vector<float> word_embedding_data = {
0.2f, 0.1f, 0.4f, -0.6f,
0.3f, 0.2f, 0.5f, 0.6f,
0.6f, 0.7f, 0.0f, -0.1f,
0.8f, 0.6f, 0.9f, 1.2f,
0.1f, 0.3f, 0.5f, 0.9f,
1.0f, -2.0f, 1.1f, 0.8f};
std::vector<float> position_embedding_data = {
0.1f, 0.1f, 0.4f, 0.6f,
0.6f, 0.0f, 0.8f, 0.6f,
0.3f, 0.9f, -2.0f, 0.8f};
std::vector<float> segment_embedding_data = {
0.3f, 0.4f, 0.9f, 0.1f,
0.7f, 0.3f, 0.5f, 0.2f};
std::vector<float> gamma_data = {
0.25f, 0.15f, 0.45f, -0.66f};
std::vector<float> beta_data = {
0.6f, 0.2f, 0.5f, -0.6f};
std::vector<float> output_data = {
0.369873046875, 0.061676025390625, 1.1591796875, -0.8515625,
0.7431640625, -0.057586669921875, 0.84326171875, -0.8525390625};
std::vector<int32_t> mask_index_data = {
2};
RunTest(input_ids_data,
segment_ids_data,
mask_data,
word_embedding_data,
position_embedding_data,
segment_embedding_data,
gamma_data,
beta_data,
output_data,
mask_index_data,
batch_size,
sequence_length,
hidden_size,
true);
}
TEST(EmbedLayerNormTest, EmbedLayerNormBatch2) {
int batch_size = 3;
int sequence_length = 2;
int hidden_size = 4;
std::vector<int32_t> input_ids_data = {
1, 3,
1, 3,
2, 0};
std::vector<int32_t> segment_ids_data = {
0, 1,
0, 1,
0, 0};
std::vector<int32_t> mask_data = {
1, 1,
1, 1,
1, 0};
std::vector<float> word_embedding_data = {
0.2f, 0.1f, 0.4f, -0.6f,
0.3f, 0.2f, 0.5f, 0.6f,
0.6f, 0.7f, 0.0f, -0.1f,
0.8f, 0.6f, 0.9f, 1.2f,
0.1f, 0.3f, 0.5f, 0.9f,
1.0f, -2.0f, 1.1f, 0.8f};
std::vector<float> position_embedding_data = {
0.1f, 0.1f, 0.4f, 0.6f,
0.6f, 0.0f, 0.8f, 0.6f,
0.3f, 0.9f, -2.0f, 0.8f};
std::vector<float> segment_embedding_data = {
0.3f, 0.4f, 0.9f, 0.1f,
0.7f, 0.3f, 0.5f, 0.2f};
std::vector<float> gamma_data = {
0.25f, 0.15f, 0.45f, -0.66f};
std::vector<float> beta_data = {
0.6f, 0.2f, 0.5f, -0.6f};
std::vector<float> output_data = {
0.36917170882225037, 0.061503000557422638, 1.1598974466323853, -0.85092413425445557,
0.74301940202713013, -0.057434864342212677, 0.84324657917022705, -0.85171419382095337,
0.36917170882225037, 0.061503000557422638, 1.1598974466323853, -0.85092413425445557,
0.74301940202713013, -0.057434864342212677, 0.84324657917022705, -0.85171419382095337,
0.57668739557266235, 0.2979130744934082, 0.96158987283706665, 0.44627034664154053,
0.64977931976318359, 0.11039737612009048, 1.1869535446166992, 0.14469735324382782};
std::vector<int32_t> mask_index_data = {
2, 2, 1};
RunTest(input_ids_data,
segment_ids_data,
mask_data,
word_embedding_data,
position_embedding_data,
segment_embedding_data,
gamma_data,
beta_data,
output_data,
mask_index_data,
batch_size,
sequence_length,
hidden_size);
}
} // namespace test
} // namespace onnxruntime

View file

@ -0,0 +1,146 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "gtest/gtest.h"
#include "test/common/tensor_op_test_utils.h"
#include "test/common/cuda_op_test_utils.h"
#include "test/providers/provider_test_utils.h"
namespace onnxruntime {
namespace test {
static void RunTest(
const std::vector<float>& input_data,
const std::vector<float>& skip_data,
const std::vector<float>& gamma_data,
const std::vector<float>& beta_data,
const std::vector<float>& output_data,
int batch_size,
int sequence_length,
int hidden_size,
bool use_float16 = false) {
int min_cuda_architecture = use_float16 ? 530 : 0;
if (HasCudaEnvironment(min_cuda_architecture)) {
OpTester test("SkipLayerNormalization", 1, onnxruntime::kMSDomain);
// Input and output shapes
// Input 0 - input: (batch_size, sequence_length, hidden_size)
// Input 1 - skip : (batch_size, sequence_length, hidden_size)
// Input 2 - gamma: (hidden_size)
// Input 3 - beta : (hidden_size)
// Output : (batch_size, sequence_length, hidden_size)
std::vector<int64_t> input_dims = {batch_size, sequence_length, hidden_size};
std::vector<int64_t> skip_dims = input_dims;
std::vector<int64_t> gamma_dims = {hidden_size};
std::vector<int64_t> beta_dims = gamma_dims;
std::vector<int64_t> output_dims = input_dims;
if (use_float16) {
test.AddInput<MLFloat16>("input", input_dims, ToFloat16(input_data));
test.AddInput<MLFloat16>("skip", skip_dims, ToFloat16(skip_data));
test.AddInput<MLFloat16>("gamma", gamma_dims, ToFloat16(gamma_data));
test.AddInput<MLFloat16>("beta", beta_dims, ToFloat16(beta_data));
test.AddOutput<MLFloat16>("output", output_dims, ToFloat16(output_data));
} else {
test.AddInput<float>("input", input_dims, input_data);
test.AddInput<float>("skip", skip_dims, skip_data);
test.AddInput<float>("gamma", gamma_dims, gamma_data);
test.AddInput<float>("beta", beta_dims, beta_data);
test.AddOutput<float>("output", output_dims, output_data);
}
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}
TEST(SkipLayerNormTest, SkipLayerNormBatch1) {
int batch_size = 1;
int sequence_length = 2;
int hidden_size = 4;
std::vector<float> input_data = {
0.8f, -0.5f, 0.0f, 1.f,
0.5f, 0.2f, 0.3f, -0.6f};
std::vector<float> skip_data = {
0.1f, -0.2f, 0.3f, 1.0f,
0.5f, 0.1f, 0.4f, 1.6f};
std::vector<float> gamma_data = {
0.3f, 0.2f, 4.0f, 2.2f};
std::vector<float> beta_data = {
0.2f, 0.1f, 0.4f, 1.6f};
std::vector<float> output_data = {
0.28433859348297119, -0.17090578377246857, -0.92897164821624756, 4.6924152374267578,
0.46111652255058289, -0.21333980560302734, -0.29631003737449646, 3.5148544311523438};
RunTest(input_data, skip_data, gamma_data, beta_data, output_data,
batch_size, sequence_length, hidden_size);
}
TEST(SkipLayerNormTest, SkipLayerNormBatch1_Float16) {
int batch_size = 1;
int sequence_length = 2;
int hidden_size = 4;
std::vector<float> input_data = {
0.8f, -0.5f, 0.0f, 1.f,
0.5f, 0.2f, 0.3f, -0.6f};
std::vector<float> skip_data = {
0.1f, -0.2f, 0.3f, 1.0f,
0.5f, 0.1f, 0.4f, 1.6f};
std::vector<float> gamma_data = {
0.3f, 0.2f, 4.0f, 2.2f};
std::vector<float> beta_data = {
0.2f, 0.1f, 0.4f, 1.6f};
std::vector<float> output_data = {
0.28433859348297119, -0.17090578377246857, -0.92897164821624756, 4.6924152374267578,
0.46111652255058289, -0.21333980560302734, -0.29631003737449646, 3.5148544311523438};
RunTest(input_data, skip_data, gamma_data, beta_data, output_data,
batch_size, sequence_length, hidden_size, true);
}
TEST(SkipLayerNormTest, SkipLayerNormBatch2) {
int batch_size = 2;
int sequence_length = 2;
int hidden_size = 4;
std::vector<float> input_data = {
0.8f, -0.5f, 0.0f, 1.f,
0.5f, 0.2f, 0.3f, -0.6f,
0.8f, -0.5f, 0.0f, 1.f,
0.5f, 0.2f, 0.3f, -0.6f};
std::vector<float> skip_data = {
0.1f, -0.2f, 0.3f, 1.0f,
0.5f, 0.1f, 0.4f, 1.6f,
1.8f, -0.3f, 0.0f, 1.f,
-0.5f, 0.4f, 0.8f, -0.6f};
std::vector<float> gamma_data = {
0.3f, 0.2f, 4.0f, 2.2f};
std::vector<float> beta_data = {
0.2f, 0.1f, 0.4f, 1.6f};
std::vector<float> output_data = {
0.28433859348297119, -0.17090578377246857, -0.92897164821624756, 4.6924152374267578,
0.46111652255058289, -0.21333980560302734, -0.29631003737449646, 3.5148544311523438,
0.55470430850982666, -0.15080101788043976, -2.3229825496673584, 3.255286693572998,
0.15631480515003204, 0.21066918969154358, 4.9432611465454102, -1.7957965135574341};
RunTest(input_data, skip_data, gamma_data, beta_data, output_data,
batch_size, sequence_length, hidden_size);
}
} // namespace test
} // namespace onnxruntime