From a5e134405d9f8cfa3dc9ec012f50bafb9b95c806 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Tue, 1 Oct 2019 20:32:28 -0700 Subject: [PATCH] Support opset-11 Gemm kernels (#1923) * Support optional bias in Gemm * Fix test * Update * More updates * Update * Update * Update gemm.cc * Update * Update * Fix build break * Update * PR comments * Update --- .../providers/cpu/cpu_execution_provider.cc | 6 ++-- onnxruntime/core/providers/cpu/math/gemm.cc | 10 +++++- onnxruntime/core/providers/cpu/math/gemm.h | 20 +++++++----- .../core/providers/cpu/math/gemm_helper.h | 17 +++++----- .../providers/cuda/cuda_execution_provider.cc | 23 ++++++++++---- onnxruntime/core/providers/cuda/math/gemm.cc | 31 +++++++++++++------ onnxruntime/test/onnx/main.cc | 1 - .../test/providers/cpu/math/gemm_test.cc | 19 ++++++++++++ .../test/python/onnx_backend_test_series.py | 2 -- 9 files changed, 92 insertions(+), 37 deletions(-) diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index f1a068c1f7..b1b16d4b8e 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -261,7 +261,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t, Where); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Flatten); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, BatchNormalization); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Gemm); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Gemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, double, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t, MatMul); @@ -358,6 +358,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Lp class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Conv); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, ConvTranspose); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, If); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Gemm); void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { @@ -604,7 +605,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -701,6 +702,7 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/cpu/math/gemm.cc b/onnxruntime/core/providers/cpu/math/gemm.cc index f2f8e3f920..49d625c1ee 100644 --- a/onnxruntime/core/providers/cpu/math/gemm.cc +++ b/onnxruntime/core/providers/cpu/math/gemm.cc @@ -13,9 +13,17 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( Gemm); // opset 9 added support for additional types (int32, uint32, int64, uint64), however we haven't enabled those yet. +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( + Gemm, + 9, + 10, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Gemm); + +// opset 11 made bias input 'C' optional ONNX_CPU_OPERATOR_KERNEL( Gemm, - 9, + 11, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), Gemm); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/math/gemm.h b/onnxruntime/core/providers/cpu/math/gemm.h index 225754141a..fed4c987e1 100644 --- a/onnxruntime/core/providers/cpu/math/gemm.h +++ b/onnxruntime/core/providers/cpu/math/gemm.h @@ -31,10 +31,12 @@ class Gemm : public OpKernel { auto ctx_internal = static_cast(context); concurrency::ThreadPool* tp = ctx_internal->GetOperatorThreadPool(); - const auto X = context->Input(0); - const auto W = context->Input(1); - const auto B = context->Input(2); - GemmHelper helper(X->Shape(), trans_A_ != CblasNoTrans, W->Shape(), trans_B_ != CblasNoTrans, B->Shape()); + const auto* X = context->Input(0); + const auto* W = context->Input(1); + const auto* B = context->Input(2); + // Bias could be missing. Treat as scalar 0 if that is the case. + GemmHelper helper(X->Shape(), trans_A_ != CblasNoTrans, W->Shape(), trans_B_ != CblasNoTrans, + B != nullptr ? B->Shape() : TensorShape({})); if (!helper.State().IsOK()) return helper.State(); @@ -42,13 +44,13 @@ class Gemm : public OpKernel { int64_t M = helper.M(); int64_t N = helper.N(); auto Y = context->Output(0, {M, N}); - // if input is emtpy tensor, return directly as nothing need to be calculated. + // if input is empty tensor, return directly as nothing need to be calculated. if (M == 0 || N == 0) return Status::OK(); T* y_data = Y->template MutableData(); - // Broadcast the bias as needed. - if (beta_ != 0) { + // Broadcast the bias as needed if bias is given + if (beta_ != 0 && B != nullptr) { auto output_mat = EigenMatrixMapRowMajor(y_data, M, N); const auto& b_shape = B->Shape(); const T* b_data = B->template Data(); @@ -77,7 +79,9 @@ class Gemm : public OpKernel { alpha_, X->template Data(), W->template Data(), - beta_, + // ideally we need to set the output buffer contents to 0 if bias is missing, + // but passing 0 for beta is cheaper and it will ignore any junk in the output buffer + B != nullptr ? beta_ : 0, y_data, tp); diff --git a/onnxruntime/core/providers/cpu/math/gemm_helper.h b/onnxruntime/core/providers/cpu/math/gemm_helper.h index 46bd5e734f..7b7240217d 100644 --- a/onnxruntime/core/providers/cpu/math/gemm_helper.h +++ b/onnxruntime/core/providers/cpu/math/gemm_helper.h @@ -50,17 +50,18 @@ class GemmHelper { Status State() const { return status_; } private: - bool IsValidBroadcast(const TensorShape& shape, int64_t M, int64_t N) { - if (shape.NumDimensions() != 1 && shape.NumDimensions() != 2) + bool IsValidBroadcast(const TensorShape& bias_shape, int64_t M, int64_t N) { + // valid shapes are (,) , (1, N) , (M, 1) , (M, N) + if (bias_shape.NumDimensions() > 2) return false; // shape is (1,) or (1, 1), or (,) - if (shape.Size() == 1) + if (bias_shape.Size() == 1) return true; - // shape is (N,) or (1, N) or (M, 1) - // or (M, N), in last case no broadcast needed, but don't fail it - return ((shape.NumDimensions() == 1 && shape[0] == N) || - (shape.NumDimensions() == 2 && shape[0] == M && (shape[1] == 1 || shape[1] == N)) || - (shape.NumDimensions() == 2 && shape[0] == 1 && shape[1] == N)); + // valid bias_shape (s) are (N,) or (1, N) or (M, 1) or (M, N), + // In last case no broadcasting needed, so don't fail it + return ((bias_shape.NumDimensions() == 1 && bias_shape[0] == N) || + (bias_shape.NumDimensions() == 2 && bias_shape[0] == M && (bias_shape[1] == 1 || bias_shape[1] == N)) || + (bias_shape.NumDimensions() == 2 && bias_shape[0] == 1 && bias_shape[1] == N)); } private: diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 3d0a8f8702..a8b9b240e3 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -214,9 +214,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, Ga class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, float, Gemm); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, double, Gemm); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, Gemm); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float, Gemm); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, double, Gemm); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, MLFloat16, Gemm); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 10, float, Gemm); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 10, double, Gemm); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, 10, MLFloat16, Gemm); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 8, float, MatMul); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 8, double, MatMul); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 8, MLFloat16, MatMul); @@ -537,10 +537,16 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, double, Less); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, MLFloat16, Less); +// opset 10 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, Dropout); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, float, RoiAlign); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, double, RoiAlign); +// opset 11 +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, float, Gemm); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, double, Gemm); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, MLFloat16, Gemm); + static void RegisterCudaKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, @@ -555,9 +561,9 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -880,6 +886,11 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + // opset 11 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/cuda/math/gemm.cc b/onnxruntime/core/providers/cuda/math/gemm.cc index 9fcf089f31..87e500cc19 100644 --- a/onnxruntime/core/providers/cuda/math/gemm.cc +++ b/onnxruntime/core/providers/cuda/math/gemm.cc @@ -20,10 +20,20 @@ namespace cuda { KernelDefBuilder() \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ Gemm); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ Gemm, \ kOnnxDomain, \ 9, \ + 10, \ + T, \ + kCudaExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Gemm); \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + Gemm, \ + kOnnxDomain, \ + 11, \ T, \ kCudaExecutionProvider, \ KernelDefBuilder() \ @@ -38,10 +48,11 @@ template Status Gemm::ComputeInternal(OpKernelContext* ctx) const { typedef typename ToCudaType::MappedType CudaT; - const auto X = ctx->Input(0); - const auto W = ctx->Input(1); - const auto B = ctx->Input(2); - GemmHelper helper(X->Shape(), trans_A_, W->Shape(), trans_B_, B->Shape()); + const auto* X = ctx->Input(0); + const auto* W = ctx->Input(1); + const auto* B = ctx->Input(2); + // Bias could be missing. Treat as scalar 0 if that is the case. + GemmHelper helper(X->Shape(), trans_A_, W->Shape(), trans_B_, B != nullptr ? B->Shape() : TensorShape({})); if (!helper.State().IsOK()) return helper.State(); @@ -49,14 +60,14 @@ Status Gemm::ComputeInternal(OpKernelContext* ctx) const { int M = gsl::narrow_cast(helper.M()); int N = gsl::narrow_cast(helper.N()); int K = gsl::narrow_cast(helper.K()); - auto Y = ctx->Output(0, TensorShape(std::vector{M, N})); + auto* Y = ctx->Output(0, TensorShape(std::vector{M, N})); CudaT* out_data = reinterpret_cast(Y->template MutableData()); CudaT one = ToCudaType::FromFloat(1.0f); CudaT zero = ToCudaType::FromFloat(0.0f); - // broadcast bias if needed - if (beta_ != 0) { + // broadcast bias if needed and is present + if (beta_ != 0 && B != nullptr) { auto& b_shape = B->Shape(); const CudaT* b_data = reinterpret_cast(B->template Data()); @@ -112,7 +123,9 @@ Status Gemm::ComputeInternal(OpKernelContext* ctx) const { (trans_B_ ? K : N), reinterpret_cast(X->template Data()), (trans_A_ ? M : K), - &beta, + // ideally we need to set the output buffer contents to 0 if bias is missing, + // but passing 0 for beta is cheaper and it will ignore any junk in the output buffer + B != nullptr ? &beta : &zero, out_data, N)); return Status::OK(); diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 031913b7ac..a22290d31f 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -471,7 +471,6 @@ int real_main(int argc, char* argv[], Ort::Env& env) { {"bitshift_left_uint64", "BitShift(11) not implemented yet"}, {"bitshift_left_uint32", "BitShift(11) not implemented yet"}, {"bitshift_left_uint16", "BitShift(11) not implemented yet"}, - {"gemm_default_scalar_bias", "Gemm ValidBroadcast() has bug to be fixed."}, }; #ifdef USE_NGRAPH diff --git a/onnxruntime/test/providers/cpu/math/gemm_test.cc b/onnxruntime/test/providers/cpu/math/gemm_test.cc index 8673f9cc6b..6c22f13038 100644 --- a/onnxruntime/test/providers/cpu/math/gemm_test.cc +++ b/onnxruntime/test/providers/cpu/math/gemm_test.cc @@ -232,5 +232,24 @@ TEST(GemmOpTest, GemmEmptyTensor) { test.Run(); } +TEST(GemmOpTest, GemmNoBiasOpset11) { + OpTester test("Gemm", 11); + + test.AddAttribute("transA", static_cast(0)); + test.AddAttribute("transB", static_cast(0)); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); + + test.AddInput("A", {2, 4}, + {1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f}); + test.AddInput("B", {4, 3}, std::vector(12, 1.0f)); + test.AddOutput("Y", {2, 3}, + {10.0f, 10.0f, 10.0f, + -10.0f, -10.0f, -10.0f}); + // NGraph and tensorRT don't seem to support missing bias + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kNGraphExecutionProvider, kTensorrtExecutionProvider}); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/python/onnx_backend_test_series.py b/onnxruntime/test/python/onnx_backend_test_series.py index ae24fd917a..02df7e0664 100644 --- a/onnxruntime/test/python/onnx_backend_test_series.py +++ b/onnxruntime/test/python/onnx_backend_test_series.py @@ -148,8 +148,6 @@ def create_backend_test(testname=None): '^test_reduce_*', '^test_onehot_*', '^test_constant_pad_cpu.*', - '^test_gemm_default_scalar_bias_cpu.*', - '^test_gemm_*', '^test_edge_pad_cpu.*', '^test_reflect_pad_cpu.*' )