From c2b8ac01544b5b2012ff60009aefb59e53606070 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Wed, 20 Feb 2019 17:03:37 -0800 Subject: [PATCH] MatMul op: Support new integer types and double type as part of opset V9 compliance (#482) * Support new integer types and double type as part of opset V9 compliance --- .../providers/cpu/cpu_execution_provider.cc | 14 +- onnxruntime/core/providers/cpu/math/matmul.cc | 59 +++- onnxruntime/core/util/math_cpu.cc | 278 +++++++++++++++--- .../test/providers/cpu/math/matmul_test.cc | 163 ++++++---- 4 files changed, 404 insertions(+), 110 deletions(-) diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index d6fd83eb64..83a5a727c7 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -89,7 +89,12 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Ata class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, Gemm); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Hardmax); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, LogSoftmax); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, MatMul); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, float, MatMul); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, double, MatMul); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 9, int32_t, MatMul); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 9, uint32_t, MatMul); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 9, int64_t, MatMul); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 9, uint64_t, MatMul); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Softmax); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, TopK); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, BatchNormalization); @@ -342,7 +347,12 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); - kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); + kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); kernel_registry.Register(BuildKernelCreateInfo()); diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index b66f302656..845fafaabb 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -9,15 +9,50 @@ namespace onnxruntime { -ONNX_CPU_OPERATOR_VERSIONED_KERNEL( - MatMul, - 1, - 9, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - MatMul); +ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( + MatMul, + 1, 9, + float, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); -template <> -Status MatMul::Compute(OpKernelContext* ctx) const { +ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( + MatMul, + 1, 9, + double, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + +ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( + MatMul, + 9, 9, + int32_t, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + +ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( + MatMul, + 9, 9, + uint32_t, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + +ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( + MatMul, + 9, 9, + int64_t, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + +ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( + MatMul, + 9, 9, + uint64_t, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + +template +Status MatMul::Compute(OpKernelContext* ctx) const { const Tensor* left_X = ctx->Input(0); const Tensor* right_X = ctx->Input(1); @@ -28,17 +63,17 @@ Status MatMul::Compute(OpKernelContext* ctx) const { // TODO: replace it with GemmBatch for performance, it's OK for now as GemmBatch unrolls as well for (int i = 0; i < helper.OutputOffsets().size(); i++) { - math::Gemm( + math::Gemm( CblasNoTrans, CblasNoTrans, static_cast(helper.M()), static_cast(helper.N()), static_cast(helper.K()), /* alpha */ 1.0f, - left_X->template Data() + helper.LeftOffsets()[i], - right_X->template Data() + helper.RightOffsets()[i], + left_X->template Data() + helper.LeftOffsets()[i], + right_X->template Data() + helper.RightOffsets()[i], /* beta */ 0.0f, - Y->template MutableData() + helper.OutputOffsets()[i], + Y->template MutableData() + helper.OutputOffsets()[i], &CPUMathUtil::Instance()); } diff --git a/onnxruntime/core/util/math_cpu.cc b/onnxruntime/core/util/math_cpu.cc index 432a9e03c8..8459eb019a 100644 --- a/onnxruntime/core/util/math_cpu.cc +++ b/onnxruntime/core/util/math_cpu.cc @@ -51,6 +51,59 @@ namespace onnxruntime { namespace math { +// Gemm implementation purely based on Eigen. +template +void GemmEigen( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + int64_t M, + int64_t N, + int64_t K, + float alpha, + const T* A, + const T* B, + float beta, + T* C) { + auto C_mat = EigenMatrixMap(C, N, M); + if (beta == 0) { + C_mat.setZero(); + } else { + C_mat *= static_cast(beta); + } + switch (TransA) { + case CblasNoTrans: { + switch (TransB) { + case CblasNoTrans: + C_mat.noalias() += static_cast(alpha) * (ConstEigenMatrixMap(B, N, K) * + ConstEigenMatrixMap(A, K, M)); + return; + case CblasTrans: + C_mat.noalias() += static_cast(alpha) * (ConstEigenMatrixMap(B, K, N).transpose() * + ConstEigenMatrixMap(A, K, M)); + return; + default: + ORT_THROW("CblasNoTrans Unexpected CBLAS_TRANSPOSE for TransB of ", TransB); + } + } + case CblasTrans: { + switch (TransB) { + case CblasNoTrans: + C_mat.noalias() += static_cast(alpha) * (ConstEigenMatrixMap(B, N, K) * + ConstEigenMatrixMap(A, M, K).transpose()); + return; + case CblasTrans: + C_mat.noalias() += static_cast(alpha) * (ConstEigenMatrixMap(B, K, N).transpose() * + ConstEigenMatrixMap(A, M, K).transpose()); + return; + default: + ORT_THROW("CblasTrans Unexpected CBLAS_TRANSPOSE for TransB of ", TransB); + } + } + default: + ORT_THROW("Unexpected CBLAS_TRANSPOSE for TransA of ", TransA); + } +} + //////////////////////////////////////////////////////////////////////////////// // BLAS alternatives. // Depending on whether we have specified an external BLAS library or not, we @@ -110,47 +163,100 @@ void Gemm( int ldb = (int)((TransB == CblasNoTrans) ? N : K); MlasSgemm(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, N); #else - auto C_mat = EigenMatrixMap(C, N, M); - if (beta == 0) { - C_mat.setZero(); - } else { - C_mat *= beta; - } - switch (TransA) { - case CblasNoTrans: { - switch (TransB) { - case CblasNoTrans: - C_mat.noalias() += alpha * (ConstEigenMatrixMap(B, N, K) * - ConstEigenMatrixMap(A, K, M)); - return; - case CblasTrans: - C_mat.noalias() += alpha * (ConstEigenMatrixMap(B, K, N).transpose() * - ConstEigenMatrixMap(A, K, M)); - return; - default: - ORT_THROW("CblasNoTrans Unexpected CBLAS_TRANSPOSE for TransB of ", TransB); - } - } - case CblasTrans: { - switch (TransB) { - case CblasNoTrans: - C_mat.noalias() += alpha * (ConstEigenMatrixMap(B, N, K) * - ConstEigenMatrixMap(A, M, K).transpose()); - return; - case CblasTrans: - C_mat.noalias() += alpha * (ConstEigenMatrixMap(B, K, N).transpose() * - ConstEigenMatrixMap(A, M, K).transpose()); - return; - default: - ORT_THROW("CblasTrans Unexpected CBLAS_TRANSPOSE for TransB of ", TransB); - } - } - default: - ORT_THROW("Unexpected CBLAS_TRANSPOSE for TransA of ", TransA); - } + GemmEigen(TransA, TransB, M, N, K, alpha, A, B, beta, C); #endif } +template <> +void Gemm( + const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, + const int64_t M, + const int64_t N, + const int64_t K, + const float alpha, + const double* A, + const double* B, + const float beta, + double* C, + CPUMathUtil* /*provider*/, + MLDataType /*math_type*/) { + // No double precision Gemm offering from MLAS or MKLDNN. Directly fallback to Eigen. + GemmEigen(TransA, TransB, M, N, K, alpha, A, B, beta, C); +} + +template <> +void Gemm( + const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, + const int64_t M, + const int64_t N, + const int64_t K, + const float alpha, + const int32_t* A, + const int32_t* B, + const float beta, + int32_t* C, + CPUMathUtil* /*provider*/, + MLDataType /*math_type*/) { + // No int32_t Gemm offering from MLAS or MKLDNN. Directly fallback to Eigen. + GemmEigen(TransA, TransB, M, N, K, alpha, A, B, beta, C); +} + +template <> +void Gemm( + const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, + const int64_t M, + const int64_t N, + const int64_t K, + const float alpha, + const uint32_t* A, + const uint32_t* B, + const float beta, + uint32_t* C, + CPUMathUtil* /*provider*/, + MLDataType /*math_type*/) { + // No uint32_t Gemm offering from MLAS or MKLDNN. Directly fallback to Eigen. + GemmEigen(TransA, TransB, M, N, K, alpha, A, B, beta, C); +} + +template <> +void Gemm( + const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, + const int64_t M, + const int64_t N, + const int64_t K, + const float alpha, + const int64_t* A, + const int64_t* B, + const float beta, + int64_t* C, + CPUMathUtil* /*provider*/, + MLDataType /*math_type*/) { + // No int64_t Gemm offering from MLAS or MKLDNN. Directly fallback to Eigen. + GemmEigen(TransA, TransB, M, N, K, alpha, A, B, beta, C); +} + +template <> +void Gemm( + const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, + const int64_t M, + const int64_t N, + const int64_t K, + const float alpha, + const uint64_t* A, + const uint64_t* B, + const float beta, + uint64_t* C, + CPUMathUtil* /*provider*/, + MLDataType /*math_type*/) { + // No uint64_t Gemm offering from MLAS or MKLDNN. Directly fallback to Eigen. + GemmEigen(TransA, TransB, M, N, K, alpha, A, B, beta, C); +} + template <> void GemmEx( const CBLAS_TRANSPOSE TransA, @@ -343,6 +449,102 @@ void Gemm( beta, C, gsl::narrow_cast(N)); } +template <> +void Gemm( + const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, + const int64_t M, + const int64_t N, + const int64_t K, + const float alpha, + const double* A, + const double* B, + const float beta, + double* C, + CPUMathUtil* /*provider*/, + MLDataType /*math_type*/) { + int lda = gsl::narrow_cast((TransA == CblasNoTrans) ? K : M); + int ldb = gsl::narrow_cast((TransB == CblasNoTrans) ? N : K); + cblas_dgemm(CblasRowMajor, TransA, TransB, + gsl::narrow_cast(M), + gsl::narrow_cast(N), + gsl::narrow_cast(K), + gsl::narrow_cast(alpha), A, lda, B, ldb, + gsl::narrow_cast(beta), C, gsl::narrow_cast(N)); +} + +template <> +void Gemm( + const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, + const int64_t M, + const int64_t N, + const int64_t K, + const float alpha, + const int32_t* A, + const int32_t* B, + const float beta, + int32_t* C, + CPUMathUtil* /*provider*/, + MLDataType /*math_type*/) { + // No int32_t Gemm offering from MKLML. Directly fallback to Eigen. + GemmEigen(TransA, TransB, M, N, K, alpha, A, B, beta, C); +} + +template <> +void Gemm( + const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, + const int64_t M, + const int64_t N, + const int64_t K, + const float alpha, + const uint32_t* A, + const uint32_t* B, + const float beta, + uint32_t* C, + CPUMathUtil* /*provider*/, + MLDataType /*math_type*/) { + // No uint32_t Gemm offering from MKLML. Directly fallback to Eigen. + GemmEigen(TransA, TransB, M, N, K, alpha, A, B, beta, C); +} + +template <> +void Gemm( + const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, + const int64_t M, + const int64_t N, + const int64_t K, + const float alpha, + const int64_t* A, + const int64_t* B, + const float beta, + int64_t* C, + CPUMathUtil* /*provider*/, + MLDataType /*math_type*/) { + // No int64_t Gemm offering from MKLML. Directly fallback to Eigen. + GemmEigen(TransA, TransB, M, N, K, alpha, A, B, beta, C); +} + +template <> +void Gemm( + const CBLAS_TRANSPOSE TransA, + const CBLAS_TRANSPOSE TransB, + const int64_t M, + const int64_t N, + const int64_t K, + const float alpha, + const uint64_t* A, + const uint64_t* B, + const float beta, + uint64_t* C, + CPUMathUtil* /*provider*/, + MLDataType /*math_type*/) { + // No uint64_t Gemm offering from MKLML. Directly fallback to Eigen. + GemmEigen(TransA, TransB, M, N, K, alpha, A, B, beta, C); +} + template <> void GemmEx( const CBLAS_TRANSPOSE TransA, diff --git a/onnxruntime/test/providers/cpu/math/matmul_test.cc b/onnxruntime/test/providers/cpu/math/matmul_test.cc index 19e9c0bb92..70721fe700 100644 --- a/onnxruntime/test/providers/cpu/math/matmul_test.cc +++ b/onnxruntime/test/providers/cpu/math/matmul_test.cc @@ -7,75 +7,122 @@ namespace onnxruntime { namespace test { -TEST(MathOpTest, MatMul) { - std::vector vals{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f}; +template +struct MatMulTestData { + std::string name; + std::vector input0_dims; + std::vector input1_dims; + std::vector expected_dims; + std::vector expected_vals; +}; - struct MatMulTest { - std::string name; - std::vector input0_dims; - std::vector input1_dims; - std::vector expected_dims; - std::vector expected_vals; - }; +template +std::vector> GenerateTestCases() +{ + std::vector> test_cases; - MatMulTest testcases[] = { - {"test padding and broadcast", - {3, 1, 1, 2}, - {2, 2, 2}, - {3, 2, 1, 2}, - {2, 3, 6, 7, 6, 11, 26, 31, 10, 19, 46, 55}}, - {"test padding and broadcast", - {2, 3, 2}, - {3, 2, 2, 1}, - {3, 2, 3, 1}, - {1, 3, 5, 33, 43, 53, 5, 23, 41, 85, 111, 137, 9, 43, 77, 137, 179, 221}}, - {"test left 1D", - {2}, - {3, 2, 1}, - {3, 1}, - {1, 3, 5}}, - {"test right 1D", - {3, 1, 2}, - {2}, - {3, 1}, - {1, 3, 5}}, - {"test scalar output", - {3}, - {3}, - {}, - {5}}, - {"test 2D", - {3, 4}, - {4, 3}, - {3, 3}, - {42, 48, 54, 114, 136, 158, 186, 224, 262}}, - {"test 2D special", - {2, 2, 3}, - {3, 4}, - {2, 2, 4}, - {20, 23, 26, 29, 56, 68, 80, 92, 92, 113, 134, 155, 128, 158, 188, 218}}, - {"test 2D special 2", - {2, 2, 3}, - {1, 3, 4}, - {2, 2, 4}, - {20, 23, 26, 29, 56, 68, 80, 92, 92, 113, 134, 155, 128, 158, 188, 218}}, - }; + test_cases.push_back( + {"test padding and broadcast", + {3, 1, 1, 2}, + {2, 2, 2}, + {3, 2, 1, 2}, + {2, 3, 6, 7, 6, 11, 26, 31, 10, 19, 46, 55}}); - for (auto t : testcases) { - OpTester test("MatMul"); + test_cases.push_back( + {"test padding and broadcast", + {2, 3, 2}, + {3, 2, 2, 1}, + {3, 2, 3, 1}, + {1, 3, 5, 33, 43, 53, 5, 23, 41, 85, 111, 137, 9, 43, 77, 137, 179, 221}}); + + test_cases.push_back( + {"test left 1D", + {2}, + {3, 2, 1}, + {3, 1}, + {1, 3, 5}}); + + test_cases.push_back( + {"test right 1D", + {3, 1, 2}, + {2}, + {3, 1}, + {1, 3, 5}}); + + test_cases.push_back( + {"test scalar output", + {3}, + {3}, + {}, + {5}}); + + test_cases.push_back( + {"test 2D", + {3, 4}, + {4, 3}, + {3, 3}, + {42, 48, 54, 114, 136, 158, 186, 224, 262}}); + + test_cases.push_back( + {"test 2D special", + {2, 2, 3}, + {3, 4}, + {2, 2, 4}, + {20, 23, 26, 29, 56, 68, 80, 92, 92, 113, 134, 155, 128, 158, 188, 218}}); + + test_cases.push_back( + {"test 2D special 2", + {2, 2, 3}, + {1, 3, 4}, + {2, 2, 4}, + {20, 23, 26, 29, 56, 68, 80, 92, 92, 113, 134, 155, 128, 158, 188, 218}}); + + return test_cases; +} + +template +void RunMatMulTest(int32_t opset_version = 7) +{ + std::vector common_input_vals{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + for (auto t : GenerateTestCases()) { + OpTester test("MatMul", opset_version); int64_t size0 = TensorShape::ReinterpretBaseType(t.input0_dims).SizeHelper(0, t.input0_dims.size()); - std::vector input0_vals(vals.cbegin(), vals.cbegin() + size0); - test.AddInput("A", t.input0_dims, input0_vals); + std::vector input0_vals(common_input_vals.cbegin(), common_input_vals.cbegin() + size0); + test.AddInput("A", t.input0_dims, input0_vals); int64_t size1 = TensorShape::ReinterpretBaseType(t.input1_dims).SizeHelper(0, t.input1_dims.size()); - std::vector input1_vals(vals.cbegin(), vals.cbegin() + size1); - test.AddInput("B", t.input1_dims, input1_vals); + std::vector input1_vals(common_input_vals.cbegin(), common_input_vals.cbegin() + size1); + test.AddInput("B", t.input1_dims, input1_vals); - test.AddOutput("Y", t.expected_dims, t.expected_vals); + test.AddOutput("Y", t.expected_dims, t.expected_vals); test.Run(); } } +TEST(MathOpTest, MatMulFloatType) { + RunMatMulTest(); +} + +TEST(MathOpTest, MatMulDoubleType) { + RunMatMulTest(); +} + +TEST(MathOpTest, MatMulInt32Type) { + RunMatMulTest(9); +} + +TEST(MathOpTest, MatMulUint32Type) { + RunMatMulTest(9); +} + +TEST(MathOpTest, MatMulInt64Type) { + RunMatMulTest(9); +} + +TEST(MathOpTest, MatMulUint64Type) { + RunMatMulTest(9); +} + } // namespace test } // namespace onnxruntime