mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-20 21:40:57 +00:00
Prepacking MatMulInteger (#5403)
* prepack matmulinteger Prepacking constant matrix B for MatMulInteger to get better performance.
This commit is contained in:
parent
621fdb44e5
commit
b99eaa99cd
4 changed files with 112 additions and 78 deletions
|
|
@ -1,28 +1,22 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/common/safeint.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/cpu/math/matmul_helper.h"
|
||||
#include "core/providers/cpu/math/matmul_integer_base.h"
|
||||
#include "core/util/math_cpuonly.h"
|
||||
#include "core/util/qmath.h"
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
class MatMulIntegerToFloatBase : public OpKernel {
|
||||
class MatMulIntegerToFloatBase : public MatMulIntegerBase {
|
||||
public:
|
||||
MatMulIntegerToFloatBase(const OpKernelInfo& info) : OpKernel(info) {
|
||||
MatMulIntegerToFloatBase(const OpKernelInfo& info) : MatMulIntegerBase(info) {
|
||||
}
|
||||
|
||||
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
|
||||
Status PrePack(const Tensor& tensor, int input_idx, bool& is_packed) override;
|
||||
#endif
|
||||
|
||||
protected:
|
||||
Status ComputeCommon(OpKernelContext* ctx,
|
||||
const uint8_t* a_data,
|
||||
|
|
@ -32,46 +26,8 @@ class MatMulIntegerToFloatBase : public OpKernel {
|
|||
uint8_t b_zero_point,
|
||||
float multiplier,
|
||||
const Tensor* bias_tensor) const;
|
||||
|
||||
bool b_is_signed_;
|
||||
TensorShape b_shape_;
|
||||
BufferUniquePtr packed_b_;
|
||||
};
|
||||
|
||||
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
|
||||
Status MatMulIntegerToFloatBase::PrePack(const Tensor& tensor, int input_idx, bool& is_packed) {
|
||||
is_packed = false;
|
||||
|
||||
// only pack Matrix B
|
||||
if (input_idx == 1) {
|
||||
// Only handle the common case of a 2D weight matrix. Additional matrices
|
||||
// could be handled by stacking the packed buffers.
|
||||
b_shape_ = tensor.Shape();
|
||||
if (b_shape_.NumDimensions() != 2) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const size_t K = static_cast<size_t>(b_shape_[0]);
|
||||
const size_t N = static_cast<size_t>(b_shape_[1]);
|
||||
|
||||
const auto* b_data = static_cast<const uint8_t*>(tensor.DataRaw());
|
||||
b_is_signed_ = tensor.IsDataType<int8_t>();
|
||||
|
||||
const size_t packed_b_size = MlasGemmPackBSize(N, K, b_is_signed_);
|
||||
if (packed_b_size == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
auto alloc = Info().GetAllocator(0, OrtMemTypeDefault);
|
||||
auto* packed_b_data = alloc->Alloc(packed_b_size);
|
||||
packed_b_ = BufferUniquePtr(packed_b_data, BufferDeleter(alloc));
|
||||
MlasGemmPackB(N, K, b_data, N, b_is_signed_, packed_b_data);
|
||||
is_packed = true;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx,
|
||||
const uint8_t* a_data,
|
||||
const TensorShape& a_shape,
|
||||
|
|
|
|||
|
|
@ -1,17 +1,17 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "matmul_integer_base.h"
|
||||
|
||||
#include "core/providers/cpu/math/matmul_helper.h"
|
||||
#include "core/util/math_cpuonly.h"
|
||||
#include "core/util/qmath.h"
|
||||
#include "core/providers/common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
class MatMulInteger final : public OpKernel {
|
||||
class MatMulInteger final : public MatMulIntegerBase {
|
||||
public:
|
||||
MatMulInteger(const OpKernelInfo& info) : OpKernel(info) {}
|
||||
MatMulInteger(const OpKernelInfo& info) : MatMulIntegerBase(info) {}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
};
|
||||
|
|
@ -32,10 +32,10 @@ Status MatMulInteger::Compute(OpKernelContext* ctx) const {
|
|||
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
|
||||
|
||||
const auto* a = ctx->Input<Tensor>(0);
|
||||
const auto* b = ctx->Input<Tensor>(1);
|
||||
const Tensor* b = packed_b_ ? nullptr : ctx->Input<Tensor>(1);
|
||||
|
||||
MatMulComputeHelper helper;
|
||||
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b->Shape()));
|
||||
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), packed_b_ ? b_shape_ : b->Shape()));
|
||||
Tensor* y = ctx->Output(0, helper.OutputShape());
|
||||
|
||||
// Bail out early if the output is going to be empty
|
||||
|
|
@ -59,11 +59,28 @@ Status MatMulInteger::Compute(OpKernelContext* ctx) const {
|
|||
}
|
||||
|
||||
const auto* a_data = a->template Data<uint8_t>();
|
||||
const auto* b_data = static_cast<const uint8_t*>(b->DataRaw());
|
||||
const bool b_is_signed = b->IsDataType<int8_t>();
|
||||
auto* y_data = y->template MutableData<int32_t>();
|
||||
|
||||
for (size_t i = 0; i < helper.OutputOffsets().size(); i++) {
|
||||
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
|
||||
if (packed_b_) {
|
||||
MlasGemm(static_cast<size_t>(helper.M()),
|
||||
static_cast<size_t>(helper.N()),
|
||||
static_cast<size_t>(helper.K()),
|
||||
a_data + helper.LeftOffsets()[i],
|
||||
static_cast<size_t>(helper.K()),
|
||||
a_offset,
|
||||
packed_b_.get(),
|
||||
b_offset,
|
||||
b_is_signed_,
|
||||
y_data + helper.OutputOffsets()[i],
|
||||
static_cast<size_t>(helper.N()),
|
||||
thread_pool);
|
||||
continue;
|
||||
}
|
||||
#endif
|
||||
const auto* b_data = static_cast<const uint8_t*>(b->DataRaw());
|
||||
const bool b_is_signed = b->IsDataType<int8_t>();
|
||||
QGemm(static_cast<int>(helper.M()),
|
||||
static_cast<int>(helper.N()),
|
||||
static_cast<int>(helper.K()),
|
||||
|
|
|
|||
54
onnxruntime/core/providers/cpu/math/matmul_integer_base.h
Normal file
54
onnxruntime/core/providers/cpu/math/matmul_integer_base.h
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
#include "core/providers/common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
class MatMulIntegerBase : public OpKernel {
|
||||
public:
|
||||
MatMulIntegerBase(const OpKernelInfo& info) : OpKernel(info) {}
|
||||
|
||||
#ifdef MLAS_SUPPORTS_PACKED_GEMM_U8X8
|
||||
Status PrePack(const Tensor& tensor, int input_idx, bool& is_packed) override {
|
||||
is_packed = false;
|
||||
|
||||
// only pack Matrix B
|
||||
if (input_idx == 1) {
|
||||
// Only handle the common case of a 2D weight matrix. Additional matrices
|
||||
// could be handled by stacking the packed buffers.
|
||||
b_shape_ = tensor.Shape();
|
||||
if (b_shape_.NumDimensions() != 2) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const size_t K = static_cast<size_t>(b_shape_[0]);
|
||||
const size_t N = static_cast<size_t>(b_shape_[1]);
|
||||
|
||||
const auto* b_data = static_cast<const uint8_t*>(tensor.DataRaw());
|
||||
b_is_signed_ = tensor.IsDataType<int8_t>();
|
||||
|
||||
const size_t packed_b_size = MlasGemmPackBSize(N, K, b_is_signed_);
|
||||
if (packed_b_size == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
auto alloc = Info().GetAllocator(0, OrtMemTypeDefault);
|
||||
auto* packed_b_data = alloc->Alloc(packed_b_size);
|
||||
packed_b_ = BufferUniquePtr(packed_b_data, BufferDeleter(alloc));
|
||||
MlasGemmPackB(N, K, b_data, N, b_is_signed_, packed_b_data);
|
||||
is_packed = true;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
protected:
|
||||
bool b_is_signed_;
|
||||
TensorShape b_shape_;
|
||||
BufferUniquePtr packed_b_;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -273,11 +273,12 @@ T GetMiddle(const std::vector<T>& v) {
|
|||
}
|
||||
|
||||
// [M x N] = [M x K] x [K x N] = [batch_seq x input_dim] x [input_dim x embed_dim]
|
||||
void RunMatMulIntegerU8S8Test(const int M, const int N, const int K, bool non_zero_zp) {
|
||||
template <typename ScalarB>
|
||||
void RunMatMulIntegerU8X8Test(const int M, const int N, const int K, bool non_zero_zp, bool B_is_initializer) {
|
||||
OpTester test("MatMulInteger", 10);
|
||||
static std::default_random_engine e(123);
|
||||
static std::uniform_int_distribution<int> n_unsigned(0, 127);
|
||||
static std::uniform_int_distribution<int> n_signed(-128, 127);
|
||||
static std::uniform_int_distribution<int> n_xint8(std::numeric_limits<ScalarB>::min(), std::numeric_limits<ScalarB>::max());
|
||||
|
||||
Eigen::MatrixXi matrix_a = Eigen::MatrixXi::Random(K, M)
|
||||
.unaryExpr([](int) { return n_unsigned(e); });
|
||||
|
|
@ -286,18 +287,18 @@ void RunMatMulIntegerU8S8Test(const int M, const int N, const int K, bool non_ze
|
|||
Eigen::MatrixXi matrix_a_offset = matrix_a - a_zero_point * Eigen::MatrixXi::Ones(K, M);
|
||||
|
||||
Eigen::MatrixXi matrix_b = Eigen::MatrixXi::Random(N, K)
|
||||
.unaryExpr([](int) { return n_signed(e); });
|
||||
std::vector<int8_t> matrix_b_data = ToVector<int8_t>(matrix_b.data(), N * K);
|
||||
int8_t b_zero_point = non_zero_zp ? GetMiddle(matrix_b_data) : 0;
|
||||
.unaryExpr([](int) { return n_xint8(e); });
|
||||
std::vector<ScalarB> matrix_b_data = ToVector<ScalarB>(matrix_b.data(), N * K);
|
||||
ScalarB b_zero_point = non_zero_zp ? GetMiddle(matrix_b_data) : 0;
|
||||
Eigen::MatrixXi matrix_b_offset = matrix_b - b_zero_point * Eigen::MatrixXi::Ones(N, K);
|
||||
|
||||
Eigen::MatrixXi matrix_c = (matrix_b_offset * matrix_a_offset).eval();
|
||||
|
||||
test.AddInput<uint8_t>("T1", {M, K}, std::move(matrix_a_data));
|
||||
test.AddInput<int8_t>("T2", {K, N}, std::move(matrix_b_data), true /*is_initializer*/);
|
||||
test.AddInput<ScalarB>("T2", {K, N}, std::move(matrix_b_data), B_is_initializer);
|
||||
if (non_zero_zp) {
|
||||
test.AddInput<uint8_t>("a_zero_point", {}, {a_zero_point});
|
||||
test.AddInput<int8_t>("b_zero_point", {}, {b_zero_point});
|
||||
test.AddInput<ScalarB>("b_zero_point", {}, {b_zero_point});
|
||||
}
|
||||
|
||||
test.AddOutput<int32_t>("T3", {M, N}, ToVector<int32_t>(matrix_c.data(), M * N));
|
||||
|
|
@ -311,30 +312,36 @@ void RunMatMulIntegerU8S8Test(const int M, const int N, const int K, bool non_ze
|
|||
}
|
||||
}
|
||||
|
||||
#define RUN_MATMUL_INTEGER_U8S8(M, N, K) \
|
||||
RunMatMulIntegerU8S8Test(M, N, K, false /*non_zero_zp*/); \
|
||||
RunMatMulIntegerU8S8Test(M, N, K, true /*non_zero_zp*/);
|
||||
#define RUN_MATMUL_INTEGER_U8X8(M, N, K) \
|
||||
RunMatMulIntegerU8X8Test<int8_t>(M, N, K, false /*non_zero_zp*/, false /*B_is_initializer*/); \
|
||||
RunMatMulIntegerU8X8Test<int8_t>(M, N, K, false /*non_zero_zp*/, true /*B_is_initializer*/); \
|
||||
RunMatMulIntegerU8X8Test<int8_t>(M, N, K, true /*non_zero_zp*/, false /*B_is_initializer*/); \
|
||||
RunMatMulIntegerU8X8Test<int8_t>(M, N, K, true /*non_zero_zp*/, true /*B_is_initializer*/); \
|
||||
RunMatMulIntegerU8X8Test<uint8_t>(M, N, K, false /*non_zero_zp*/, false /*B_is_initializer*/); \
|
||||
RunMatMulIntegerU8X8Test<uint8_t>(M, N, K, false /*non_zero_zp*/, true /*B_is_initializer*/); \
|
||||
RunMatMulIntegerU8X8Test<uint8_t>(M, N, K, true /*non_zero_zp*/, false /*B_is_initializer*/); \
|
||||
RunMatMulIntegerU8X8Test<uint8_t>(M, N, K, true /*non_zero_zp*/, true /*B_is_initializer*/);
|
||||
|
||||
TEST(MatmulIntegerOpTest, MatMulInteger_Uint8_Int8_Scalar) {
|
||||
RUN_MATMUL_INTEGER_U8S8(1, 1, 32);
|
||||
RUN_MATMUL_INTEGER_U8S8(1, 1, 260);
|
||||
RUN_MATMUL_INTEGER_U8S8(1, 1, 288);
|
||||
RUN_MATMUL_INTEGER_U8X8(1, 1, 32);
|
||||
RUN_MATMUL_INTEGER_U8X8(1, 1, 260);
|
||||
RUN_MATMUL_INTEGER_U8X8(1, 1, 288);
|
||||
}
|
||||
|
||||
TEST(MatmulIntegerOpTest, MatMulInteger_Uint8_Int8_GEMV) {
|
||||
RUN_MATMUL_INTEGER_U8S8(1, 2, 16);
|
||||
RUN_MATMUL_INTEGER_U8S8(1, 2, 64);
|
||||
RUN_MATMUL_INTEGER_U8S8(1, 8, 36);
|
||||
RUN_MATMUL_INTEGER_U8S8(1, 8, 68);
|
||||
RUN_MATMUL_INTEGER_U8S8(1, 8, 400);
|
||||
RUN_MATMUL_INTEGER_U8S8(1, 512, 1024);
|
||||
RUN_MATMUL_INTEGER_U8X8(1, 2, 16);
|
||||
RUN_MATMUL_INTEGER_U8X8(1, 2, 64);
|
||||
RUN_MATMUL_INTEGER_U8X8(1, 8, 36);
|
||||
RUN_MATMUL_INTEGER_U8X8(1, 8, 68);
|
||||
RUN_MATMUL_INTEGER_U8X8(1, 8, 400);
|
||||
RUN_MATMUL_INTEGER_U8X8(1, 512, 1024);
|
||||
}
|
||||
|
||||
TEST(MatmulIntegerOpTest, MatMulInteger_Uint8_Int8_GEMM) {
|
||||
RUN_MATMUL_INTEGER_U8S8(2, 2, 40);
|
||||
RUN_MATMUL_INTEGER_U8S8(2, 48, 33);
|
||||
RUN_MATMUL_INTEGER_U8S8(2, 51, 40);
|
||||
RUN_MATMUL_INTEGER_U8S8(4, 8, 68);
|
||||
RUN_MATMUL_INTEGER_U8X8(2, 2, 40);
|
||||
RUN_MATMUL_INTEGER_U8X8(2, 48, 33);
|
||||
RUN_MATMUL_INTEGER_U8X8(2, 51, 40);
|
||||
RUN_MATMUL_INTEGER_U8X8(4, 8, 68);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
|
|
|
|||
Loading…
Reference in a new issue