Prepacking MatMulInteger (#5403)

* prepack matmulinteger
Prepacking constant matrix B for MatMulInteger to get better performance.
This commit is contained in:
Yufeng Li 2020-10-09 02:37:19 -07:00 committed by GitHub
parent 621fdb44e5
commit b99eaa99cd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 112 additions and 78 deletions

View file

@ -1,28 +1,22 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/framework/op_kernel.h"
#include "core/common/safeint.h"
#include "core/providers/common.h"
#include "core/providers/cpu/math/matmul_helper.h"
#include "core/providers/cpu/math/matmul_integer_base.h"
#include "core/util/math_cpuonly.h"
#include "core/util/qmath.h"
#include "core/mlas/inc/mlas.h"
#include <algorithm>
namespace onnxruntime {
namespace contrib {
class MatMulIntegerToFloatBase : public OpKernel {
class MatMulIntegerToFloatBase : public MatMulIntegerBase {
public:
MatMulIntegerToFloatBase(const OpKernelInfo& info) : OpKernel(info) {
MatMulIntegerToFloatBase(const OpKernelInfo& info) : MatMulIntegerBase(info) {
}
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
Status PrePack(const Tensor& tensor, int input_idx, bool& is_packed) override;
#endif
protected:
Status ComputeCommon(OpKernelContext* ctx,
const uint8_t* a_data,
@ -32,46 +26,8 @@ class MatMulIntegerToFloatBase : public OpKernel {
uint8_t b_zero_point,
float multiplier,
const Tensor* bias_tensor) const;
bool b_is_signed_;
TensorShape b_shape_;
BufferUniquePtr packed_b_;
};
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
Status MatMulIntegerToFloatBase::PrePack(const Tensor& tensor, int input_idx, bool& is_packed) {
is_packed = false;
// only pack Matrix B
if (input_idx == 1) {
// Only handle the common case of a 2D weight matrix. Additional matrices
// could be handled by stacking the packed buffers.
b_shape_ = tensor.Shape();
if (b_shape_.NumDimensions() != 2) {
return Status::OK();
}
const size_t K = static_cast<size_t>(b_shape_[0]);
const size_t N = static_cast<size_t>(b_shape_[1]);
const auto* b_data = static_cast<const uint8_t*>(tensor.DataRaw());
b_is_signed_ = tensor.IsDataType<int8_t>();
const size_t packed_b_size = MlasGemmPackBSize(N, K, b_is_signed_);
if (packed_b_size == 0) {
return Status::OK();
}
auto alloc = Info().GetAllocator(0, OrtMemTypeDefault);
auto* packed_b_data = alloc->Alloc(packed_b_size);
packed_b_ = BufferUniquePtr(packed_b_data, BufferDeleter(alloc));
MlasGemmPackB(N, K, b_data, N, b_is_signed_, packed_b_data);
is_packed = true;
}
return Status::OK();
}
#endif
Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx,
const uint8_t* a_data,
const TensorShape& a_shape,

View file

@ -1,17 +1,17 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/framework/op_kernel.h"
#include "matmul_integer_base.h"
#include "core/providers/cpu/math/matmul_helper.h"
#include "core/util/math_cpuonly.h"
#include "core/util/qmath.h"
#include "core/providers/common.h"
namespace onnxruntime {
class MatMulInteger final : public OpKernel {
class MatMulInteger final : public MatMulIntegerBase {
public:
MatMulInteger(const OpKernelInfo& info) : OpKernel(info) {}
MatMulInteger(const OpKernelInfo& info) : MatMulIntegerBase(info) {}
Status Compute(OpKernelContext* context) const override;
};
@ -32,10 +32,10 @@ Status MatMulInteger::Compute(OpKernelContext* ctx) const {
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
const auto* a = ctx->Input<Tensor>(0);
const auto* b = ctx->Input<Tensor>(1);
const Tensor* b = packed_b_ ? nullptr : ctx->Input<Tensor>(1);
MatMulComputeHelper helper;
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b->Shape()));
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), packed_b_ ? b_shape_ : b->Shape()));
Tensor* y = ctx->Output(0, helper.OutputShape());
// Bail out early if the output is going to be empty
@ -59,11 +59,28 @@ Status MatMulInteger::Compute(OpKernelContext* ctx) const {
}
const auto* a_data = a->template Data<uint8_t>();
const auto* b_data = static_cast<const uint8_t*>(b->DataRaw());
const bool b_is_signed = b->IsDataType<int8_t>();
auto* y_data = y->template MutableData<int32_t>();
for (size_t i = 0; i < helper.OutputOffsets().size(); i++) {
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
if (packed_b_) {
MlasGemm(static_cast<size_t>(helper.M()),
static_cast<size_t>(helper.N()),
static_cast<size_t>(helper.K()),
a_data + helper.LeftOffsets()[i],
static_cast<size_t>(helper.K()),
a_offset,
packed_b_.get(),
b_offset,
b_is_signed_,
y_data + helper.OutputOffsets()[i],
static_cast<size_t>(helper.N()),
thread_pool);
continue;
}
#endif
const auto* b_data = static_cast<const uint8_t*>(b->DataRaw());
const bool b_is_signed = b->IsDataType<int8_t>();
QGemm(static_cast<int>(helper.M()),
static_cast<int>(helper.N()),
static_cast<int>(helper.K()),

View file

@ -0,0 +1,54 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/framework/op_kernel.h"
#include "core/mlas/inc/mlas.h"
#include "core/providers/common.h"
namespace onnxruntime {
class MatMulIntegerBase : public OpKernel {
public:
MatMulIntegerBase(const OpKernelInfo& info) : OpKernel(info) {}
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
Status PrePack(const Tensor& tensor, int input_idx, bool& is_packed) override {
is_packed = false;
// only pack Matrix B
if (input_idx == 1) {
// Only handle the common case of a 2D weight matrix. Additional matrices
// could be handled by stacking the packed buffers.
b_shape_ = tensor.Shape();
if (b_shape_.NumDimensions() != 2) {
return Status::OK();
}
const size_t K = static_cast<size_t>(b_shape_[0]);
const size_t N = static_cast<size_t>(b_shape_[1]);
const auto* b_data = static_cast<const uint8_t*>(tensor.DataRaw());
b_is_signed_ = tensor.IsDataType<int8_t>();
const size_t packed_b_size = MlasGemmPackBSize(N, K, b_is_signed_);
if (packed_b_size == 0) {
return Status::OK();
}
auto alloc = Info().GetAllocator(0, OrtMemTypeDefault);
auto* packed_b_data = alloc->Alloc(packed_b_size);
packed_b_ = BufferUniquePtr(packed_b_data, BufferDeleter(alloc));
MlasGemmPackB(N, K, b_data, N, b_is_signed_, packed_b_data);
is_packed = true;
}
return Status::OK();
}
#endif
protected:
bool b_is_signed_;
TensorShape b_shape_;
BufferUniquePtr packed_b_;
};
} // namespace onnxruntime

View file

@ -273,11 +273,12 @@ T GetMiddle(const std::vector<T>& v) {
}
// [M x N] = [M x K] x [K x N] = [batch_seq x input_dim] x [input_dim x embed_dim]
void RunMatMulIntegerU8S8Test(const int M, const int N, const int K, bool non_zero_zp) {
template <typename ScalarB>
void RunMatMulIntegerU8X8Test(const int M, const int N, const int K, bool non_zero_zp, bool B_is_initializer) {
OpTester test("MatMulInteger", 10);
static std::default_random_engine e(123);
static std::uniform_int_distribution<int> n_unsigned(0, 127);
static std::uniform_int_distribution<int> n_signed(-128, 127);
static std::uniform_int_distribution<int> n_xint8(std::numeric_limits<ScalarB>::min(), std::numeric_limits<ScalarB>::max());
Eigen::MatrixXi matrix_a = Eigen::MatrixXi::Random(K, M)
.unaryExpr([](int) { return n_unsigned(e); });
@ -286,18 +287,18 @@ void RunMatMulIntegerU8S8Test(const int M, const int N, const int K, bool non_ze
Eigen::MatrixXi matrix_a_offset = matrix_a - a_zero_point * Eigen::MatrixXi::Ones(K, M);
Eigen::MatrixXi matrix_b = Eigen::MatrixXi::Random(N, K)
.unaryExpr([](int) { return n_signed(e); });
std::vector<int8_t> matrix_b_data = ToVector<int8_t>(matrix_b.data(), N * K);
int8_t b_zero_point = non_zero_zp ? GetMiddle(matrix_b_data) : 0;
.unaryExpr([](int) { return n_xint8(e); });
std::vector<ScalarB> matrix_b_data = ToVector<ScalarB>(matrix_b.data(), N * K);
ScalarB b_zero_point = non_zero_zp ? GetMiddle(matrix_b_data) : 0;
Eigen::MatrixXi matrix_b_offset = matrix_b - b_zero_point * Eigen::MatrixXi::Ones(N, K);
Eigen::MatrixXi matrix_c = (matrix_b_offset * matrix_a_offset).eval();
test.AddInput<uint8_t>("T1", {M, K}, std::move(matrix_a_data));
test.AddInput<int8_t>("T2", {K, N}, std::move(matrix_b_data), true /*is_initializer*/);
test.AddInput<ScalarB>("T2", {K, N}, std::move(matrix_b_data), B_is_initializer);
if (non_zero_zp) {
test.AddInput<uint8_t>("a_zero_point", {}, {a_zero_point});
test.AddInput<int8_t>("b_zero_point", {}, {b_zero_point});
test.AddInput<ScalarB>("b_zero_point", {}, {b_zero_point});
}
test.AddOutput<int32_t>("T3", {M, N}, ToVector<int32_t>(matrix_c.data(), M * N));
@ -311,30 +312,36 @@ void RunMatMulIntegerU8S8Test(const int M, const int N, const int K, bool non_ze
}
}
#define RUN_MATMUL_INTEGER_U8S8(M, N, K) \
RunMatMulIntegerU8S8Test(M, N, K, false /*non_zero_zp*/); \
RunMatMulIntegerU8S8Test(M, N, K, true /*non_zero_zp*/);
#define RUN_MATMUL_INTEGER_U8X8(M, N, K) \
RunMatMulIntegerU8X8Test<int8_t>(M, N, K, false /*non_zero_zp*/, false /*B_is_initializer*/); \
RunMatMulIntegerU8X8Test<int8_t>(M, N, K, false /*non_zero_zp*/, true /*B_is_initializer*/); \
RunMatMulIntegerU8X8Test<int8_t>(M, N, K, true /*non_zero_zp*/, false /*B_is_initializer*/); \
RunMatMulIntegerU8X8Test<int8_t>(M, N, K, true /*non_zero_zp*/, true /*B_is_initializer*/); \
RunMatMulIntegerU8X8Test<uint8_t>(M, N, K, false /*non_zero_zp*/, false /*B_is_initializer*/); \
RunMatMulIntegerU8X8Test<uint8_t>(M, N, K, false /*non_zero_zp*/, true /*B_is_initializer*/); \
RunMatMulIntegerU8X8Test<uint8_t>(M, N, K, true /*non_zero_zp*/, false /*B_is_initializer*/); \
RunMatMulIntegerU8X8Test<uint8_t>(M, N, K, true /*non_zero_zp*/, true /*B_is_initializer*/);
TEST(MatmulIntegerOpTest, MatMulInteger_Uint8_Int8_Scalar) {
RUN_MATMUL_INTEGER_U8S8(1, 1, 32);
RUN_MATMUL_INTEGER_U8S8(1, 1, 260);
RUN_MATMUL_INTEGER_U8S8(1, 1, 288);
RUN_MATMUL_INTEGER_U8X8(1, 1, 32);
RUN_MATMUL_INTEGER_U8X8(1, 1, 260);
RUN_MATMUL_INTEGER_U8X8(1, 1, 288);
}
TEST(MatmulIntegerOpTest, MatMulInteger_Uint8_Int8_GEMV) {
RUN_MATMUL_INTEGER_U8S8(1, 2, 16);
RUN_MATMUL_INTEGER_U8S8(1, 2, 64);
RUN_MATMUL_INTEGER_U8S8(1, 8, 36);
RUN_MATMUL_INTEGER_U8S8(1, 8, 68);
RUN_MATMUL_INTEGER_U8S8(1, 8, 400);
RUN_MATMUL_INTEGER_U8S8(1, 512, 1024);
RUN_MATMUL_INTEGER_U8X8(1, 2, 16);
RUN_MATMUL_INTEGER_U8X8(1, 2, 64);
RUN_MATMUL_INTEGER_U8X8(1, 8, 36);
RUN_MATMUL_INTEGER_U8X8(1, 8, 68);
RUN_MATMUL_INTEGER_U8X8(1, 8, 400);
RUN_MATMUL_INTEGER_U8X8(1, 512, 1024);
}
TEST(MatmulIntegerOpTest, MatMulInteger_Uint8_Int8_GEMM) {
RUN_MATMUL_INTEGER_U8S8(2, 2, 40);
RUN_MATMUL_INTEGER_U8S8(2, 48, 33);
RUN_MATMUL_INTEGER_U8S8(2, 51, 40);
RUN_MATMUL_INTEGER_U8S8(4, 8, 68);
RUN_MATMUL_INTEGER_U8X8(2, 2, 40);
RUN_MATMUL_INTEGER_U8X8(2, 48, 33);
RUN_MATMUL_INTEGER_U8X8(2, 51, 40);
RUN_MATMUL_INTEGER_U8X8(4, 8, 68);
}
} // namespace test