From b2c193752376e6c3a802f8165093d8602a4ce4ba Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 7 Oct 2019 17:29:43 -0700 Subject: [PATCH] Add EmbedLayerNormalization and SkipLayerNormalization ops for bert optimization (#2012) * Add Embed Layer Normalization and Skip Layer Normalization ops for bert optimization. * add float16 test for skiplayernorm * Add test for EmbedLayerNormalization op * fix cpu build error * fix build warning * update HasCudaEnvironment function * handle cuda error --- .../contrib_ops/cuda/bert/embed_layer_norm.cc | 123 ++++++++ .../contrib_ops/cuda/bert/embed_layer_norm.h | 24 ++ .../cuda/bert/embed_layer_norm_impl.cu | 206 +++++++++++++ .../cuda/bert/embed_layer_norm_impl.h | 26 ++ .../contrib_ops/cuda/bert/layer_norm.cuh | 138 +++++++++ .../contrib_ops/cuda/bert/skip_layer_norm.cc | 100 +++++++ .../contrib_ops/cuda/bert/skip_layer_norm.h | 24 ++ .../cuda/bert/skip_layer_norm_impl.cu | 129 +++++++++ .../cuda/bert/skip_layer_norm_impl.h | 24 ++ .../contrib_ops/cuda_contrib_kernels.cc | 8 + .../core/graph/contrib_ops/contrib_defs.cc | 122 +++++++- .../contrib_ops/embedlayernorm_op_test.cc | 273 ++++++++++++++++++ .../test/contrib_ops/skiplayernorm_op_test.cc | 146 ++++++++++ 13 files changed, 1341 insertions(+), 2 deletions(-) create mode 100644 onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.cc create mode 100644 onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.h create mode 100644 onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.h create mode 100644 onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh create mode 100644 onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc create mode 100644 onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.h create mode 100644 onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h create mode 100644 onnxruntime/test/contrib_ops/embedlayernorm_op_test.cc create mode 100644 onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc diff --git a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.cc b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.cc new file mode 100644 index 0000000000..bf88fa9ca0 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.cc @@ -0,0 +1,123 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/cuda/cudnn_common.h" +#include "core/framework/tensorprotoutils.h" +#include "onnx/defs/tensor_proto_util.h" +#include "embed_layer_norm.h" +#include "embed_layer_norm_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + EmbedLayerNormalization, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + EmbedLayerNorm); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) + +using namespace ONNX_NAMESPACE; + +template +EmbedLayerNorm::EmbedLayerNorm(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) { +} + +template +Status EmbedLayerNorm::ComputeInternal(OpKernelContext* context) const { + const Tensor* input_ids = context->Input(0); + const Tensor* segment_ids = context->Input(1); + const Tensor* mask = context->Input(2); + const Tensor* word_embedding = context->Input(3); + const Tensor* position_embedding = context->Input(4); + const Tensor* segment_embedding = context->Input(5); + const Tensor* gamma = context->Input(6); + const Tensor* beta = context->Input(7); + + if (input_ids->Shape() != segment_ids->Shape() || input_ids->Shape() != mask->Shape()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 0, 1 and 2 shall have same shape"); + } + + const auto input_dims = input_ids->Shape().GetDims(); + if (input_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "input_ids is expected to have 2 dimensions, got ", input_dims.size()); + } + + const auto word_embedding_dims = word_embedding->Shape().GetDims(); + if (word_embedding_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "word_embedding is expected to have 2 dimensions, got ", word_embedding_dims.size()); + } + + const auto position_embedding_dims = position_embedding->Shape().GetDims(); + if (position_embedding_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "position_embedding is expected to have 2 dimensions, got ", position_embedding_dims.size()); + } + + const auto segment_embedding_dims = segment_embedding->Shape().GetDims(); + if (segment_embedding_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "segment_embedding is expected to have 2 dimensions, got ", segment_embedding_dims.size()); + } + + if (word_embedding_dims[1] != position_embedding_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "word_embedding and position_embedding shall have same dimension 1"); + } + int64_t hidden_size = word_embedding_dims[1]; + + std::vector out_dims; + out_dims.reserve(3); + out_dims.push_back(input_dims[0]); + out_dims.push_back(input_dims[1]); + out_dims.push_back(hidden_size); + TensorShape output_shape(out_dims); + Tensor* output = context->Output(0, output_shape); + + std::vector mask_index_dims; + mask_index_dims.push_back(input_dims[0]); + TensorShape mask_index_shape(mask_index_dims); + Tensor* mask_index = context->Output(1, mask_index_shape); + + int batch_size = static_cast(input_dims[0]); + int sequence_length = static_cast(input_dims[1]); + size_t element_size = sizeof(T); + + if (!LaunchEmbedLayerNormKernel( + output->template MutableData(), + mask_index->template MutableData(), + input_ids->template Data(), + segment_ids->template Data(), + mask->template Data(), + gamma->template Data(), + beta->template Data(), + word_embedding->template Data(), + position_embedding->template Data(), + segment_embedding->template Data(), + static_cast(hidden_size), + batch_size, + sequence_length, + element_size)) { + // Get last error to reset it to cudaSuccess. + CUDA_CALL(cudaGetLastError()); + return Status(common::ONNXRUNTIME, common::FAIL); + } + + return Status::OK(); +} + +} //namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.h b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.h new file mode 100644 index 0000000000..c9ff25c0cd --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cuda/cuda_common.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +template +class EmbedLayerNorm final : public CudaKernel { + public: + EmbedLayerNorm(const OpKernelInfo& op_kernel_info); + Status ComputeInternal(OpKernelContext* ctx) const override; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu new file mode 100644 index 0000000000..3dd4e0488b --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu @@ -0,0 +1,206 @@ +/* + The implementation of this file is based on embLayerNorm plugin in TensorRT demo: + https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ + +Copyright 2019 NVIDIA Corporation + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "layer_norm.cuh" +#include "embed_layer_norm_impl.h" + +using namespace onnxruntime::cuda; +using namespace cub; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +__global__ void MaskIndexKernelSmall(int sequence_length, const int* mask, int* mask_index) { + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + // blockIdx.x is b + const int offset = blockIdx.x * sequence_length; // batch strides of sequence_length + + cub::Min min; + int thread_data(sequence_length); + + const int idx = offset + threadIdx.x; + if (threadIdx.x < sequence_length) { + const int val = mask[idx]; + if (val == 0) // masked position: report thread idx + { + thread_data = threadIdx.x; + } + } + + const auto min_index = BlockReduce(temp_storage).Reduce(thread_data, min); + + if (threadIdx.x == 0) { + mask_index[blockIdx.x] = min_index; + } +} + +template +__global__ void MaskIndexKernel(int sequence_length, const int* mask, int* mask_index) { + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + // blockIdx.x is b + const int offset = blockIdx.x * sequence_length; // batch strides of sequence_length + + cub::Min min; + int thread_data(sequence_length); + + for (int i = threadIdx.x; i < sequence_length; i += TPB) { + const int idx = offset + i; + const int val = mask[idx]; + if (val == 0) // masked position: report thread idx + { + thread_data = min(thread_data, i); + } + } + + const auto min_index = BlockReduce(temp_storage).Reduce(thread_data, min); + + if (threadIdx.x == 0) { + mask_index[blockIdx.x] = min_index; + } +} + +inline bool ComputeMaskIndex(cudaStream_t stream, const int sequence_length, const int batch_size, const int* mask, int* mask_index) { + // Mask idx is of length batch_size and assumes the valid region is contiguous starting + // from the beginning of the sequence + + // Assume n = batch_size x sequence_length + if (sequence_length <= 32) { + MaskIndexKernelSmall<32><<>>(sequence_length, mask, mask_index); + } else if (sequence_length <= 128) { + MaskIndexKernelSmall<128><<>>(sequence_length, mask, mask_index); + } else if (sequence_length == 384) { + MaskIndexKernelSmall<384><<>>(sequence_length, mask, mask_index); + } else { + MaskIndexKernel<256><<>>(sequence_length, mask, mask_index); + } + + return CUDA_CALL(cudaPeekAtLastError()); +} + +template +__global__ void EmbedLayerNormKernel( + int hidden_size, const int* input_ids, const int* segment_ids, const T* beta, const T* gamma, + const T* word_embedding, const T* position_embedding, const T* segment_embedding, + T* output) { + KeyValuePairSum pair_sum; + // 1. lookup word and segment of the block + // blockIdx.x = position in the sequence + // blockIdx.y = batch + // gridDim.x = sequence_length + // gridDim.y = batch_size + __shared__ int word_id; + __shared__ int segment_id; + + const T rld = T(1.f) / T(hidden_size); + const int sequence_position = blockIdx.y * gridDim.x + blockIdx.x; + if (threadIdx.x == 0) { + word_id = input_ids[sequence_position]; + segment_id = segment_ids[sequence_position]; + } + __syncthreads(); + + // 2. load pos/segment/word embeddings and add them toghether + // offset into embeddings is given by word_id * hidden_size + const int position_offset = blockIdx.x * hidden_size; + const int word_offset = word_id * hidden_size; + const int segment_offset = segment_id * hidden_size; + // the output offset is given by b * (sequence_length * hidden_size) + s * hidden_size + const int output_offset = sequence_position * hidden_size; + + cub::KeyValuePair thread_data(0, 0); + + for (int it = threadIdx.x; it < hidden_size; it += TPB) { + const T w(word_embedding[word_offset + it]); + const T t(segment_embedding[segment_offset + it]); + const T p(position_embedding[position_offset + it]); + const T val = w + t + p; + + output[output_offset + it] = val; + const T rldval = rld * val; + thread_data = pair_sum(thread_data, cub::KeyValuePair(rldval, rldval * val)); + } + + // 3. layer norm on the sum + LayerNorm(thread_data, hidden_size, output_offset, beta, gamma, output); +} + +template +bool EmbedSkipLayerNorm( + cudaStream_t stream, int hidden_size, int batch_size, int sequence_length, + const int* input_ids, const int* segment_ids, const T* beta, const T* gamma, + const T* word_embedding, const T* position_embedding, const T* segment_embedding, + T* output) { + constexpr int tpb = 256; + const dim3 grid(sequence_length, batch_size, 1); + const dim3 block(tpb, 1, 1); + + EmbedLayerNormKernel + <<>>(hidden_size, input_ids, segment_ids, beta, gamma, word_embedding, position_embedding, segment_embedding, output); + + return CUDA_CALL(cudaPeekAtLastError()); +} + +bool LaunchEmbedLayerNormKernel( + void* output, + void* mask_index, + const int* input_ids, + const int* segment_ids, + const int* input_mask, + const void* gamma, + const void* beta, + const void* word_embedding, + const void* position_embedding, + const void* segment_embedding, + const int hidden_size, + int batch_size, + int sequence_length, + const size_t element_size) { + const cudaStream_t stream = nullptr; // default stream + + if (!ComputeMaskIndex(stream, sequence_length, hidden_size, input_mask, static_cast(mask_index))) { + return false; + } + + if (element_size == 2) { + return EmbedSkipLayerNorm( + stream, hidden_size, batch_size, sequence_length, input_ids, segment_ids, + reinterpret_cast(beta), reinterpret_cast(gamma), + reinterpret_cast(word_embedding), reinterpret_cast(position_embedding), reinterpret_cast(segment_embedding), + reinterpret_cast(output)); + } else { + return EmbedSkipLayerNorm( + stream, hidden_size, batch_size, sequence_length, input_ids, segment_ids, + reinterpret_cast(beta), reinterpret_cast(gamma), + reinterpret_cast(word_embedding), reinterpret_cast(position_embedding), reinterpret_cast(segment_embedding), + reinterpret_cast(output)); + } +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.h b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.h new file mode 100644 index 0000000000..e21131bb1a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +bool LaunchEmbedLayerNormKernel(void* output, // output tensor + void* mask_index, // output mask index + const int* input_ids, // input word IDs + const int* segment_ids, // input segment IDs + const int* input_mask, // input mask + const void* gamma, // weight for layer normalization + const void* beta, // bias for layer normalization + const void* word_embedding, // weights for word embeddings + const void* position_embedding, // weights for position embeddings + const void* segment_embedding, // weights for segment (like sentence) embeddings + const int hidden_size, // hidden size (that is head_size * num_heads) + int batch_size, // batch size + int sequence_length, // sequence length + const size_t element_size); // size of element in output tensor. 2 for half, 4 for float. + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh b/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh new file mode 100644 index 0000000000..8a639633f0 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh @@ -0,0 +1,138 @@ +/* + The implementation of this file is based on bert plugins in TensorRT demo: + https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ + +Copyright 2019 NVIDIA Corporation + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/shared_inc/cuda_call.h" +#include +#include +#include + +using namespace onnxruntime::cuda; +using namespace cub; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +__device__ inline T Rsqrt(const T& x); + +template <> +__device__ inline float Rsqrt(const float& x) { + return rsqrtf(x); +} + +template <> +__device__ inline half Rsqrt(const half& x) { +#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__) + return hrsqrt(x); +#else + return half(rsqrtf(float(x))); +#endif +} + +__device__ inline half2 AddHalf2(const half2 a, const half2 b) { +#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__) + return __hadd2(a, b); +#else + return __halves2half2(__hadd(a.x, b.x), __hadd(a.y, b.y)); +#endif +} + +struct KeyValuePairSum { + __device__ inline cub::KeyValuePair operator()(const cub::KeyValuePair& a, const cub::KeyValuePair& b) { + return cub::KeyValuePair(a.key + b.key, a.value + b.value); + } + + __device__ inline cub::KeyValuePair operator()(const cub::KeyValuePair& a, const cub::KeyValuePair& b) { + const half2 a2 = __halves2half2(a.key, a.value); + const half2 b2 = __halves2half2(b.key, b.value); + const half2 res = AddHalf2(a2, b2); + return cub::KeyValuePair(res.x, res.y); + } + + __device__ inline cub::KeyValuePair operator()(const cub::KeyValuePair& a, const cub::KeyValuePair& b) { + return cub::KeyValuePair(AddHalf2(a.key, b.key), AddHalf2(a.value, b.value)); + } +}; + +template +__device__ inline void LayerNorm( + const cub::KeyValuePair& thread_data, const int ld, const int offset, const T* beta, const T* gamma, T* output) { + // Assuming thread_data is already divided by ld + + using BlockReduce = cub::BlockReduce, TPB>; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ T mu; // mean + __shared__ T rsigma; // 1 / std.dev. + + KeyValuePairSum pair_sum; + const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum); + + if (threadIdx.x == 0) { + mu = sum_kv.key; + rsigma = Rsqrt(sum_kv.value - mu * mu); + } + __syncthreads(); + + for (int i = threadIdx.x; i < ld; i += TPB) { + const int idx = offset + i; + const T val = output[idx]; + const T g(gamma[i]); + const T b(beta[i]); + output[idx] = g * (val - mu) * rsigma + b; + } +} + +template +__device__ inline void LayerNormSmall(const T val, const cub::KeyValuePair& thread_data, const int ld, const int idx, + const T* beta, const T* gamma, T* output) { + // Assuming thread_data is already divided by ld + // Small settings: the block covers the leading dimension TPB >= ld. The input + // value is available in a register + + using BlockReduce = cub::BlockReduce, TPB>; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ T mu; // mean + __shared__ T rsigma; // 1 / std.dev. + + KeyValuePairSum pair_sum; + const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, pair_sum); + + if (threadIdx.x == 0) { + mu = sum_kv.key; + rsigma = Rsqrt(sum_kv.value - mu * mu); + } + __syncthreads(); + + if (threadIdx.x < ld) { + const T g(gamma[threadIdx.x]); + const T b(beta[threadIdx.x]); + output[idx] = g * (val - mu) * rsigma + b; + } +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc new file mode 100644 index 0000000000..20891cacad --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc @@ -0,0 +1,100 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/cuda/cudnn_common.h" +#include "core/framework/tensorprotoutils.h" +#include "onnx/defs/tensor_proto_util.h" +#include "skip_layer_norm.h" +#include "skip_layer_norm_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + SkipLayerNormalization, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + SkipLayerNorm); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) + +using namespace ONNX_NAMESPACE; + +template +SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) { +} + +template +Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const { + const Tensor* input = ctx->Input(0); + const Tensor* skip = ctx->Input(1); + const Tensor* gamma = ctx->Input(2); + const Tensor* beta = ctx->Input(3); + Tensor* output = ctx->Output(0, input->Shape()); + + const auto input_dims = input->Shape().GetDims(); + if (input_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "input is expected to have 3 dimensions, got ", input_dims.size()); + } + + if (input->Shape() != skip->Shape()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "skip is expected to have same shape as input"); + } + + const auto gamma_dims = gamma->Shape().GetDims(); + if (gamma_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "gamma is expected to have 1 dimension, got ", gamma_dims.size()); + } + if (gamma_dims[0] != input_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Last dimension of gamma and input does not match"); + } + + const auto beta_dims = beta->Shape().GetDims(); + if (beta_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "beta is expected to have 1 dimension, got ", beta_dims.size()); + } + if (beta_dims[0] != input_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Last dimension of beta and input does not match"); + } + + int batch_size = static_cast(input_dims[0]); + int sequence_length = static_cast(input_dims[1]); + int hidden_size = static_cast(input_dims[2]); + int element_count = batch_size * sequence_length * hidden_size; + size_t element_size = sizeof(T); + + if (!LaunchSkipLayerNormKernel( + output->template MutableData(), + input->template Data(), + skip->template Data(), + gamma->template Data(), + beta->template Data(), + batch_size, + hidden_size, + element_count, + element_size)) { + // Get last error to reset it to cudaSuccess. + CUDA_CALL(cudaGetLastError()); + return Status(common::ONNXRUNTIME, common::FAIL); + } + + return Status::OK(); +} + +} //namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.h b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.h new file mode 100644 index 0000000000..c8b754a947 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cuda/cuda_common.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +template +class SkipLayerNorm final : public CudaKernel { + public: + SkipLayerNorm(const OpKernelInfo& op_kernel_info); + Status ComputeInternal(OpKernelContext* context) const override; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu new file mode 100644 index 0000000000..7541c19f0f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu @@ -0,0 +1,129 @@ +/* + The implementation of this file is based on skipLayerNorm plugin in TensorRT demo: + https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ + +Copyright 2019 NVIDIA Corporation + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "layer_norm.cuh" +#include "skip_layer_norm_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +__global__ void SkipLayerNormKernelSmall( + const int ld, const T* input, const T* skip, const T* beta, const T* gamma, T* output) { + const T reverse_ld = T(1) / T(ld); + const int offset = blockIdx.x * ld; + + KeyValuePairSum pair_sum; + // reduce x and x^2 + cub::KeyValuePair thread_data(0, 0); + const int idx = offset + threadIdx.x; + T val = 0; + + if (threadIdx.x < ld) { + val = input[idx] + skip[idx]; + const T rldval = reverse_ld * val; + thread_data = pair_sum(thread_data, cub::KeyValuePair(rldval, rldval * val)); + } + + LayerNormSmall(val, thread_data, ld, idx, beta, gamma, output); +} + +template +__global__ void SkipLayerNormKernel( + const int ld, const T* input, const T* skip, const T* beta, const T* gamma, T* output) { + const T reverse_ld = T(1) / T(ld); + const int offset = blockIdx.x * ld; + + KeyValuePairSum pair_sum; + // reduce x and x^2 + cub::KeyValuePair thread_data(0, 0); + + for (int i = threadIdx.x; i < ld; i += TPB) { + const int idx = offset + i; + const T val = input[idx] + skip[idx]; + const T rldval = reverse_ld * val; + thread_data = pair_sum(thread_data, cub::KeyValuePair(rldval, rldval * val)); + output[idx] = val; + } + + LayerNorm(thread_data, ld, offset, beta, gamma, output); +} + +template +bool ComputeSkipLayerNorm( + cudaStream_t stream, const int ld, const int n, const T* input, const T* skip, + const T* beta, const T* gamma, T* output) { + // this must be true because n is the total size of the tensor + assert(n % ld == 0); + const int grid_size = n / ld; + + if (ld <= 32) { + constexpr int block_size = 32; + SkipLayerNormKernelSmall + <<>>(ld, input, skip, beta, gamma, output); + } else if (ld <= 128) { + constexpr int block_size = 128; + SkipLayerNormKernelSmall + <<>>(ld, input, skip, beta, gamma, output); + } else if (ld == 384) { + constexpr int block_size = 384; + SkipLayerNormKernelSmall + <<>>(ld, input, skip, beta, gamma, output); + } else { + constexpr int block_size = 256; + SkipLayerNormKernel<<>>(ld, input, skip, beta, gamma, output); + } + return CUDA_CALL(cudaPeekAtLastError()); +} + +bool LaunchSkipLayerNormKernel( + void* output, + const void* input, + const void* skip, + const void* gamma, + const void* beta, + const int batch_size, + const int hidden_size, + const int element_count, + const size_t element_size) { + // use default stream + const cudaStream_t stream = nullptr; + + if (element_size == 2) { + return ComputeSkipLayerNorm( + stream, hidden_size, element_count, + reinterpret_cast(input), reinterpret_cast(skip), + reinterpret_cast(beta), reinterpret_cast(gamma), + reinterpret_cast(output)); + } else { + return ComputeSkipLayerNorm( + stream, hidden_size, element_count, + reinterpret_cast(input), reinterpret_cast(skip), + reinterpret_cast(beta), reinterpret_cast(gamma), + reinterpret_cast(output)); + } +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h new file mode 100644 index 0000000000..b0d46ce20f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +bool LaunchSkipLayerNormKernel( + void* output, // output tensor + const void* input, // input tensor + const void* skip, // skip tensor + const void* gamma, // Layer normalization gamma tensor + const void* beta, // Layer normalization beta tensor + const int batch_size, // batch size (B) + const int hidden_size, // hidden size, it is the leading dimension (ld) + const int element_count, // number of elements in input tensor + const size_t element_size // element size of input tensor +); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda_contrib_kernels.cc index 560c7ee4f2..b4af013dab 100644 --- a/onnxruntime/contrib_ops/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda_contrib_kernels.cc @@ -26,6 +26,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, EmbedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ImageScaler); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ImageScaler); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ImageScaler); @@ -35,6 +37,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ThresholdedRelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ThresholdedRelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ThresholdedRelu); @@ -61,6 +65,8 @@ void RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -70,6 +76,8 @@ void RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 9500784652..5ace5c396c 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -209,6 +209,124 @@ void RegisterBertSchemas() { .TypeConstraint("M", {"tensor(int32)"}, "Constrain mask index to integer types") .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput); + ONNX_CONTRIB_OPERATOR_SCHEMA(EmbedLayerNormalization) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL) + .SetDoc("Embedding Layer Normalization") + .Input(0, "input_ids", "2D words IDs with shape (batch_size, sequence_length)", "T1") + .Input(1, "segment_ids", "2D segment IDs with shape (batch_size, sequence_length)", "T1") + .Input(2, "mask", "2D attention mask with shape (batch_size, sequence_length)", "T1") + .Input(3, "word_embedding", "2D with shape (,hidden_size)", "T") + .Input(4, "position_embedding", "2D with shape (, hidden_size)", "T") + .Input(5, "segment_embedding", "2D with shape (, hidden_size)", "T") + .Input(6, "gamma", "1D gamma tensor for layer normalization with shape (hidden_size)", "T") + .Input(7, "beta", "1D beta tensor for layer normalization with shape (hidden_size)", "T") + .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", "T") + .Output(1, "mask_index", "1D mask_index tensor with shape (batch_size)", "T1") + .TypeConstraint("T1", {"tensor(int32)"}, "Constrain input and output integer tensors types") + .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output float tensors types.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 3, 0); + propagateElemTypeFromInputToOutput(ctx, 0, 1); + if (!hasInputShape(ctx, 0)) + return; + + auto& input_ids_shape = getInputShape(ctx, 0); + auto& input_ids_dims = input_ids_shape.dim(); + + auto& segment_ids_shape = getInputShape(ctx, 1); + auto& segment_ids_dims = segment_ids_shape.dim(); + + auto& mask_shape = getInputShape(ctx, 2); + auto& mask_dims = mask_shape.dim(); + + if (input_ids_dims.size() != 2 || segment_ids_dims.size() != 2 || mask_dims.size() != 2) { + fail_shape_inference("Inputs 0, 1 and 2 shall be 2 dimensions"); + } + + if (input_ids_shape.dim(1).has_dim_value() && segment_ids_shape.dim(1).has_dim_value() && mask_shape.dim(1).has_dim_value()) { + if (input_ids_shape.dim(1).dim_value() != segment_ids_shape.dim(1).dim_value() || input_ids_shape.dim(1).dim_value() != mask_shape.dim(1).dim_value()) { + fail_shape_inference("Inputs 0, 1 and 2 shall have same value in dimension 1"); + } + } else { + fail_shape_inference("Inputs 0, 1 and 2 shall have value in dimension 1"); + } + + // get hidden_size from the last dimension of embedding + auto& word_embedding_shape = getInputShape(ctx, 3); + auto& word_embedding_dims = word_embedding_shape.dim(); + if (word_embedding_dims.size() != 2 || !word_embedding_dims[1].has_dim_value()) { + fail_shape_inference("word_embedding should have 2 dimensions and dimension size is known."); + } + int64_t hidden_size = word_embedding_shape.dim(1).dim_value(); + + auto& position_embedding_shape = getInputShape(ctx, 4); + auto& position_embedding_dims = position_embedding_shape.dim(); + if (position_embedding_dims.size() != 2) { + fail_shape_inference("position_embedding should have 2 dimensions"); + } + if (position_embedding_shape.dim(1).dim_value() != hidden_size) { + fail_shape_inference("The last dimension of word_embedding and position_embedding does not match."); + } + + auto& segment_embedding_shape = getInputShape(ctx, 5); + auto& segment_embedding_dims = segment_embedding_shape.dim(); + if (segment_embedding_dims.size() != 2) { + fail_shape_inference("segment_embedding should have 2 dimensions"); + } + if (segment_embedding_shape.dim(1).dim_value() != hidden_size) { + fail_shape_inference("The last dimension of word_embedding and segment_embedding does not match."); + } + + auto& gamma_shape = getInputShape(ctx, 6); + auto& gamma_dims = gamma_shape.dim(); + if (gamma_dims.size() != 1) { + fail_shape_inference("gamma should have 1 dimension"); + } + if (gamma_shape.dim(0).dim_value() != hidden_size) { + fail_shape_inference("The last dimension of word_embedding and gamma does not match."); + } + + auto& beta_shape = getInputShape(ctx, 7); + auto& beta_dims = beta_shape.dim(); + if (beta_dims.size() != 1) { + fail_shape_inference("beta should have 1 dimension"); + } + if (beta_shape.dim(0).dim_value() != hidden_size) { + fail_shape_inference("The last dimension of word_embedding and beta does not match."); + } + + // mask shape is (batch_size, sequence_length), output shape is (batch_size, sequence_length, hidden_size) + ONNX_NAMESPACE::TensorShapeProto output_shape; + for (auto& dim : mask_dims) { + *output_shape.add_dim() = dim; + } + if (hidden_size > 0) { + output_shape.add_dim(); + output_shape.mutable_dim(2)->set_dim_value(hidden_size); + } + updateOutputShape(ctx, 0, output_shape); + + // mask_index shape is (batch_size) + ONNX_NAMESPACE::TensorShapeProto mask_index_shape; + *mask_index_shape.add_dim() = mask_shape.dim(0); + updateOutputShape(ctx, 1, mask_index_shape); + }); + + ONNX_CONTRIB_OPERATOR_SCHEMA(SkipLayerNormalization) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL) + .SetDoc("Skip and Layer Normalization Fusion") + .Input(0, "input", "3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T") + .Input(1, "skip", "3D skip tensor with shape (batch_size, sequence_length, hidden_size)", "T") + .Input(2, "gamma", "1D input tensor with shape (hidden_size)", "T") + .Input(3, "beta", "1D skip tensor with shape (hidden_size", "T") + .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", "T") + .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or half tensors.") + .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput); + } void RegisterContribSchemas() { @@ -1792,8 +1910,6 @@ Example 4: } }); - RegisterBertSchemas(); - // Register the NCHWc schemas if supported by the platform. if (MlasNchwcGetBlockSize() > 1) { RegisterNchwcSchemas(); @@ -1812,6 +1928,8 @@ Example 4: "Constrain input and output types to float tensors.") .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput); + RegisterBertSchemas(); + #ifdef MICROSOFT_INTERNAL // register internal ops RegisterInternalSchemas(); diff --git a/onnxruntime/test/contrib_ops/embedlayernorm_op_test.cc b/onnxruntime/test/contrib_ops/embedlayernorm_op_test.cc new file mode 100644 index 0000000000..d272d8b8d0 --- /dev/null +++ b/onnxruntime/test/contrib_ops/embedlayernorm_op_test.cc @@ -0,0 +1,273 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { +static void RunTest( + const std::vector& input_ids_data, + const std::vector& segment_ids_data, + const std::vector& mask_data, + const std::vector& word_embedding_data, + const std::vector& position_embedding_data, + const std::vector& segment_embedding_data, + const std::vector& gamma_data, + const std::vector& beta_data, + const std::vector& output_data, + const std::vector& mask_index_data, + int batch_size, + int sequence_length, + int hidden_size, + bool use_float16 = false) { + int min_cuda_architecture = use_float16 ? 530 : 0; + if (HasCudaEnvironment(min_cuda_architecture)) { + // Input and output shapes + // Input 0 - input_ids : (batch_size, sequence_length) + // Input 1 - segment_ids : (batch_size, sequence_length) + // Input 2 - mask : (batch_size, sequence_length) + // Input 3 - word_embedding : (,hidden_size) + // Input 4 - position_embedding : (,hidden_size) + // Input 5 - segment_embedding : (,hidden_size) + // Input 6 - gamma : (hidden_size) + // Input 7 - beta : (hidden_size) + // Output 0 - output : (batch_size, sequence_length, hidden_size) + // Output 1 - mask_index : (batch_size) + + std::vector input_ids_dims = {batch_size, sequence_length}; + std::vector segment_ids_dims = {batch_size, sequence_length}; + std::vector mask_dims = {batch_size, sequence_length}; + + ASSERT_TRUE(word_embedding_data.size() % hidden_size == 0); + std::vector word_embedding_dims = {static_cast(word_embedding_data.size() / hidden_size), hidden_size}; + + ASSERT_TRUE(position_embedding_data.size() % hidden_size == 0); + std::vector position_embedding_dims = {static_cast(position_embedding_data.size() / hidden_size), hidden_size}; + + ASSERT_TRUE(segment_embedding_data.size() % hidden_size == 0); + std::vector segment_embedding_dims = {static_cast(segment_embedding_data.size() / hidden_size), hidden_size}; + + std::vector gamma_dims = {hidden_size}; + std::vector beta_dims = gamma_dims; + std::vector output_dims = {batch_size, sequence_length, hidden_size}; + std::vector mask_index_dims = {batch_size}; + + OpTester tester("EmbedLayerNormalization", 1, onnxruntime::kMSDomain); + tester.AddInput("input_ids", input_ids_dims, input_ids_data); + tester.AddInput("segment_ids", segment_ids_dims, segment_ids_data); + tester.AddInput("mask", mask_dims, mask_data); + if (use_float16) { + tester.AddInput("word_embedding", word_embedding_dims, ToFloat16(word_embedding_data)); + tester.AddInput("position_embedding", position_embedding_dims, ToFloat16(position_embedding_data)); + tester.AddInput("segment_embedding", segment_embedding_dims, ToFloat16(segment_embedding_data)); + tester.AddInput("gamma", gamma_dims, ToFloat16(gamma_data)); + tester.AddInput("beta", beta_dims, ToFloat16(beta_data)); + tester.AddOutput("output", output_dims, ToFloat16(output_data)); + } else { + tester.AddInput("word_embedding", word_embedding_dims, word_embedding_data); + tester.AddInput("position_embedding", position_embedding_dims, position_embedding_data); + tester.AddInput("segment_embedding", segment_embedding_dims, segment_embedding_data); + tester.AddInput("gamma", gamma_dims, gamma_data); + tester.AddInput("beta", beta_dims, beta_data); + tester.AddOutput("output", output_dims, output_data); + } + tester.AddOutput("mask_index", mask_index_dims, mask_index_data); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(EmbedLayerNormTest, EmbedLayerNormBatch1) { + int batch_size = 1; + int sequence_length = 2; + int hidden_size = 4; + + std::vector input_ids_data = { + 1, 3}; + + std::vector segment_ids_data = { + 0, 1}; + + std::vector mask_data = { + 1, 1}; + + std::vector word_embedding_data = { + 0.2f, 0.1f, 0.4f, -0.6f, + 0.3f, 0.2f, 0.5f, 0.6f, + 0.6f, 0.7f, 0.0f, -0.1f, + 0.8f, 0.6f, 0.9f, 1.2f, + 0.1f, 0.3f, 0.5f, 0.9f, + 1.0f, -2.0f, 1.1f, 0.8f}; + + std::vector position_embedding_data = { + 0.1f, 0.1f, 0.4f, 0.6f, + 0.6f, 0.0f, 0.8f, 0.6f, + 0.3f, 0.9f, -2.0f, 0.8f}; + + std::vector segment_embedding_data = { + 0.3f, 0.4f, 0.9f, 0.1f, + 0.7f, 0.3f, 0.5f, 0.2f}; + + std::vector gamma_data = { + 0.25f, 0.15f, 0.45f, -0.66f}; + + std::vector beta_data = { + 0.6f, 0.2f, 0.5f, -0.6f}; + + std::vector output_data = { + 0.36917170882225037, 0.061503000557422638, 1.1598974466323853, -0.85092413425445557, + 0.74301940202713013, -0.057434864342212677, 0.84324657917022705, -0.85171419382095337}; + + std::vector mask_index_data = { + 2}; + + RunTest(input_ids_data, + segment_ids_data, + mask_data, + word_embedding_data, + position_embedding_data, + segment_embedding_data, + gamma_data, + beta_data, + output_data, + mask_index_data, + batch_size, + sequence_length, + hidden_size); +} + +TEST(EmbedLayerNormTest, EmbedLayerNormBatch1_Float16) { + int batch_size = 1; + int sequence_length = 2; + int hidden_size = 4; + + std::vector input_ids_data = { + 1, 3}; + + std::vector segment_ids_data = { + 0, 1}; + + std::vector mask_data = { + 1, 1}; + + std::vector word_embedding_data = { + 0.2f, 0.1f, 0.4f, -0.6f, + 0.3f, 0.2f, 0.5f, 0.6f, + 0.6f, 0.7f, 0.0f, -0.1f, + 0.8f, 0.6f, 0.9f, 1.2f, + 0.1f, 0.3f, 0.5f, 0.9f, + 1.0f, -2.0f, 1.1f, 0.8f}; + + std::vector position_embedding_data = { + 0.1f, 0.1f, 0.4f, 0.6f, + 0.6f, 0.0f, 0.8f, 0.6f, + 0.3f, 0.9f, -2.0f, 0.8f}; + + std::vector segment_embedding_data = { + 0.3f, 0.4f, 0.9f, 0.1f, + 0.7f, 0.3f, 0.5f, 0.2f}; + + std::vector gamma_data = { + 0.25f, 0.15f, 0.45f, -0.66f}; + + std::vector beta_data = { + 0.6f, 0.2f, 0.5f, -0.6f}; + + std::vector output_data = { + 0.369873046875, 0.061676025390625, 1.1591796875, -0.8515625, + 0.7431640625, -0.057586669921875, 0.84326171875, -0.8525390625}; + + std::vector mask_index_data = { + 2}; + + RunTest(input_ids_data, + segment_ids_data, + mask_data, + word_embedding_data, + position_embedding_data, + segment_embedding_data, + gamma_data, + beta_data, + output_data, + mask_index_data, + batch_size, + sequence_length, + hidden_size, + true); +} + +TEST(EmbedLayerNormTest, EmbedLayerNormBatch2) { + int batch_size = 3; + int sequence_length = 2; + int hidden_size = 4; + + std::vector input_ids_data = { + 1, 3, + 1, 3, + 2, 0}; + + std::vector segment_ids_data = { + 0, 1, + 0, 1, + 0, 0}; + + std::vector mask_data = { + 1, 1, + 1, 1, + 1, 0}; + + std::vector word_embedding_data = { + 0.2f, 0.1f, 0.4f, -0.6f, + 0.3f, 0.2f, 0.5f, 0.6f, + 0.6f, 0.7f, 0.0f, -0.1f, + 0.8f, 0.6f, 0.9f, 1.2f, + 0.1f, 0.3f, 0.5f, 0.9f, + 1.0f, -2.0f, 1.1f, 0.8f}; + + std::vector position_embedding_data = { + 0.1f, 0.1f, 0.4f, 0.6f, + 0.6f, 0.0f, 0.8f, 0.6f, + 0.3f, 0.9f, -2.0f, 0.8f}; + + std::vector segment_embedding_data = { + 0.3f, 0.4f, 0.9f, 0.1f, + 0.7f, 0.3f, 0.5f, 0.2f}; + + std::vector gamma_data = { + 0.25f, 0.15f, 0.45f, -0.66f}; + + std::vector beta_data = { + 0.6f, 0.2f, 0.5f, -0.6f}; + + std::vector output_data = { + 0.36917170882225037, 0.061503000557422638, 1.1598974466323853, -0.85092413425445557, + 0.74301940202713013, -0.057434864342212677, 0.84324657917022705, -0.85171419382095337, + 0.36917170882225037, 0.061503000557422638, 1.1598974466323853, -0.85092413425445557, + 0.74301940202713013, -0.057434864342212677, 0.84324657917022705, -0.85171419382095337, + 0.57668739557266235, 0.2979130744934082, 0.96158987283706665, 0.44627034664154053, + 0.64977931976318359, 0.11039737612009048, 1.1869535446166992, 0.14469735324382782}; + + std::vector mask_index_data = { + 2, 2, 1}; + + RunTest(input_ids_data, + segment_ids_data, + mask_data, + word_embedding_data, + position_embedding_data, + segment_embedding_data, + gamma_data, + beta_data, + output_data, + mask_index_data, + batch_size, + sequence_length, + hidden_size); +} +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc new file mode 100644 index 0000000000..902497f14c --- /dev/null +++ b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc @@ -0,0 +1,146 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +static void RunTest( + const std::vector& input_data, + const std::vector& skip_data, + const std::vector& gamma_data, + const std::vector& beta_data, + const std::vector& output_data, + int batch_size, + int sequence_length, + int hidden_size, + bool use_float16 = false) { + int min_cuda_architecture = use_float16 ? 530 : 0; + if (HasCudaEnvironment(min_cuda_architecture)) { + OpTester test("SkipLayerNormalization", 1, onnxruntime::kMSDomain); + + // Input and output shapes + // Input 0 - input: (batch_size, sequence_length, hidden_size) + // Input 1 - skip : (batch_size, sequence_length, hidden_size) + // Input 2 - gamma: (hidden_size) + // Input 3 - beta : (hidden_size) + // Output : (batch_size, sequence_length, hidden_size) + std::vector input_dims = {batch_size, sequence_length, hidden_size}; + std::vector skip_dims = input_dims; + std::vector gamma_dims = {hidden_size}; + std::vector beta_dims = gamma_dims; + std::vector output_dims = input_dims; + + if (use_float16) { + test.AddInput("input", input_dims, ToFloat16(input_data)); + test.AddInput("skip", skip_dims, ToFloat16(skip_data)); + test.AddInput("gamma", gamma_dims, ToFloat16(gamma_data)); + test.AddInput("beta", beta_dims, ToFloat16(beta_data)); + test.AddOutput("output", output_dims, ToFloat16(output_data)); + } else { + test.AddInput("input", input_dims, input_data); + test.AddInput("skip", skip_dims, skip_data); + test.AddInput("gamma", gamma_dims, gamma_data); + test.AddInput("beta", beta_dims, beta_data); + test.AddOutput("output", output_dims, output_data); + } + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(SkipLayerNormTest, SkipLayerNormBatch1) { + int batch_size = 1; + int sequence_length = 2; + int hidden_size = 4; + + std::vector input_data = { + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f}; + + std::vector skip_data = { + 0.1f, -0.2f, 0.3f, 1.0f, + 0.5f, 0.1f, 0.4f, 1.6f}; + + std::vector gamma_data = { + 0.3f, 0.2f, 4.0f, 2.2f}; + + std::vector beta_data = { + 0.2f, 0.1f, 0.4f, 1.6f}; + + std::vector output_data = { + 0.28433859348297119, -0.17090578377246857, -0.92897164821624756, 4.6924152374267578, + 0.46111652255058289, -0.21333980560302734, -0.29631003737449646, 3.5148544311523438}; + + RunTest(input_data, skip_data, gamma_data, beta_data, output_data, + batch_size, sequence_length, hidden_size); +} + +TEST(SkipLayerNormTest, SkipLayerNormBatch1_Float16) { + int batch_size = 1; + int sequence_length = 2; + int hidden_size = 4; + + std::vector input_data = { + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f}; + + std::vector skip_data = { + 0.1f, -0.2f, 0.3f, 1.0f, + 0.5f, 0.1f, 0.4f, 1.6f}; + + std::vector gamma_data = { + 0.3f, 0.2f, 4.0f, 2.2f}; + + std::vector beta_data = { + 0.2f, 0.1f, 0.4f, 1.6f}; + + std::vector output_data = { + 0.28433859348297119, -0.17090578377246857, -0.92897164821624756, 4.6924152374267578, + 0.46111652255058289, -0.21333980560302734, -0.29631003737449646, 3.5148544311523438}; + + RunTest(input_data, skip_data, gamma_data, beta_data, output_data, + batch_size, sequence_length, hidden_size, true); +} + +TEST(SkipLayerNormTest, SkipLayerNormBatch2) { + int batch_size = 2; + int sequence_length = 2; + int hidden_size = 4; + + std::vector input_data = { + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f, + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f}; + + std::vector skip_data = { + 0.1f, -0.2f, 0.3f, 1.0f, + 0.5f, 0.1f, 0.4f, 1.6f, + 1.8f, -0.3f, 0.0f, 1.f, + -0.5f, 0.4f, 0.8f, -0.6f}; + + std::vector gamma_data = { + 0.3f, 0.2f, 4.0f, 2.2f}; + + std::vector beta_data = { + 0.2f, 0.1f, 0.4f, 1.6f}; + + std::vector output_data = { + 0.28433859348297119, -0.17090578377246857, -0.92897164821624756, 4.6924152374267578, + 0.46111652255058289, -0.21333980560302734, -0.29631003737449646, 3.5148544311523438, + 0.55470430850982666, -0.15080101788043976, -2.3229825496673584, 3.255286693572998, + 0.15631480515003204, 0.21066918969154358, 4.9432611465454102, -1.7957965135574341}; + + RunTest(input_data, skip_data, gamma_data, beta_data, output_data, + batch_size, sequence_length, hidden_size); +} + +} // namespace test +} // namespace onnxruntime