From 473cd5545fde1d9cbcf48b7af898c5837f6e3e80 Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Wed, 1 Jul 2020 15:18:02 -0700 Subject: [PATCH] Simple support of MatMul U8S8 on ARM to pass tests (#4392) --- onnxruntime/core/util/qmath.cc | 35 ++++++++++--------- .../dynamic_quantize_matmul_test.cc | 2 -- .../contrib_ops/quantize_attention_op_test.cc | 2 -- .../providers/cpu/math/matmul_integer_test.cc | 4 --- 4 files changed, 19 insertions(+), 24 deletions(-) diff --git a/onnxruntime/core/util/qmath.cc b/onnxruntime/core/util/qmath.cc index 67f84b32b4..d7c7ee4671 100644 --- a/onnxruntime/core/util/qmath.cc +++ b/onnxruntime/core/util/qmath.cc @@ -12,6 +12,19 @@ #endif namespace onnxruntime { + +template +void QGemmWithEigen(const TA* A_data, const TB* B_data, TY* Y_data, int M, int N, int K, TA a_offset, TB b_offset) { + auto A = ConstEigenMatrixMapRowMajor(A_data, M, K); + auto B = ConstEigenMatrixMapRowMajor(B_data, K, N); + + auto A_row_sum = (A.template cast().rowwise().sum()) * static_cast(b_offset); + auto B_col_sum = (B.template cast().colwise().sum()) * static_cast(a_offset); + EigenMatrixMapRowMajor(Y_data, M, N) = A.template cast() * B.template cast() + static_cast(K * a_offset * b_offset) * ConstEigenMatrixMapRowMajor::Ones(M, N); + EigenMatrixMapRowMajor(Y_data, M, N).colwise() -= A_row_sum; + EigenMatrixMapRowMajor(Y_data, M, N).rowwise() -= B_col_sum; +} + template <> void QGemm( int M, @@ -29,20 +42,12 @@ void QGemm( #ifdef MLAS_SUPPORTS_GEMM_U8X8 MlasGemm(M, N, K, lhs_data, lda, lhs_offset, rhs_data, ldb, rhs_offset, result_data, ldc, thread_pool); #else - ORT_UNUSED_PARAMETER(M); - ORT_UNUSED_PARAMETER(N); - ORT_UNUSED_PARAMETER(K); - ORT_UNUSED_PARAMETER(lhs_data); - ORT_UNUSED_PARAMETER(lda); - ORT_UNUSED_PARAMETER(lhs_offset); - ORT_UNUSED_PARAMETER(rhs_data); - ORT_UNUSED_PARAMETER(ldb); - ORT_UNUSED_PARAMETER(rhs_offset); - ORT_UNUSED_PARAMETER(result_data); - ORT_UNUSED_PARAMETER(ldc); ORT_UNUSED_PARAMETER(thread_pool); - ORT_NOT_IMPLEMENTED("MatMulInteger: activation uint8 and weight int8 not supported on ARM"); + ORT_ENFORCE(lda == K && ldb == N && ldc == N, "For Eigen only RowMajor*RowMajor=RowMajor format is supported"); + + QGemmWithEigen(lhs_data, rhs_data, result_data, M, N, K, lhs_offset, rhs_offset); + #endif } @@ -104,8 +109,7 @@ void QGemm( #endif } -template -void QGemm( +template void QGemm( int M, int N, int K, @@ -121,8 +125,7 @@ void QGemm( const float* bias, concurrency::ThreadPool* thread_pool); -template -void QGemm( +template void QGemm( int M, int N, int K, diff --git a/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc b/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc index 991cbf87bd..b66fd66ae2 100644 --- a/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc +++ b/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc @@ -49,13 +49,11 @@ void TestDynamicQuantizeMatMul(const std::vector& A_dims, } TEST(DynamicQuantizeMatMul, Int8_test) { -#ifdef MLAS_SUPPORTS_GEMM_U8X8 std::vector A_dims{4, 128}; std::vector B_dims{128, 128}; std::vector Y_dims{4, 128}; TestDynamicQuantizeMatMul(A_dims, B_dims, "testdata/dynamic_quantize_matmul_int8.onnx"); -#endif } TEST(DynamicQuantizeMatMul, UInt8_test) { diff --git a/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc b/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc index 4b4dfb2b79..1dd6d4f2c9 100644 --- a/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc @@ -135,11 +135,9 @@ void RunQAttention(const std::vector& input_data, // input: execution_providers.push_back(DefaultCudaExecutionProvider()); tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } else { -#ifdef MLAS_SUPPORTS_GEMM_U8X8 std::vector> execution_providers; execution_providers.push_back(DefaultCpuExecutionProvider()); tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); -#endif } } diff --git a/onnxruntime/test/providers/cpu/math/matmul_integer_test.cc b/onnxruntime/test/providers/cpu/math/matmul_integer_test.cc index 09d7b7a2f7..d5abeb5406 100644 --- a/onnxruntime/test/providers/cpu/math/matmul_integer_test.cc +++ b/onnxruntime/test/providers/cpu/math/matmul_integer_test.cc @@ -302,13 +302,9 @@ void RunMatMulIntegerU8S8Test(const int M, const int N, const int K, bool non_ze } } -#ifdef MLAS_SUPPORTS_GEMM_U8X8 #define RUN_MATMUL_INTEGER_U8S8(M, N, K) \ RunMatMulIntegerU8S8Test(M, N, K, false /*non_zero_zp*/); \ RunMatMulIntegerU8S8Test(M, N, K, true /*non_zero_zp*/); -#else -#define RUN_MATMUL_INTEGER_U8S8(M, N, K) -#endif // MLAS_SUPPORTS_GEMM_U8X8 TEST(MatmulIntegerOpTest, MatMulInteger_Uint8_Int8_Scalar) { RUN_MATMUL_INTEGER_U8S8(1, 1, 32);