From b99eaa99cdf3cc4403564ebce2a210eb2db5dd90 Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Fri, 9 Oct 2020 02:37:19 -0700 Subject: [PATCH] Prepacking MatMulInteger (#5403) * prepack matmulinteger Prepacking constant matrix B for MatMulInteger to get better performance. --- .../quantization/dynamic_quantize_matmul.cc | 50 ++--------------- .../core/providers/cpu/math/matmul_integer.cc | 33 +++++++++--- .../providers/cpu/math/matmul_integer_base.h | 54 +++++++++++++++++++ .../providers/cpu/math/matmul_integer_test.cc | 53 ++++++++++-------- 4 files changed, 112 insertions(+), 78 deletions(-) create mode 100644 onnxruntime/core/providers/cpu/math/matmul_integer_base.h diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc index 7092fec005..fabddfb17b 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc @@ -1,28 +1,22 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/framework/op_kernel.h" #include "core/common/safeint.h" -#include "core/providers/common.h" #include "core/providers/cpu/math/matmul_helper.h" +#include "core/providers/cpu/math/matmul_integer_base.h" #include "core/util/math_cpuonly.h" #include "core/util/qmath.h" -#include "core/mlas/inc/mlas.h" #include namespace onnxruntime { namespace contrib { -class MatMulIntegerToFloatBase : public OpKernel { +class MatMulIntegerToFloatBase : public MatMulIntegerBase { public: - MatMulIntegerToFloatBase(const OpKernelInfo& info) : OpKernel(info) { + MatMulIntegerToFloatBase(const OpKernelInfo& info) : MatMulIntegerBase(info) { } -#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8 - Status PrePack(const Tensor& tensor, int input_idx, bool& is_packed) override; -#endif - protected: Status ComputeCommon(OpKernelContext* ctx, const uint8_t* a_data, @@ -32,46 +26,8 @@ class MatMulIntegerToFloatBase : public OpKernel { uint8_t b_zero_point, float multiplier, const Tensor* bias_tensor) const; - - bool b_is_signed_; - TensorShape b_shape_; - BufferUniquePtr packed_b_; }; -#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8 -Status MatMulIntegerToFloatBase::PrePack(const Tensor& tensor, int input_idx, bool& is_packed) { - is_packed = false; - - // only pack Matrix B - if (input_idx == 1) { - // Only handle the common case of a 2D weight matrix. Additional matrices - // could be handled by stacking the packed buffers. - b_shape_ = tensor.Shape(); - if (b_shape_.NumDimensions() != 2) { - return Status::OK(); - } - - const size_t K = static_cast(b_shape_[0]); - const size_t N = static_cast(b_shape_[1]); - - const auto* b_data = static_cast(tensor.DataRaw()); - b_is_signed_ = tensor.IsDataType(); - - const size_t packed_b_size = MlasGemmPackBSize(N, K, b_is_signed_); - if (packed_b_size == 0) { - return Status::OK(); - } - - auto alloc = Info().GetAllocator(0, OrtMemTypeDefault); - auto* packed_b_data = alloc->Alloc(packed_b_size); - packed_b_ = BufferUniquePtr(packed_b_data, BufferDeleter(alloc)); - MlasGemmPackB(N, K, b_data, N, b_is_signed_, packed_b_data); - is_packed = true; - } - return Status::OK(); -} -#endif - Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx, const uint8_t* a_data, const TensorShape& a_shape, diff --git a/onnxruntime/core/providers/cpu/math/matmul_integer.cc b/onnxruntime/core/providers/cpu/math/matmul_integer.cc index a36d3dd254..4bf3912d4b 100644 --- a/onnxruntime/core/providers/cpu/math/matmul_integer.cc +++ b/onnxruntime/core/providers/cpu/math/matmul_integer.cc @@ -1,17 +1,17 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/framework/op_kernel.h" +#include "matmul_integer_base.h" + #include "core/providers/cpu/math/matmul_helper.h" #include "core/util/math_cpuonly.h" #include "core/util/qmath.h" -#include "core/providers/common.h" namespace onnxruntime { -class MatMulInteger final : public OpKernel { +class MatMulInteger final : public MatMulIntegerBase { public: - MatMulInteger(const OpKernelInfo& info) : OpKernel(info) {} + MatMulInteger(const OpKernelInfo& info) : MatMulIntegerBase(info) {} Status Compute(OpKernelContext* context) const override; }; @@ -32,10 +32,10 @@ Status MatMulInteger::Compute(OpKernelContext* ctx) const { concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); const auto* a = ctx->Input(0); - const auto* b = ctx->Input(1); + const Tensor* b = packed_b_ ? nullptr : ctx->Input(1); MatMulComputeHelper helper; - ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b->Shape())); + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), packed_b_ ? b_shape_ : b->Shape())); Tensor* y = ctx->Output(0, helper.OutputShape()); // Bail out early if the output is going to be empty @@ -59,11 +59,28 @@ Status MatMulInteger::Compute(OpKernelContext* ctx) const { } const auto* a_data = a->template Data(); - const auto* b_data = static_cast(b->DataRaw()); - const bool b_is_signed = b->IsDataType(); auto* y_data = y->template MutableData(); for (size_t i = 0; i < helper.OutputOffsets().size(); i++) { +#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8 + if (packed_b_) { + MlasGemm(static_cast(helper.M()), + static_cast(helper.N()), + static_cast(helper.K()), + a_data + helper.LeftOffsets()[i], + static_cast(helper.K()), + a_offset, + packed_b_.get(), + b_offset, + b_is_signed_, + y_data + helper.OutputOffsets()[i], + static_cast(helper.N()), + thread_pool); + continue; + } +#endif + const auto* b_data = static_cast(b->DataRaw()); + const bool b_is_signed = b->IsDataType(); QGemm(static_cast(helper.M()), static_cast(helper.N()), static_cast(helper.K()), diff --git a/onnxruntime/core/providers/cpu/math/matmul_integer_base.h b/onnxruntime/core/providers/cpu/math/matmul_integer_base.h new file mode 100644 index 0000000000..ad5bd7d2bd --- /dev/null +++ b/onnxruntime/core/providers/cpu/math/matmul_integer_base.h @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/op_kernel.h" +#include "core/mlas/inc/mlas.h" +#include "core/providers/common.h" + +namespace onnxruntime { + +class MatMulIntegerBase : public OpKernel { + public: + MatMulIntegerBase(const OpKernelInfo& info) : OpKernel(info) {} + +#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8 + Status PrePack(const Tensor& tensor, int input_idx, bool& is_packed) override { + is_packed = false; + + // only pack Matrix B + if (input_idx == 1) { + // Only handle the common case of a 2D weight matrix. Additional matrices + // could be handled by stacking the packed buffers. + b_shape_ = tensor.Shape(); + if (b_shape_.NumDimensions() != 2) { + return Status::OK(); + } + + const size_t K = static_cast(b_shape_[0]); + const size_t N = static_cast(b_shape_[1]); + + const auto* b_data = static_cast(tensor.DataRaw()); + b_is_signed_ = tensor.IsDataType(); + + const size_t packed_b_size = MlasGemmPackBSize(N, K, b_is_signed_); + if (packed_b_size == 0) { + return Status::OK(); + } + + auto alloc = Info().GetAllocator(0, OrtMemTypeDefault); + auto* packed_b_data = alloc->Alloc(packed_b_size); + packed_b_ = BufferUniquePtr(packed_b_data, BufferDeleter(alloc)); + MlasGemmPackB(N, K, b_data, N, b_is_signed_, packed_b_data); + is_packed = true; + } + return Status::OK(); + } +#endif + + protected: + bool b_is_signed_; + TensorShape b_shape_; + BufferUniquePtr packed_b_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/math/matmul_integer_test.cc b/onnxruntime/test/providers/cpu/math/matmul_integer_test.cc index b54c04fe82..c844890894 100644 --- a/onnxruntime/test/providers/cpu/math/matmul_integer_test.cc +++ b/onnxruntime/test/providers/cpu/math/matmul_integer_test.cc @@ -273,11 +273,12 @@ T GetMiddle(const std::vector& v) { } // [M x N] = [M x K] x [K x N] = [batch_seq x input_dim] x [input_dim x embed_dim] -void RunMatMulIntegerU8S8Test(const int M, const int N, const int K, bool non_zero_zp) { +template +void RunMatMulIntegerU8X8Test(const int M, const int N, const int K, bool non_zero_zp, bool B_is_initializer) { OpTester test("MatMulInteger", 10); static std::default_random_engine e(123); static std::uniform_int_distribution n_unsigned(0, 127); - static std::uniform_int_distribution n_signed(-128, 127); + static std::uniform_int_distribution n_xint8(std::numeric_limits::min(), std::numeric_limits::max()); Eigen::MatrixXi matrix_a = Eigen::MatrixXi::Random(K, M) .unaryExpr([](int) { return n_unsigned(e); }); @@ -286,18 +287,18 @@ void RunMatMulIntegerU8S8Test(const int M, const int N, const int K, bool non_ze Eigen::MatrixXi matrix_a_offset = matrix_a - a_zero_point * Eigen::MatrixXi::Ones(K, M); Eigen::MatrixXi matrix_b = Eigen::MatrixXi::Random(N, K) - .unaryExpr([](int) { return n_signed(e); }); - std::vector matrix_b_data = ToVector(matrix_b.data(), N * K); - int8_t b_zero_point = non_zero_zp ? GetMiddle(matrix_b_data) : 0; + .unaryExpr([](int) { return n_xint8(e); }); + std::vector matrix_b_data = ToVector(matrix_b.data(), N * K); + ScalarB b_zero_point = non_zero_zp ? GetMiddle(matrix_b_data) : 0; Eigen::MatrixXi matrix_b_offset = matrix_b - b_zero_point * Eigen::MatrixXi::Ones(N, K); Eigen::MatrixXi matrix_c = (matrix_b_offset * matrix_a_offset).eval(); test.AddInput("T1", {M, K}, std::move(matrix_a_data)); - test.AddInput("T2", {K, N}, std::move(matrix_b_data), true /*is_initializer*/); + test.AddInput("T2", {K, N}, std::move(matrix_b_data), B_is_initializer); if (non_zero_zp) { test.AddInput("a_zero_point", {}, {a_zero_point}); - test.AddInput("b_zero_point", {}, {b_zero_point}); + test.AddInput("b_zero_point", {}, {b_zero_point}); } test.AddOutput("T3", {M, N}, ToVector(matrix_c.data(), M * N)); @@ -311,30 +312,36 @@ void RunMatMulIntegerU8S8Test(const int M, const int N, const int K, bool non_ze } } -#define RUN_MATMUL_INTEGER_U8S8(M, N, K) \ - RunMatMulIntegerU8S8Test(M, N, K, false /*non_zero_zp*/); \ - RunMatMulIntegerU8S8Test(M, N, K, true /*non_zero_zp*/); +#define RUN_MATMUL_INTEGER_U8X8(M, N, K) \ + RunMatMulIntegerU8X8Test(M, N, K, false /*non_zero_zp*/, false /*B_is_initializer*/); \ + RunMatMulIntegerU8X8Test(M, N, K, false /*non_zero_zp*/, true /*B_is_initializer*/); \ + RunMatMulIntegerU8X8Test(M, N, K, true /*non_zero_zp*/, false /*B_is_initializer*/); \ + RunMatMulIntegerU8X8Test(M, N, K, true /*non_zero_zp*/, true /*B_is_initializer*/); \ + RunMatMulIntegerU8X8Test(M, N, K, false /*non_zero_zp*/, false /*B_is_initializer*/); \ + RunMatMulIntegerU8X8Test(M, N, K, false /*non_zero_zp*/, true /*B_is_initializer*/); \ + RunMatMulIntegerU8X8Test(M, N, K, true /*non_zero_zp*/, false /*B_is_initializer*/); \ + RunMatMulIntegerU8X8Test(M, N, K, true /*non_zero_zp*/, true /*B_is_initializer*/); TEST(MatmulIntegerOpTest, MatMulInteger_Uint8_Int8_Scalar) { - RUN_MATMUL_INTEGER_U8S8(1, 1, 32); - RUN_MATMUL_INTEGER_U8S8(1, 1, 260); - RUN_MATMUL_INTEGER_U8S8(1, 1, 288); + RUN_MATMUL_INTEGER_U8X8(1, 1, 32); + RUN_MATMUL_INTEGER_U8X8(1, 1, 260); + RUN_MATMUL_INTEGER_U8X8(1, 1, 288); } TEST(MatmulIntegerOpTest, MatMulInteger_Uint8_Int8_GEMV) { - RUN_MATMUL_INTEGER_U8S8(1, 2, 16); - RUN_MATMUL_INTEGER_U8S8(1, 2, 64); - RUN_MATMUL_INTEGER_U8S8(1, 8, 36); - RUN_MATMUL_INTEGER_U8S8(1, 8, 68); - RUN_MATMUL_INTEGER_U8S8(1, 8, 400); - RUN_MATMUL_INTEGER_U8S8(1, 512, 1024); + RUN_MATMUL_INTEGER_U8X8(1, 2, 16); + RUN_MATMUL_INTEGER_U8X8(1, 2, 64); + RUN_MATMUL_INTEGER_U8X8(1, 8, 36); + RUN_MATMUL_INTEGER_U8X8(1, 8, 68); + RUN_MATMUL_INTEGER_U8X8(1, 8, 400); + RUN_MATMUL_INTEGER_U8X8(1, 512, 1024); } TEST(MatmulIntegerOpTest, MatMulInteger_Uint8_Int8_GEMM) { - RUN_MATMUL_INTEGER_U8S8(2, 2, 40); - RUN_MATMUL_INTEGER_U8S8(2, 48, 33); - RUN_MATMUL_INTEGER_U8S8(2, 51, 40); - RUN_MATMUL_INTEGER_U8S8(4, 8, 68); + RUN_MATMUL_INTEGER_U8X8(2, 2, 40); + RUN_MATMUL_INTEGER_U8X8(2, 48, 33); + RUN_MATMUL_INTEGER_U8X8(2, 51, 40); + RUN_MATMUL_INTEGER_U8X8(4, 8, 68); } } // namespace test