From bb09acffed8323ea13664bbff96c4914b875b477 Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Thu, 3 Feb 2022 12:58:49 -0800 Subject: [PATCH] Transformer model CUDA EP align with CPU on corner case (#9889) * align with cpu on no input data * review comments and add tests Co-authored-by: Ubuntu --- .../contrib_ops/cuda/bert/fast_gelu.cc | 3 ++ .../contrib_ops/cuda/bert/skip_layer_norm.cc | 14 ++++----- onnxruntime/contrib_ops/cuda/layer_norm.cc | 10 +++---- .../test/contrib_ops/fastgelu_op_test.cc | 13 +++++++++ .../test/contrib_ops/layer_norm_test.cc | 5 ++++ .../test/contrib_ops/skiplayernorm_op_test.cc | 29 +++++++++++++++++++ 6 files changed, 59 insertions(+), 15 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc index 29f823d994..fe74f7efd3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc @@ -44,6 +44,9 @@ Status FastGelu::ComputeInternal(OpKernelContext* context) const { Tensor* output = context->Output(0, input->Shape()); int64_t input_length = input->Shape().Size(); + if (input_length == 0) { + return Status::OK(); + } int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size(); typedef typename ToCudaType::MappedType CudaT; diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc index d6cac35261..1baaaa69e4 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc @@ -41,12 +41,13 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const { Tensor* output = ctx->Output(0, input->Shape()); - if (input->SizeInBytes() == 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'input' has no data from upstream nodes"); + if (input->Shape() != skip->Shape()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "skip is expected to have same shape as input"); } - if (skip->SizeInBytes() == 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'skip' has no data from upstream nodes"); + if (input->Shape().Size() == 0) { + return Status::OK(); } const auto& input_dims = input->Shape().GetDims(); @@ -55,11 +56,6 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const { "input is expected to have 3 dimensions, got ", input_dims.size()); } - if (input->Shape() != skip->Shape()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "skip is expected to have same shape as input"); - } - const auto& gamma_dims = gamma->Shape().GetDims(); if (gamma_dims.size() != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, diff --git a/onnxruntime/contrib_ops/cuda/layer_norm.cc b/onnxruntime/contrib_ops/cuda/layer_norm.cc index 3f7360bd77..7cb3f07ac6 100644 --- a/onnxruntime/contrib_ops/cuda/layer_norm.cc +++ b/onnxruntime/contrib_ops/cuda/layer_norm.cc @@ -59,12 +59,6 @@ Status LayerNorm::ComputeInternal(OpKernelContext* ctx) const auto bias_data = (simplified || (nullptr == bias)) ? nullptr : reinterpret_cast(bias->template Data()); const TensorShape& x_shape = X->Shape(); - // Sometimes due to conversion issue, the input 'X' has no data which is a case that cuda kernel cannot handle. - // Provide more error infomation here instead of CUDA errors. - if (X->SizeInBytes() == 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'X' has no data from upstream nodes"); - } - const int64_t axis = HandleNegativeAxis(axis_, x_shape.NumDimensions()); int n1 = gsl::narrow(x_shape.SizeToDimension(axis)); @@ -101,6 +95,10 @@ Status LayerNorm::ComputeInternal(OpKernelContext* ctx) const inv_var_data = reinterpret_cast(var->template MutableData()); } + if (x_shape.Size() == 0) { + return Status::OK(); + } + HostApplyLayerNorm(GetDeviceProp(), Stream(), Y_data, mean_data, inv_var_data, X_data, n1, n2, epsilon_, scale_data, bias_data); return Status::OK(); } diff --git a/onnxruntime/test/contrib_ops/fastgelu_op_test.cc b/onnxruntime/test/contrib_ops/fastgelu_op_test.cc index 5e0c513397..5614e065ac 100644 --- a/onnxruntime/test/contrib_ops/fastgelu_op_test.cc +++ b/onnxruntime/test/contrib_ops/fastgelu_op_test.cc @@ -110,6 +110,19 @@ static void RunFastGeluTest( RunFastGeluTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, has_bias); } +TEST(FastGeluTest, FastGeluWithNullInput) { + int batch_size = 1; + int sequence_length = 0; + int hidden_size = 4; + + std::vector input_data = {}; + + std::vector bias_data = { + -0.5f, 0.6f, 1.2f, 2.1f}; + + RunFastGeluTest(input_data, bias_data, batch_size, sequence_length, hidden_size); +} + TEST(FastGeluTest, FastGeluWithBiasFloat32) { int batch_size = 1; int sequence_length = 2; diff --git a/onnxruntime/test/contrib_ops/layer_norm_test.cc b/onnxruntime/test/contrib_ops/layer_norm_test.cc index 1e2203d520..9c41f218a2 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_test.cc @@ -80,6 +80,11 @@ static void TestLayerNorm(const std::vector& x_dims, #endif } +TEST(CudaKernelTest, LayerNorm_NullInput) { + const std::vector X_dims{0, 20, 128}; + TestLayerNorm(X_dims, LAYER_NORM_OP, k_epsilon_default); +} + TEST(CudaKernelTest, LayerNorm_SmallSizeTensor) { const std::vector X_dims{4, 20, 128}; TestLayerNorm(X_dims, LAYER_NORM_OP, k_epsilon_default); diff --git a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc index 5a9a4dce6e..009201f130 100644 --- a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc +++ b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc @@ -83,6 +83,35 @@ static void RunTest( } } +TEST(SkipLayerNormTest, SkipLayerNormNullInput) { + int batch_size = 1; + int sequence_length = 0; + int hidden_size = 4; + + std::vector input_data = {}; + + std::vector skip_data = {}; + + std::vector gamma_data = { + 0.3f, 0.2f, 4.0f, 2.2f}; + + std::vector beta_data = { + 0.2f, 0.1f, 0.4f, 1.6f}; + + std::vector output_data = {}; + + RunTest(input_data, + skip_data, + gamma_data, + beta_data, + std::vector(), + output_data, + epsilon_, + batch_size, + sequence_length, + hidden_size); +} + TEST(SkipLayerNormTest, SkipLayerNormBatch1) { int batch_size = 1; int sequence_length = 2;