mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-14 01:13:40 +00:00
Description: ### Description This is a partial change ported from fajin/qdqmatmulnbitstoolchain. That branch has issues resolving the web CI. MatMulNBits is a heavily optimized matmul operation. Currently a MatMul can be converted to MatMulNBits to speed up the model inference. However, MatMulNBits is an ORT only op. To make the graph compatible with ONNX ops and utilize MatMulNBits at the same time, we introduce Q/DQ support for MatMulNBits. To convert MatMul ops in a model to MatMulNBits: 1. use matmul_4bits_quantizer.py to convert MatMul to DQ + MatMul using QDQ mode. 2. In ORT session, DQ + MatMul is fused to MatMulNBits #### Note MatMulNBits assume B weight is uint4. When no zp is provided, zp defaults to 8, which is different from DQ. DQ defaults zp to 0 when no zp provided. And DQ supports int4. Therefore some conversions are introduced during DQ + MatMul --> MatMulNBits step. #### Perf Using QDQ format will increase the model initialization time and memory consumption. With current implement, model init time increased from ~4s to ~9s, and memory consumption increased from ~2.8GB to ~4.8GB. The memory increase is due to 1. in optimizer, after transpose the B weight, a in-memory tensor proto is created using protobuf's arena. 2. in finalize step, when saving initializer and prepacking, ORT arena is used to create buffers for initializers. The memory allocated by arenas cannot be fully deallocated. If disable ORT arena memory allocation, the memory consumptions of both QDQ format and original format are ~2.2GB. The time increase is mainly due to multiple memory copy, but can be further optimized. ### Motivation and Context Please see description for details.
138 lines
4.5 KiB
C++
138 lines
4.5 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#include <pybind11/pybind11.h>
|
|
#include <pybind11/numpy.h>
|
|
#include <pybind11/functional.h>
|
|
|
|
#include "core/mlas/inc/mlas_q4.h"
|
|
#include "contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h"
|
|
#include "core/util/thread_utils.h"
|
|
|
|
namespace pybind11 {
|
|
namespace detail {
|
|
// python3 -c 'import numpy as np; print(np.dtype(np.float16).num)'
|
|
constexpr int NPY_FLOAT16 = 23;
|
|
template <>
|
|
struct npy_format_descriptor<onnxruntime::MLFloat16> {
|
|
static constexpr auto name = _("float16");
|
|
static pybind11::dtype dtype() {
|
|
handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16);
|
|
return reinterpret_borrow<pybind11::dtype>(ptr);
|
|
}
|
|
static std::string format() {
|
|
// following: https://docs.python.org/3/library/struct.html#format-characters
|
|
return "e";
|
|
}
|
|
};
|
|
} // namespace detail
|
|
} // namespace pybind11
|
|
|
|
namespace onnxruntime {
|
|
namespace python {
|
|
|
|
namespace py = pybind11;
|
|
using namespace onnxruntime;
|
|
|
|
template <typename T>
|
|
void QuantizeMatMul4BitsBlockwise(
|
|
py::array_t<uint8_t> dst, // shape: [ N, block_per_K, block_blob_size ]
|
|
py::array_t<T> src, // shape: [K, N]
|
|
py::array_t<T> scale, // shape: [N, block_per_K]
|
|
py::array_t<uint8_t> zero_points, // shape: [N, block_per_K] if bits > 4 else [N, (block_per_K + 1) / 2]
|
|
int32_t block_size,
|
|
int32_t N,
|
|
int32_t K,
|
|
bool is_symmetric) {
|
|
OrtThreadPoolParams to;
|
|
auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to,
|
|
concurrency::ThreadPoolType::INTRA_OP);
|
|
|
|
py::buffer_info dst_buf = dst.request();
|
|
py::buffer_info src_buf = src.request();
|
|
py::buffer_info scale_buf = scale.request();
|
|
py::buffer_info zp_buf = zero_points.request();
|
|
|
|
MlasQuantizeBlockwise<T, 4>(
|
|
reinterpret_cast<uint8_t*>(dst_buf.ptr),
|
|
reinterpret_cast<T*>(scale_buf.ptr),
|
|
is_symmetric ? nullptr : reinterpret_cast<uint8_t*>(zp_buf.ptr),
|
|
reinterpret_cast<const T*>(src_buf.ptr),
|
|
block_size,
|
|
true,
|
|
K,
|
|
N,
|
|
N,
|
|
tp.get());
|
|
}
|
|
|
|
template <typename T>
|
|
bool QuantizeQDQMatMul4BitsBlockwise(
|
|
py::array_t<uint8_t> dst, // shape: [K, N / 2]
|
|
py::array_t<T> src, // shape: [K, N]
|
|
py::array_t<T> scale, // shape: [block_per_K, N]
|
|
py::array_t<uint8_t> zero_points, // shape: [block_per_K, N / 2]
|
|
int32_t quant_block_size,
|
|
int32_t N,
|
|
int32_t K,
|
|
bool is_symmetric) {
|
|
OrtThreadPoolParams to;
|
|
auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to,
|
|
concurrency::ThreadPoolType::INTRA_OP);
|
|
|
|
py::buffer_info dst_buf = dst.request();
|
|
py::buffer_info src_buf = src.request();
|
|
py::buffer_info scale_buf = scale.request();
|
|
py::buffer_info zp_buf = zero_points.request();
|
|
|
|
return MlasQDQQuantizeBlockwise<T, 4>(
|
|
reinterpret_cast<const T*>(src_buf.ptr),
|
|
reinterpret_cast<T*>(scale_buf.ptr),
|
|
is_symmetric ? nullptr : reinterpret_cast<uint8_t*>(zp_buf.ptr),
|
|
reinterpret_cast<uint8_t*>(dst_buf.ptr),
|
|
true,
|
|
K,
|
|
N,
|
|
quant_block_size,
|
|
tp.get());
|
|
}
|
|
|
|
template <typename T>
|
|
void QuantizeMatMulBnb4Blockwise(
|
|
py::array_t<uint8_t> dst,
|
|
py::array_t<T> src,
|
|
py::array_t<T> absmax,
|
|
int32_t block_size,
|
|
int32_t quant_type,
|
|
int32_t N,
|
|
int32_t K) {
|
|
OrtThreadPoolParams to;
|
|
auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to,
|
|
concurrency::ThreadPoolType::INTRA_OP);
|
|
|
|
py::buffer_info dst_buf = dst.request();
|
|
py::buffer_info src_buf = src.request();
|
|
py::buffer_info absmax_buf = absmax.request();
|
|
|
|
contrib::QuantizeBlockwiseBnb4<T>(
|
|
static_cast<uint8_t*>(dst_buf.ptr),
|
|
static_cast<const T*>(src_buf.ptr),
|
|
static_cast<T*>(absmax_buf.ptr),
|
|
block_size,
|
|
quant_type,
|
|
N,
|
|
K,
|
|
tp.get());
|
|
}
|
|
|
|
void CreateQuantPybindModule(py::module& m) {
|
|
m.def("quantize_matmul_4bits", &QuantizeMatMul4BitsBlockwise<float>);
|
|
m.def("quantize_matmul_4bits", &QuantizeMatMul4BitsBlockwise<MLFloat16>);
|
|
m.def("quantize_matmul_bnb4", &QuantizeMatMulBnb4Blockwise<float>);
|
|
m.def("quantize_matmul_bnb4", &QuantizeMatMulBnb4Blockwise<MLFloat16>);
|
|
m.def("quantize_qdq_matmul_4bits", &QuantizeQDQMatMul4BitsBlockwise<float>);
|
|
m.def("quantize_qdq_matmul_4bits", &QuantizeQDQMatMul4BitsBlockwise<MLFloat16>);
|
|
}
|
|
|
|
} // namespace python
|
|
} // namespace onnxruntime
|