Add MatMul FP4 and NF4 Support (#18066)

### Description
Add a contrib op MatMulBnb4 (FP4 and NF4) and related toolchain to
support quantization on weight.

This PR adds:
- schema for contrib op MatMulBnb4 which can support FP4 (4-bit floating
point) and NF4 (4-bit NormalFloat) quantization on weight.
- a naive implementation for MatMulBnb4 on CPU and GPU, i.e.,
implemented like MatMul(A, Dequantize(B)).
- a special implementation for GemV for MatMulBnb4 and related benchmark
tool.
- tool to quantize model to FP4 or NF4.
This commit is contained in:
Jambay Kinley 2023-10-25 15:34:58 -07:00 committed by GitHub
parent d88d52eead
commit d30d4d372a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 2236 additions and 0 deletions

View file

@ -54,6 +54,11 @@ set(contrib_ops_excluded_files
"quantization/attention_quantization_impl.cuh"
"quantization/dequantize_blockwise.cuh"
"quantization/dequantize_blockwise.cu"
"quantization/dequantize_blockwise_bnb4.cuh"
"quantization/dequantize_blockwise_bnb4.cu"
"quantization/matmul_bnb4.cc"
"quantization/matmul_bnb4.cuh"
"quantization/matmul_bnb4.cu"
"quantization/matmul_nbits.cc"
"quantization/matmul_nbits.cuh"
"quantization/matmul_nbits.cu"

View file

@ -47,6 +47,7 @@ Do not modify directly.*
* <a href="#com.microsoft.Inverse">com.microsoft.Inverse</a>
* <a href="#com.microsoft.Irfft">com.microsoft.Irfft</a>
* <a href="#com.microsoft.LongformerAttention">com.microsoft.LongformerAttention</a>
* <a href="#com.microsoft.MatMulBnb4">com.microsoft.MatMulBnb4</a>
* <a href="#com.microsoft.MatMulFpQ4">com.microsoft.MatMulFpQ4</a>
* <a href="#com.microsoft.MatMulInteger16">com.microsoft.MatMulInteger16</a>
* <a href="#com.microsoft.MatMulIntegerToFloat">com.microsoft.MatMulIntegerToFloat</a>
@ -2504,6 +2505,62 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>
### <a name="com.microsoft.MatMulBnb4"></a><a name="com.microsoft.matmulbnb4">**com.microsoft.MatMulBnb4**</a>
MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 data type (https://arxiv.org/pdf/2305.14314.pdf). It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences:
1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'.
2. Input B is quantized with 4 bits with quantization data type specified by attribute 'quant_type'. It is transposed, flattened and quantized blockwisely with block size specified by attribute 'block_size'.
And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,..
3. Input B's quantization constants or scales are specified by input 'absmax'.
Input B is stored as uint8_t with shape: [(N * K + 1) / 2].
Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size].
#### Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
#### Attributes
<dl>
<dt><tt>K</tt> : int (required)</dt>
<dd>size of each input feature</dd>
<dt><tt>N</tt> : int (required)</dt>
<dd>size of each output feature</dd>
<dt><tt>block_size</tt> : int (required)</dt>
<dd>number of groupsize used for weight quantization. It needs to be a power of 2 and not smaller than 16.</dd>
<dt><tt>quant_type</tt> : int (required)</dt>
<dd>quantization data type. 0 for FP4, 1 for NF4.</dd>
</dl>
#### Inputs
<dl>
<dt><tt>A</tt> : T1</dt>
<dd>The input tensor, not quantized</dd>
<dt><tt>B</tt> : T2</dt>
<dd>1-dimensional quantized data for weight</dd>
<dt><tt>absmax</tt> : T1</dt>
<dd>quantization constants</dd>
</dl>
#### Outputs
<dl>
<dt><tt>Y</tt> : T1</dt>
<dd>tensor. The output tensor has the same rank as the input. </dd>
</dl>
#### Type Constraints
<dl>
<dt><tt>T1</tt> : tensor(float), tensor(float16)</dt>
<dd>Constrain input and output types to float/half_float tensors.</dd>
<dt><tt>T2</tt> : tensor(uint8)</dt>
<dd>Constrain quantized weight types to uint8.</dd>
</dl>
### <a name="com.microsoft.MatMulFpQ4"></a><a name="com.microsoft.matmulfpq4">**com.microsoft.MatMulFpQ4**</a>
Matrix product with right hand matrix being pre-packed and quantized int4 data blob.

View file

@ -457,6 +457,7 @@ Do not modify directly.*
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float)|
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
|MatMulFpQ4|*in* A:**T1**<br> *in* B:**T2**<br> *in* B_shape:**T3**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(int64)|
|MatMulInteger16|*in* A:**T1**<br> *in* B:**T2**<br> *out* Y:**T3**|1+|**T1** = tensor(int16)<br/> **T2** = tensor(int16)<br/> **T3** = tensor(int32)|
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float)|
@ -852,6 +853,7 @@ Do not modify directly.*
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T2**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)|
|NGramRepeatBlock|*in* input_ids:**Tid**<br> *in* scores:**T**<br> *out* scores_out:**T**|1+|**T** = tensor(float)<br/> **Tid** = tensor(int64)|

View file

@ -30,6 +30,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Gathe
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FusedMatMul);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulNBits);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulBnb4);
#ifndef ORT_MINIMAL_BUILD
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulFpQ4);
#endif
@ -270,6 +271,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FusedMatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulNBits)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulBnb4)>,
#ifndef ORT_MINIMAL_BUILD
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulFpQ4)>,
#endif

View file

@ -0,0 +1,202 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <cstdint>
#include <algorithm>
#include <cmath>
namespace onnxruntime {
namespace contrib {
#if defined(_MSC_VER)
#define FORCEINLINE __forceinline
#else
#define FORCEINLINE __attribute__((always_inline)) inline
#endif
typedef enum Bnb_DataType_t {
FP4 = 0,
NF4 = 1,
} Bnb_DataType_t;
FORCEINLINE uint8_t QuantizeOneFP4(float x) {
// FP4 with bias of 3
// first bit is a sign
// subnormals
// 0b000 = 0
// 0b001 = 0.0625
// 0b110 = 2
// 0b111 = 3
// 0b100 = 4
// 0b101 = 6
// 0b010 = 8
// 0b011 = 12
// we do a binary search
// the pivots are divided by 12 (the FP4 absmax)
// since we assum input data is in [-1.0, 1.0]
// !be careful here, its easy to make a mistake
// that is difficult to noice if you add an extra
// zero somewhere!
uint8_t sign = x < 0 ? 0b1000 : 0b0000;
x = fabsf(x);
if (x > 0.29166667f) {
if (x > 0.583333f) {
if (x > 0.8333333f) {
return 0b0011 + sign;
} else {
return 0b0010 + sign;
}
} else if (x > 0.4166667f) {
return 0b101 + sign;
} else {
return 0b100 + sign;
}
} else if (x > 0.0859375f) {
if (x > 0.20833333f) {
return 0b0111 + sign;
} else {
return 0b0110 + sign;
}
} else if (x > 0.00260417f) {
return 0b0001 + sign;
} else {
return 0b0000 + sign;
}
}
FORCEINLINE uint8_t QuantizeOneNF4(float x) {
if (x > 0.03979014977812767f) {
if (x > 0.3893125355243683f) { // 1
if (x > 0.6427869200706482f) { // 11
if (x > 0.8614784181118011f) { // 111
return 0b1111;
} else {
return 0b1110;
}
} else if (x > 0.5016634166240692f) { // 110
return 0b1101;
} else {
return 0b1100;
}
} else if (x > 0.2035212516784668f) { // 10
if (x > 0.2920137718319893f) { // 101
return 0b1011;
} else {
return 0b1010;
}
} else if (x > 0.1202552504837513f) { // 100
return 0b1001;
} else {
return 0b1000;
}
} else if (x > -0.33967943489551544f) { // 0
if (x > -0.13791173323988914f) { // 01
if (x > -0.045525018125772476f) { // 011
return 0b0111;
} else {
return 0b0110;
}
} else if (x > -0.23460740596055984f) { // 010
return 0b0101;
} else {
return 0b0100;
}
} else if (x > -0.6106329262256622f) { // 00
if (x > -0.4599952697753906f) { // 001
return 0b0011;
} else {
return 0b0010;
}
} else if (x > -0.8480964004993439f) { // 000
return 0b0001;
} else {
return 0b0000;
}
}
template <int32_t DATA_TYPE>
FORCEINLINE uint8_t QuantizeOneBnb4(float x) {
if constexpr (DATA_TYPE == FP4)
return QuantizeOneFP4(x);
else
return QuantizeOneNF4(x);
}
template <typename T, int32_t block_size, int32_t DATA_TYPE>
FORCEINLINE void QuantizeBlockBnb4(const T* src, uint8_t* dst, T& absmax_block, int32_t block_idx, int32_t numel) {
float local_absmax = 0.0f;
int32_t block_len = std::min(block_size, numel - block_idx * block_size);
int32_t src_offset = block_idx * block_size;
int32_t dst_offset = block_idx * block_size / 2;
for (int32_t idx = 0; idx < block_len; idx++) {
const float v = static_cast<float>(src[src_offset + idx]);
local_absmax = fmaxf(local_absmax, fabsf(v));
}
absmax_block = static_cast<T>(local_absmax);
const float reciprocal_absmax = local_absmax ? 1.0f / local_absmax : 0.0f;
for (int32_t idx = 0; idx < block_len; idx += 2) {
const float v0 = static_cast<float>(src[src_offset + idx]) * reciprocal_absmax;
const uint8_t vi0 = QuantizeOneBnb4<DATA_TYPE>(v0);
const float v1 = (idx + 1 < block_len) ? static_cast<float>(src[src_offset + idx + 1]) * reciprocal_absmax : 0;
const uint8_t vi1 = QuantizeOneBnb4<DATA_TYPE>(v1);
dst[dst_offset + idx / 2] = (vi0 << 4) | vi1;
}
}
static float fp4_qaunt_map[16] = {0.00000000f, 5.208333333e-03f, 0.66666667f, 1.00000000f,
0.33333333f, 0.50000000f, 0.16666667f, 0.25000000f,
-0.00000000f, -5.208333333e-03f, -0.66666667f, -1.00000000f,
-0.33333333f, -0.50000000f, -0.16666667f, -0.25000000f};
static float nf4_qaunt_map[16] = {-1.0f,
-0.6961928009986877f,
-0.5250730514526367f,
-0.39491748809814453f,
-0.28444138169288635f,
-0.18477343022823334f,
-0.09105003625154495f,
0.0f,
0.07958029955625534f,
0.16093020141124725f,
0.24611230194568634f,
0.33791524171829224f,
0.44070982933044434f,
0.5626170039176941f,
0.7229568362236023f,
1.0f};
template <typename T, int32_t DATA_TYPE>
FORCEINLINE T DequantizeOneBnb4(uint8_t x) {
if constexpr (DATA_TYPE == FP4)
return static_cast<T>(fp4_qaunt_map[x]);
else
return static_cast<T>(nf4_qaunt_map[x]);
}
template <typename T, int32_t block_size, int32_t DATA_TYPE>
FORCEINLINE void DequantizeBlockBnb4(const uint8_t* src, T* dst, T absmax_block, int32_t block_idx, int32_t numel) {
int32_t block_len = std::min(block_size, numel - block_idx * block_size);
int32_t src_offset = block_idx * block_size / 2;
int32_t dst_offset = block_idx * block_size;
for (int32_t idx = 0; idx < block_len; idx += 2) {
const uint8_t val = src[src_offset + idx / 2];
dst[dst_offset + idx] = DequantizeOneBnb4<T, DATA_TYPE>(val >> 4) * absmax_block;
if (idx + 1 < block_len) dst[dst_offset + idx + 1] = DequantizeOneBnb4<T, DATA_TYPE>(val & 0xF) * absmax_block;
}
}
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,143 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "blockwise_quant_block_bnb4.h"
#include <vector>
#include "core/common/safeint.h"
#include "core/framework/float16.h"
#include "core/platform/threadpool.h"
#include <iostream>
namespace onnxruntime {
namespace contrib {
template <typename T, int32_t block_size, int32_t DATA_TYPE>
void QuantizeBlockwiseBnb4(
uint8_t* dst, // shape: [(N * K + 1) / 2]
const T* src, // shape: [N, K]
T* absmax, // shape: [(N * K + block_size - 1) / block_size]
int32_t N,
int32_t K,
onnxruntime::concurrency::ThreadPool* thread_pool) {
int32_t numel = N * K;
int32_t total_block_count = (numel + block_size - 1) / block_size;
concurrency::ThreadPool::TryBatchParallelFor(
thread_pool,
total_block_count,
[&](ptrdiff_t block_idx) {
QuantizeBlockBnb4<T, block_size, DATA_TYPE>(
src,
dst,
absmax[block_idx],
static_cast<int32_t>(block_idx),
numel);
},
0);
}
#define QuantizeBlockwiseBn4DataTyped(block_size, quant_type) \
if (quant_type == FP4) \
QuantizeBlockwiseBnb4<T, block_size, FP4>(dst, src, absmax, N, K, thread_pool); \
else \
QuantizeBlockwiseBnb4<T, block_size, NF4>(dst, src, absmax, N, K, thread_pool);
template <typename T>
void QuantizeBlockwiseBnb4(
uint8_t* dst, // shape: [(N * K + 1) / 2]
const T* src, // shape: [N, K]
T* absmax, // shape: [(N * K + block_size - 1) / block_size]
int32_t block_size,
int32_t quant_type,
int32_t N,
int32_t K,
onnxruntime::concurrency::ThreadPool* thread_pool) {
ORT_ENFORCE(
quant_type == FP4 || quant_type == NF4,
"Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
if (block_size == 16) {
QuantizeBlockwiseBn4DataTyped(16, quant_type);
} else if (block_size == 32) {
QuantizeBlockwiseBn4DataTyped(32, quant_type);
} else if (block_size == 64) {
QuantizeBlockwiseBn4DataTyped(64, quant_type);
} else if (block_size == 128) {
QuantizeBlockwiseBn4DataTyped(128, quant_type);
} else if (block_size == 256) {
QuantizeBlockwiseBn4DataTyped(256, quant_type);
} else {
ORT_NOT_IMPLEMENTED("only block size 16, 32, 64, 128, 256 are supported.");
}
}
#undef QuantizeBlockwiseBn4DataTyped
template <typename T, int32_t block_size, int32_t DATA_TYPE>
void DequantizeBlockwiseBnb4(
T* dst, // shape: [N, K]
const uint8_t* src, // shape: [(N * K + 1) / 2)]
const T* absmax, // shape: [(N * K + block_size - 1) / block_size]
int32_t N,
int32_t K,
onnxruntime::concurrency::ThreadPool* thread_pool) {
int32_t numel = N * K;
int32_t total_block_count = (numel + block_size - 1) / block_size;
concurrency::ThreadPool::TryBatchParallelFor(
thread_pool,
total_block_count,
[&](ptrdiff_t block_idx) {
DequantizeBlockBnb4<T, block_size, DATA_TYPE>(
src,
dst,
absmax[block_idx],
static_cast<int32_t>(block_idx),
numel);
},
0);
}
#define DequantizeBlockwiseBn4DataTyped(block_size, quant_type) \
if (quant_type == FP4) \
DequantizeBlockwiseBnb4<T, block_size, FP4>(dst, src, absmax, N, K, thread_pool); \
else \
DequantizeBlockwiseBnb4<T, block_size, NF4>(dst, src, absmax, N, K, thread_pool);
template <typename T>
void DequantizeBlockwiseBnb4(
T* dst, // shape: [N, K]
const uint8_t* src, // shape: [(N * K + 1) / 2)]
const T* absmax, // shape: [(N * K + block_size - 1) / block_size]
int32_t block_size,
int32_t quant_type,
int32_t N,
int32_t K,
onnxruntime::concurrency::ThreadPool* thread_pool) {
ORT_ENFORCE(
quant_type == FP4 || quant_type == NF4,
"Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
if (block_size == 16) {
DequantizeBlockwiseBn4DataTyped(16, quant_type);
} else if (block_size == 32) {
DequantizeBlockwiseBn4DataTyped(32, quant_type);
} else if (block_size == 64) {
DequantizeBlockwiseBn4DataTyped(64, quant_type);
} else if (block_size == 128) {
DequantizeBlockwiseBn4DataTyped(128, quant_type);
} else if (block_size == 256) {
DequantizeBlockwiseBn4DataTyped(256, quant_type);
} else {
ORT_NOT_IMPLEMENTED("only block size 16, 32, 64, 128, 256 are supported.");
}
}
#undef DequantizeBlockwiseBn4DataTyped
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,109 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/common/safeint.h"
#include "core/framework/op_kernel.h"
#include "core/providers/cpu/math/matmul_helper.h"
#include "core/providers/common.h"
#include "dequantize_blockwise_bnb4.h"
#include "core/mlas/inc/mlas.h"
namespace onnxruntime {
namespace contrib {
class MatMulBnb4 final : public OpKernel {
public:
MatMulBnb4(const OpKernelInfo& info) : OpKernel(info) {
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("K", &K_));
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("N", &N_));
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("block_size", &block_size_));
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("quant_type", &quant_type_));
ORT_ENFORCE(
quant_type_ == FP4 || quant_type_ == NF4,
"Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
}
Status Compute(OpKernelContext* context) const override;
private:
int64_t K_;
int64_t N_;
int64_t block_size_;
int64_t quant_type_;
};
Status MatMulBnb4::Compute(OpKernelContext* ctx) const {
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
const Tensor* a = ctx->Input<Tensor>(0);
const Tensor* b_quant = ctx->Input<Tensor>(1);
const Tensor* absmax = ctx->Input<Tensor>(2);
const float* a_data = a->Data<float>();
const uint8_t* b_quant_data = b_quant->Data<uint8_t>();
const float* absmax_data = absmax->Data<float>();
AllocatorPtr allocator;
auto status = ctx->GetTempSpaceAllocator(&allocator);
ORT_RETURN_IF_ERROR(status);
auto tmp_b_data_ptr = IAllocator::MakeUniquePtr<float>(allocator, SafeInt<size_t>(K_) * N_);
DequantizeBlockwiseBnb4<float>(
tmp_b_data_ptr.get(),
b_quant_data,
absmax_data,
static_cast<int32_t>(block_size_),
static_cast<int32_t>(quant_type_),
static_cast<int32_t>(N_),
static_cast<int32_t>(K_),
thread_pool);
constexpr bool transa = false;
constexpr bool transb = true;
TensorShape b_shape({N_, K_});
MatMulComputeHelper helper;
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, transa, transb));
Tensor* y = ctx->Output(0, helper.OutputShape());
// Bail out early if the output is going to be empty
if (y->Shape().Size() == 0) return Status::OK();
auto* y_data = y->MutableData<float>();
const size_t max_len = helper.OutputOffsets().size();
const size_t M = static_cast<size_t>(helper.M());
const size_t N = static_cast<size_t>(helper.N());
const size_t K = static_cast<size_t>(helper.K());
const size_t lda = helper.Lda(transa);
const size_t ldb = helper.Ldb(transb);
// TODO: implement with native kernel
std::vector<MLAS_SGEMM_DATA_PARAMS> data(max_len);
for (size_t i = 0; i < max_len; i++) {
data[i].BIsPacked = false;
data[i].A = a_data + helper.LeftOffsets()[i];
data[i].lda = lda;
data[i].B = tmp_b_data_ptr.get() + helper.RightOffsets()[i];
data[i].ldb = ldb;
data[i].C = y_data + helper.OutputOffsets()[i];
data[i].ldc = N;
data[i].alpha = 1.f;
data[i].beta = 0.0f;
}
MlasGemmBatch(CblasNoTrans, CblasTrans, M, N, K, data.data(), max_len, thread_pool);
return Status::OK();
}
ONNX_OPERATOR_KERNEL_EX(
MatMulBnb4,
kMSDomain,
1,
kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T1", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<uint8_t>()),
MatMulBnb4);
} // namespace contrib
} // namespace onnxruntime

View file

@ -118,6 +118,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulBnb4);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulBnb4);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Trilu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, UnfoldTensor);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, DynamicTimeWarping);
@ -279,6 +281,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulBnb4)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulBnb4)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BiasSoftmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BiasDropout)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BitmaskDropout)>,

View file

@ -0,0 +1,129 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <cub/cub.cuh>
#include <cuda_fp16.h>
#include "core/providers/cuda/cuda_common.h"
#include "contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h"
#include "dequantize_blockwise_bnb4.cuh"
namespace onnxruntime {
namespace contrib {
namespace cuda {
template <class T>
Status SetBnbQuantMap(int quant_type, T* quant_map_buffer, cudaStream_t stream) {
ORT_ENFORCE(
quant_type == FP4 || quant_type == NF4,
"Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
T host_quant_map[16];
switch (quant_type) {
case FP4:
for (int i = 0; i < 16; i++) host_quant_map[i] = static_cast<T>(fp4_qaunt_map[i]);
break;
case NF4:
for (int i = 0; i < 16; i++) host_quant_map[i] = static_cast<T>(nf4_qaunt_map[i]);
break;
}
CUDA_CALL_THROW(cudaMemcpyAsync(quant_map_buffer, host_quant_map, sizeof(T) * 16, cudaMemcpyHostToDevice, stream));
return Status::OK();
}
template Status SetBnbQuantMap<float>(int quant_type, float* quant_map_buffer, cudaStream_t stream);
template Status SetBnbQuantMap<half>(int quant_type, half* quant_map_buffer, cudaStream_t stream);
template <class T, int TILE_SIZE, int THREADS, int NUM_PER_TH>
__global__ void kDequantizeBlockwise(
const T* quant_map,
T* output,
const uint8_t* quant_data,
const T* absmax,
const int block_size,
const int n) {
const int n_load = (gridDim.x * TILE_SIZE);
int valid_items_load = 0;
int valid_items_store = 0;
const int base_idx = (blockIdx.x * TILE_SIZE);
T vals[NUM_PER_TH * 2];
uint8_t qvals[NUM_PER_TH];
T local_abs_max = T(0.0f);
typedef cub::BlockLoad<uint8_t, THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
typedef cub::BlockStore<T, THREADS, NUM_PER_TH * 2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
__shared__ typename LoadChar::TempStorage loadchar;
__shared__ typename StoreT::TempStorage storet;
for (unsigned int i = base_idx; i < n_load; i += gridDim.x * TILE_SIZE) {
valid_items_load = (n + 1) / 2 - i > TILE_SIZE ? TILE_SIZE : (n + 1) / 2 - i;
valid_items_store = n - i * 2 > TILE_SIZE * 2 ? TILE_SIZE * 2 : n - i * 2;
local_abs_max = __ldg(&absmax[(i + threadIdx.x * NUM_PER_TH) / (block_size)]);
__syncthreads();
LoadChar(loadchar).Load(&(quant_data[i]), qvals, valid_items_load, 128);
#pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH; j++) {
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530
vals[j * 2] = quant_map[qvals[j] >> 4] * local_abs_max;
vals[j * 2 + 1] = quant_map[qvals[j] & 0x0F] * local_abs_max;
#else
// half multiplication not supported
vals[j * 2] = static_cast<T>(static_cast<float>(quant_map[qvals[j] >> 4]) * static_cast<float>(local_abs_max));
vals[j * 2 + 1] =
static_cast<T>(static_cast<float>(quant_map[qvals[j] & 0x0F]) * static_cast<float>(local_abs_max));
#endif
}
__syncthreads();
StoreT(storet).Store(&(output[i * 2]), vals, valid_items_store);
}
}
template <class T>
Status DequantizeBnb4(
const T* quant_map,
T* output,
const uint8_t* quant_data,
const T* absmax,
int block_size,
int numel,
cudaStream_t stream) {
int tile_size = 1024;
kDequantizeBlockwise<T, 512, 64, 8><<<(numel + tile_size - 1) / tile_size, 64, 0, stream>>>(
quant_map,
output,
quant_data,
absmax,
block_size / 2,
numel);
return Status::OK();
}
template Status DequantizeBnb4<float>(
const float* quant_map,
float* output,
const uint8_t* quant_data,
const float* absmax,
int block_size,
int numel,
cudaStream_t stream);
template Status DequantizeBnb4<half>(
const half* quant_map,
half* output,
const uint8_t* quant_data,
const half *absmax,
int block_size,
int numel,
cudaStream_t stream);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,26 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/cuda/shared_inc/cuda_utils.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
template <class T>
Status SetBnbQuantMap(int quant_type, T* quant_map_buffer, cudaStream_t stream);
template <class T>
Status DequantizeBnb4(
const T* quant_map,
T* output,
const uint8_t* quant_data,
const T* absmax,
int block_size,
int numel,
cudaStream_t stream);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,144 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/common/safeint.h"
#include "core/providers/cuda/cuda_kernel.h"
#include "core/providers/cuda/shared_inc/fpgeneric.h"
#include "core/providers/cpu/math/matmul_helper.h"
#include "contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h"
#include "matmul_bnb4.cuh"
#include "dequantize_blockwise_bnb4.cuh"
namespace onnxruntime {
namespace contrib {
namespace cuda {
using namespace onnxruntime::cuda;
template <typename T>
class MatMulBnb4 final : public CudaKernel {
public:
MatMulBnb4(const OpKernelInfo& info) : CudaKernel(info) {
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("K", &K_));
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("N", &N_));
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("block_size", &block_size_));
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("quant_type", &quant_type_));
ORT_ENFORCE(
quant_type_ == FP4 || quant_type_ == NF4,
"Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
}
Status ComputeInternal(OpKernelContext* context) const override;
private:
int64_t K_;
int64_t N_;
int64_t block_size_;
int64_t quant_type_;
};
template <typename T>
Status MatMulBnb4<T>::ComputeInternal(OpKernelContext* ctx) const {
const Tensor* a = ctx->Input<Tensor>(0);
const Tensor* b_quant = ctx->Input<Tensor>(1);
const Tensor* absmax = ctx->Input<Tensor>(2);
const auto* a_data = a->Data<T>();
const uint8_t* b_quant_data = b_quant->Data<uint8_t>();
const auto* absmax_data = absmax->Data<T>();
typedef typename ToCudaType<T>::MappedType CudaT;
// TODO: find a better way to create the quant_map without using a buffer
// don't want to use malloc directly so asking from the caller
// can create a __device__ static array for float but doesn't work for half
IAllocatorUniquePtr<T> quant_map_buffer = GetScratchBuffer<T>(16, ctx->GetComputeStream());
auto* quant_map_buffer_data = quant_map_buffer.get();
ORT_RETURN_IF_ERROR(SetBnbQuantMap<CudaT>(
SafeInt<int>(quant_type_),
reinterpret_cast<CudaT*>(quant_map_buffer_data),
static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle())));
constexpr bool transa = false;
constexpr bool transb = true;
MatMulComputeHelper helper;
TensorShape b_shape({N_, K_});
ORT_RETURN_IF_ERROR(
helper.Compute(a->Shape(), b_shape, transa, transb));
Tensor* Y = ctx->Output(0, helper.OutputShape());
// Bail out early if the output is going to be empty
if (Y->Shape().Size() == 0) return Status::OK();
bool is_4bit_done = TryMatMulBnb4(
reinterpret_cast<const CudaT*>(quant_map_buffer_data),
reinterpret_cast<CudaT*>(Y->MutableData<T>()),
reinterpret_cast<const CudaT*>(a_data),
b_quant_data,
reinterpret_cast<const CudaT*>(absmax_data),
SafeInt<int>(helper.M()),
SafeInt<int>(helper.N()),
SafeInt<int>(helper.K()),
SafeInt<int>(block_size_),
static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle()));
if (!is_4bit_done) {
IAllocatorUniquePtr<T> b_dequant_ptr = GetScratchBuffer<T>(N_ * K_, ctx->GetComputeStream());
auto* b_dequant_data = b_dequant_ptr.get();
ORT_RETURN_IF_ERROR(DequantizeBnb4<CudaT>(
reinterpret_cast<const CudaT*>(quant_map_buffer_data),
reinterpret_cast<CudaT*>(b_dequant_data),
b_quant_data,
reinterpret_cast<const CudaT*>(absmax_data),
SafeInt<int>(block_size_),
SafeInt<int>(N_ * K_),
static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle())));
const CudaT alpha = ToCudaType<T>::FromFloat(1.f);
const CudaT zero = ToCudaType<T>::FromFloat(0.f);
CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
GetCublasHandle(ctx),
CUBLAS_OP_T,
CUBLAS_OP_N,
SafeInt<int>(helper.N()),
SafeInt<int>(helper.M()),
SafeInt<int>(helper.K()),
&alpha,
reinterpret_cast<const CudaT*>(b_dequant_data),
SafeInt<int>(K_),
reinterpret_cast<const CudaT*>(a_data),
helper.Lda(transa),
&zero,
reinterpret_cast<CudaT*>(Y->MutableData<T>()),
helper.Ldc(),
GetDeviceProp()));
}
return Status::OK();
}
ONNX_OPERATOR_TYPED_KERNEL_EX(
MatMulBnb4,
kMSDomain,
1,
float,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T1", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<uint8_t>()),
MatMulBnb4<float>);
ONNX_OPERATOR_TYPED_KERNEL_EX(
MatMulBnb4,
kMSDomain,
1,
MLFloat16,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T1", DataTypeImpl::GetTensorType<MLFloat16>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<uint8_t>()),
MatMulBnb4<MLFloat16>);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,192 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <type_traits>
#include <cub/cub.cuh>
#include <cublas_v2.h>
#include <cuda_fp16.h>
#include "matmul_bnb4.cuh"
namespace onnxruntime {
namespace contrib {
namespace cuda {
#define num_values_4bit 32
template <class T, int THREADS, int BITS>
__global__ void kgemm_4bit_inference_naive(
int M,
int N,
int K,
const T* __restrict__ A,
const uint8_t* B,
const T* absmax,
const T* datatype,
T* out,
int lda,
int ldb,
int ldc,
int block_size) {
// per threadblock:
// load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps]
// 4 warps -> 4 loads per iter
// 1x32 * 32x4 -> 1x4 outputs per thread block
typedef cub::WarpReduce<float> WarpReduce;
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS / 32];
const int warp_idx = threadIdx.x / 32;
const int warp_lane = threadIdx.x % 32;
const int row_B = (THREADS / 32) * blockIdx.x + warp_idx;
const int num_values_8bit = num_values_4bit / 2;
float local_C = 0.0f;
uint8_t local_B_4bit[num_values_8bit];
T local_B[num_values_4bit / 4];
T local_A[num_values_4bit / 4];
__shared__ T quant_map[16];
T local_absmax = T(0.0f);
for (int i = threadIdx.x; i < 16; i++) quant_map[i] = T(datatype[i]);
__syncthreads();
// A: [1, K]
// B: [N, K]
for (int inner_idx = warp_lane * num_values_4bit; inner_idx < K; inner_idx += 32 * num_values_4bit) {
int inner_idx_halved = inner_idx / 2;
int offset_B = ldb * row_B;
int absidx = ((2 * offset_B) + inner_idx) / block_size;
local_absmax = __ldg(&(absmax[absidx]));
if (row_B < N) {
if ((inner_idx_halved + num_values_8bit) < (K / 2)) {
// this is the most important for performance considerations
reinterpret_cast<int4(&)[num_values_8bit]>(local_B_4bit)[0] =
reinterpret_cast<const int4*>(B)[(offset_B + (inner_idx_halved)) / (num_values_8bit)];
} else {
#pragma unroll
for (int j = 0; j < (num_values_8bit); j++)
if ((inner_idx_halved) + j < (K / 2))
local_B_4bit[j] = B[offset_B + inner_idx_halved + j];
else
local_B_4bit[j] = 0b01110111;
}
} else {
#pragma unroll
for (int j = 0; j < (num_values_8bit); j++) local_B_4bit[j] = 0b01110111;
}
for (int i = 0; i < 4; i++) {
#pragma unroll
for (int k = 0; k < num_values_8bit / 4; k++) {
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530
local_B[k * 2] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * local_absmax;
local_B[k * 2 + 1] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * local_absmax;
#else
// half multiplication not supported
local_B[k * 2] =
static_cast<T>(static_cast<float>(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4]) *
static_cast<float>(local_absmax));
local_B[k * 2 + 1] =
static_cast<T>(static_cast<float>(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F]) *
static_cast<float>(local_absmax));
#endif
}
if (inner_idx + (num_values_4bit / 4) + (i * num_values_4bit / 4) < K) {
// this is also relatively important for performance
if (BITS == 16) {
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[0] =
reinterpret_cast<const int4*>(A)[inner_idx / (num_values_4bit / 4) + i];
} else {
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[0] =
reinterpret_cast<const int4*>(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 0];
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[1] =
reinterpret_cast<const int4*>(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 1];
}
} else {
#pragma unroll
for (int k = 0; k < num_values_4bit / 4; k++) {
if (inner_idx + (i * num_values_4bit / 4) + k < K)
local_A[k] = A[inner_idx + k + (i * num_values_4bit / 4)];
else
local_A[k] = T(0.0f);
}
}
// accumulate in float; small performance hit for Ampere, but lower error for outputs
#pragma unroll
for (int k = 0; k < num_values_4bit / 4; k++) {
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530
local_C += static_cast<float>(local_A[k] * local_B[k]);
#else
// half multiplication not supported
local_C += static_cast<float>(local_A[k]) * static_cast<float>(local_B[k]);
#endif
}
}
}
local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C);
if (row_B < N && warp_lane == 0) out[row_B] = T(local_C);
}
template <class T>
bool TryMatMulBnb4(
const T* quant_map,
T* output,
const T* a_data,
const uint8_t* b_data_quant,
const T* absmax,
int m,
int n,
int k,
int block_size,
cudaStream_t stream) {
if (k % block_size != 0 || m > 1) {
return false;
}
// supported block_sizes are [4096, 2048, 1024, 512, 256, 128, 64, 32]
if (block_size % 32 != 0 || block_size > 4096) {
return false;
}
int lda = k;
int ldb = (k + 1) / 2;
int ldc = n;
int num_blocks = (n + 3) / 4;
constexpr int bits = std::is_same_v<T, half> ? 16 : 32;
kgemm_4bit_inference_naive<T, 128, bits><<<num_blocks, 128, 0, stream>>>(
m, n, k, a_data, b_data_quant, absmax, quant_map, output, lda, ldb, ldc, block_size);
return true;
}
template bool TryMatMulBnb4<float>(
const float* quant_map,
float* output,
const float* a_data,
const uint8_t* b_data_quant,
const float* absmax,
int m,
int n,
int k,
int block_size,
cudaStream_t stream);
template bool TryMatMulBnb4<half>(
const half* quant_map,
half* output,
const half* a_data,
const uint8_t* b_data_quant,
const half* absmax,
int m,
int n,
int k,
int block_size,
cudaStream_t stream);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,26 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/cuda/shared_inc/cuda_utils.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
template <class T>
bool TryMatMulBnb4(
const T* quant_map,
T* output,
const T* a_data,
const uint8_t* b_data_quant,
const T* absmax,
int m,
int n,
int k,
int block_size,
cudaStream_t stream);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -3239,6 +3239,41 @@ Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored
MatmulWithQuantWeightShapeInference(ctx, in_features, out_features);
});
static const char* MatMulBnb4_ver1_doc = R"DOC(
MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 data type (https://arxiv.org/pdf/2305.14314.pdf). It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences:
1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'.
2. Input B is quantized with 4 bits with quantization data type specified by attribute 'quant_type'. It is transposed, flattened and quantized blockwisely with block size specified by attribute 'block_size'.
And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,..
3. Input B's quantization constants or scales are specified by input 'absmax'.
Input B is stored as uint8_t with shape: [(N * K + 1) / 2].
Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size].
)DOC";
ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulBnb4)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetDoc(MatMulBnb4_ver1_doc)
.Attr("K", "size of each input feature", AttributeProto::INT)
.Attr("N", "size of each output feature", AttributeProto::INT)
.Attr("block_size", "number of groupsize used for weight quantization. It needs to be a power of 2 and not smaller than 16.", AttributeProto::INT)
.Attr("quant_type", "quantization data type. 0 for FP4, 1 for NF4.", AttributeProto::INT)
.Input(0, "A", "The input tensor, not quantized", "T1")
.Input(1, "B", "1-dimensional quantized data for weight", "T2")
.Input(2, "absmax", "quantization constants", "T1")
.Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1")
.TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.")
.TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
// Type inference
propagateElemTypeFromInputToOutput(ctx, 0, 0);
// Shape inference
int64_t in_features = getAttribute(ctx, "K", -1);
int64_t out_features = getAttribute(ctx, "N", -1);
MatmulWithQuantWeightShapeInference(ctx, in_features, out_features);
});
#ifdef ENABLE_ATEN
ONNX_CONTRIB_OPERATOR_SCHEMA(ATen)
.SetDomain(kPytorchAtenDomain)

View file

@ -6,6 +6,7 @@
#include <pybind11/functional.h>
#include "contrib_ops/cpu/quantization/dequantize_blockwise.h"
#include "contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h"
#include "core/util/thread_utils.h"
namespace pybind11 {
@ -64,9 +65,39 @@ void QuantizeMatMul4BitsBlockwise(
tp.get());
}
template <typename T>
void QuantizeMatMulBnb4Blockwise(
py::array_t<uint8_t> dst,
py::array_t<T> src,
py::array_t<T> absmax,
int32_t block_size,
int32_t quant_type,
int32_t N,
int32_t K) {
OrtThreadPoolParams to;
auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to,
concurrency::ThreadPoolType::INTRA_OP);
py::buffer_info dst_buf = dst.request();
py::buffer_info src_buf = src.request();
py::buffer_info absmax_buf = absmax.request();
contrib::QuantizeBlockwiseBnb4<T>(
static_cast<uint8_t*>(dst_buf.ptr),
static_cast<const T*>(src_buf.ptr),
static_cast<T*>(absmax_buf.ptr),
block_size,
quant_type,
N,
K,
tp.get());
}
void CreateQuantPybindModule(py::module& m) {
m.def("quantize_matmul_4bits", &QuantizeMatMul4BitsBlockwise<float>);
m.def("quantize_matmul_4bits", &QuantizeMatMul4BitsBlockwise<MLFloat16>);
m.def("quantize_matmul_bnb4", &QuantizeMatMulBnb4Blockwise<float>);
m.def("quantize_matmul_bnb4", &QuantizeMatMulBnb4Blockwise<MLFloat16>);
}
} // namespace python

View file

@ -0,0 +1,89 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// This file serve as a simple example for adding a tunable op to onnxruntime.
#include <cuda_runtime_api.h>
#include <cuda_fp16.h>
#include <pybind11/pybind11.h>
#include <string>
#include "core/providers/cuda/tunable/cuda_tunable.h"
#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
#include "python/tools/kernel_explorer/device_array.h"
#include "contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh"
namespace py = pybind11;
namespace onnxruntime {
// Extend the OpParams so that all specializations have the same parameter passing interface
template <typename T>
struct DequantizeBnb4Params : cuda::tunable::OpParams {
std::string Signature() const override { return std::to_string(n_); }
int quant_type_;
T* output_;
const uint8_t* quant_;
const T* absmax_;
T* quant_map_buffer_;
int n_;
int k_;
};
template <typename T>
class DequantizeBnb4 : public IKernelExplorer {
public:
DequantizeBnb4(
int quant_type,
DeviceArray& output,
DeviceArray& quant,
DeviceArray& absmax,
DeviceArray& quant_map_buffer,
int n, int k) {
params_.tuning_ctx = TuningContext();
params_.stream = Stream();
params_.quant_type_ = quant_type;
params_.output_ = static_cast<T*>(output.ptr());
params_.quant_ = static_cast<uint8_t*>(quant.ptr());
params_.absmax_ = static_cast<T*>(absmax.ptr());
params_.quant_map_buffer_ = static_cast<T*>(quant_map_buffer.ptr());
params_.n_ = n;
params_.k_ = k;
}
void Run() override {
ORT_THROW_IF_ERROR(contrib::cuda::SetBnbQuantMap(
params_.quant_type_,
params_.quant_map_buffer_,
params_.StreamHandle()));
ORT_THROW_IF_ERROR(contrib::cuda::DequantizeBnb4(
params_.quant_map_buffer_,
params_.output_,
params_.quant_,
params_.absmax_,
64,
params_.n_ * params_.k_,
params_.StreamHandle()));
}
private:
// A VectorAddOp<T> is a callable that can process const VectorAddParams<T>*
using ParamsT = DequantizeBnb4Params<T>;
ParamsT params_{};
};
#define REGISTER_OP(name, type) \
py::class_<name<type>>(m, #name "_" #type) \
.def(py::init<int, DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, int, int>()) \
.def("SetRepeats", &name<type>::SetRepeats) \
.def("Profile", &name<type>::Profile) \
.def("Run", &name<type>::Run);
KE_REGISTER(m) {
REGISTER_OP(DequantizeBnb4, half);
REGISTER_OP(DequantizeBnb4, float);
}
} // namespace onnxruntime

View file

@ -0,0 +1,96 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// This file serve as a simple example for adding a tunable op to onnxruntime.
#include <cuda_runtime_api.h>
#include <cuda_fp16.h>
#include <pybind11/pybind11.h>
#include <string>
#include "core/providers/cuda/tunable/cuda_tunable.h"
#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
#include "python/tools/kernel_explorer/kernels/vector_add_kernel.cuh"
#include "contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh"
#include "contrib_ops/cuda/quantization/matmul_bnb4.cuh"
namespace py = pybind11;
namespace onnxruntime {
// Extend the OpParams so that all specializations have the same parameter passing interface
template <typename T>
struct MatrixFloatBnb4Params : cuda::tunable::OpParams {
std::string Signature() const override { return std::to_string(n_); }
int quant_type_;
T* output_;
const T* a_;
const uint8_t* b_;
const T* absmax_;
T* quant_map_buffer_;
int m_;
int n_;
int k_;
};
template <typename T>
class MatrixFloatBnb4 : public IKernelExplorer {
public:
MatrixFloatBnb4(DeviceArray& output,
DeviceArray& a,
DeviceArray& b,
DeviceArray& absmax,
DeviceArray& quant_map_buffer,
int quant_type, int m, int n, int k) {
params_.tuning_ctx = TuningContext();
params_.stream = Stream();
params_.output_ = static_cast<T*>(output.ptr());
params_.a_ = static_cast<T*>(a.ptr());
params_.b_ = static_cast<uint8_t*>(b.ptr());
params_.absmax_ = static_cast<T*>(absmax.ptr());
params_.quant_map_buffer_ = static_cast<T*>(quant_map_buffer.ptr());
params_.quant_type_ = quant_type;
params_.m_ = m;
params_.n_ = n;
params_.k_ = k;
}
void Run() override {
ORT_THROW_IF_ERROR(contrib::cuda::SetBnbQuantMap(
params_.quant_type_,
params_.quant_map_buffer_,
params_.StreamHandle()));
contrib::cuda::TryMatMulBnb4(
params_.quant_map_buffer_,
params_.output_,
params_.a_,
params_.b_,
params_.absmax_,
params_.m_,
params_.n_,
params_.k_,
64,
params_.StreamHandle());
}
private:
// A VectorAddOp<T> is a callable that can process const VectorAddParams<T>*
using ParamsT = MatrixFloatBnb4Params<T>;
ParamsT params_{};
};
#define REGISTER_OP(name, type) \
py::class_<name<type>>(m, #name "_" #type) \
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, int, int, int, int>()) \
.def("SetRepeats", &name<type>::SetRepeats) \
.def("Profile", &name<type>::Profile) \
.def("Run", &name<type>::Run);
KE_REGISTER(m) {
REGISTER_OP(MatrixFloatBnb4, half);
REGISTER_OP(MatrixFloatBnb4, float);
}
} // namespace onnxruntime

View file

@ -0,0 +1,92 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import sys
from dataclasses import dataclass
import kernel_explorer as ke
import numpy as np
from utils import dtype_to_bytes
def dtype_to_funcs(dtype):
type_map = {
"float16": list(filter(lambda x: "DequantizeBnb4_half" in x, dir(ke))),
"float32": list(filter(lambda x: "DequantizeBnb4_float" in x, dir(ke))),
}
return type_map[dtype]
quant_enums = {"FP4": 0, "NF4": 1}
dtypes = ["float16", "float32"]
quant_types = ["FP4", "NF4"]
@dataclass
class DequantizeBnb4Metric(ke.BandwidthMetric):
quant_type: str
n: int
k: int
def report(self):
return (
f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s"
f" {self.quant_type} {self.dtype} n={self.n} k={self.k} {self.name}"
)
def profile_dequantize_int4_func(qt, n, k, dtype, func):
np.random.seed(0)
block_size = 64
numel = n * k
output = np.random.rand(n, k).astype(dtype)
quant = np.random.randint(low=0, high=255, size=(numel + 1) // 2).astype("uint8")
absmax = np.random.rand((numel + block_size - 1) // block_size).astype(dtype)
quant_map_buffer = np.zeros(16).astype(dtype)
output_d = ke.DeviceArray(output)
quant_d = ke.DeviceArray(quant)
absmax_d = ke.DeviceArray(absmax)
quant_map_buffer_d = ke.DeviceArray(quant_map_buffer)
f = getattr(ke, func)
my_op = f(quant_enums[qt], output_d, quant_d, absmax_d, quant_map_buffer_d, n, k)
duration_ms = my_op.Profile()
total_bytes = numel / 2 + (numel + numel / block_size) * dtype_to_bytes(dtype)
ke.report(DequantizeBnb4Metric(func, dtype, duration_ms, total_bytes, qt, n, k))
def profile_with_args(qt, n, k, dtype, sort):
with ke.benchmark(sort):
for func in dtype_to_funcs(dtype):
profile_dequantize_int4_func(qt, n, k, dtype, func)
def profile():
for qt in quant_types:
for dt in dtypes:
for n, k in ((4096, 4096), (4096, 12288), (12288, 4096)):
profile_with_args(qt, n, k, dt, True)
print()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
group = parser.add_argument_group("profile with args")
group.add_argument("n", type=int)
group.add_argument("k", type=int)
group.add_argument("quant_type", choices=quant_types)
group.add_argument("dtype", choices=dtypes)
group.add_argument("--sort", action="store_true")
if len(sys.argv) == 1:
profile()
else:
args = parser.parse_args()
profile_with_args(args.quant_type, args.n, args.k, args.dtype, args.sort)

View file

@ -0,0 +1,136 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import sys
from dataclasses import dataclass
import kernel_explorer as ke
import numpy as np
from utils import dtype_to_bytes
def dtype_to_funcs(dtype):
type_map = {
"float16": list(filter(lambda x: "MatrixFloatBnb4_half" in x, dir(ke))),
"float32": list(filter(lambda x: "MatrixFloatBnb4_float" in x, dir(ke))),
}
return type_map[dtype]
def dtype_to_funcs_cublas(dtype):
type_map = {
"float16": list(filter(lambda x: "GemmBenchmark_half" in x, dir(ke))),
"float32": list(filter(lambda x: "GemmBenchmark_float" in x, dir(ke))),
}
return type_map[dtype]
quant_enums = {"FP4": 0, "NF4": 1}
dtypes = ["float16", "float32"]
quant_types = ["FP4", "NF4"]
@dataclass
class MatrixMulMetric(ke.BandwidthMetric):
m: int
n: int
k: int
def report(self):
return (
f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s {self.dtype} m={self.m} n={self.n} k={self.k} {self.name}"
)
@dataclass
class MatrixFpBnb4Metric(MatrixMulMetric):
quant_type: str
def report(self):
return (
f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s"
f" {self.quant_type} {self.dtype} m={self.m} n={self.n} k={self.k} {self.name}"
)
def profile_matmul_fp_bnb4_func(qt, m, n, k, dtype, func):
np.random.seed(0)
block_size = 64
numel = n * k
output = np.random.rand(m, n).astype(dtype)
a = np.random.rand(m, k).astype(dtype)
b = np.random.randint(low=0, high=255, size=(numel + 1) // 2).astype("uint8")
absmax = np.random.rand((numel + block_size - 1) // block_size).astype(dtype)
quant_map_buffer = np.zeros(16).astype(dtype)
output_d = ke.DeviceArray(output)
a_d = ke.DeviceArray(a)
b_d = ke.DeviceArray(b)
absmax_d = ke.DeviceArray(absmax)
quant_map_buffer_d = ke.DeviceArray(quant_map_buffer)
f = getattr(ke, func)
my_op = f(output_d, a_d, b_d, absmax_d, quant_map_buffer_d, quant_enums[qt], m, n, k)
duration_ms = my_op.Profile()
total_bytes = (m * k + n * k + m * n) * (dtype_to_bytes(dtype))
ke.report(MatrixFpBnb4Metric(func, dtype, duration_ms, total_bytes, m, n, k, qt))
def profile_gemm_func(m, n, k, dtype, func):
np.random.seed(0)
output = np.random.rand(m, n).astype(dtype)
a = np.random.rand(m, k).astype(dtype)
b = np.random.rand(k, n).astype(dtype)
output_d = ke.DeviceArray(output)
a_d = ke.DeviceArray(a)
b_d = ke.DeviceArray(b)
f = getattr(ke, func)
my_op = f(output_d, a_d, b_d, m, n, k)
duration_ms = my_op.Profile()
total_bytes = (m * k + n * k + m * n) * (dtype_to_bytes(dtype))
ke.report(MatrixMulMetric(func, dtype, duration_ms, total_bytes, m, n, k))
def profile_with_args(qt, m, n, k, dtype, sort):
with ke.benchmark(sort):
for func in dtype_to_funcs(dtype):
profile_matmul_fp_bnb4_func(qt, m, n, k, dtype, func)
for func in dtype_to_funcs_cublas(dtype):
profile_gemm_func(m, n, k, dtype, func)
def profile():
dims_m = [1]
for qt in quant_types:
for dt in dtypes:
for m in dims_m:
for n, k in ((4096, 4096), (4096, 12288), (12288, 4096)):
profile_with_args(qt, m, n, k, dt, False)
print()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
group = parser.add_argument_group("profile with args")
group.add_argument("m", type=int)
group.add_argument("n", type=int)
group.add_argument("k", type=int)
group.add_argument("quant_type", choices=quant_types)
group.add_argument("dtype", choices=dtypes)
group.add_argument("--sort", action="store_true")
if len(sys.argv) == 1:
profile()
else:
args = parser.parse_args()
profile_with_args(args.quant_type, args.m, args.n, args.k, args.dtype, args.sort)

View file

@ -0,0 +1,240 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import argparse
import logging
import os
from typing import List, Tuple
import numpy as np
import numpy.typing as npt
import onnx
from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto
from onnxruntime.capi._pybind_state import quantize_matmul_bnb4
from .onnx_model import ONNXModel
from .quant_utils import attribute_to_kwarg
logger = logging.getLogger(__name__)
class MatMulBnb4Quantizer:
"""Perform 4b quantization of constant MatMul weights using FP4 or NF4 data type"""
##################
# quantization types, must be consistent with native code type
# Bnb_DataType_t defined in blockwise_quant_block_bnb4.h
# 4b floating point with bias of 3
FP4 = 0
# 4b NormalFloat
NF4 = 1
def __init__(self, model: ModelProto, quant_type: int, block_size: int, nodes_to_exclude=None):
nodes_to_exclude = nodes_to_exclude or []
assert quant_type in [MatMulBnb4Quantizer.FP4, MatMulBnb4Quantizer.NF4]
self.model = ONNXModel(model)
self.quant_type = quant_type
self.block_size = block_size
self.nodes_to_exclude = set(nodes_to_exclude)
@staticmethod
def __get_initializer(name, graph_path: List[GraphProto]) -> Tuple[TensorProto, GraphProto]:
for gid in range(len(graph_path) - 1, -1, -1):
graph = graph_path[gid]
for tensor in graph.initializer:
if tensor.name == name:
return tensor, graph
return None, None
def bnb4_block_quant(self, fpweight: npt.ArrayLike) -> np.ndarray:
"""4b quantize fp32/fp16 weight"""
if len(fpweight.shape) != 2:
raise ValueError("Current bnb4 block quantization only supports 2D tensors!")
# need to copy since the transposed weight still has the original memory layout
# Linear4bit quantizes its weight data which is the transposed weight
fpweight_t = fpweight.transpose().copy()
rows, cols = fpweight.shape
numel = rows * cols
block_size = self.block_size
num_blocks = (numel + block_size - 1) // block_size
quantized_numel = (numel + 1) // 2
packed = np.zeros(quantized_numel, dtype="uint8")
absmax = np.zeros(num_blocks, dtype=fpweight.dtype)
# block wise quantization, fpweight_t is flattened and divided into blocks
quantize_matmul_bnb4(packed, fpweight_t, absmax, block_size, self.quant_type, cols, rows)
return (packed, absmax)
def _bnb4_matmul_node_weight(self, node: NodeProto, graph_stack: List[GraphProto]) -> NodeProto:
"""If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node"""
if node.op_type != "MatMul":
return node # only care about MatMul for now
logger.debug(f"start to quantize {node.name} ...")
if node.name in self.nodes_to_exclude:
logger.debug(f"exclude to quantize {node.name} as specified by nodes_to_exclude...")
return node
inputB = node.input[1] # noqa: N806
B, Bs_graph = MatMulBnb4Quantizer.__get_initializer(inputB, graph_stack) # noqa: N806
if B is None:
logger.debug("MatMul doesn't have const weight. Skip to quantize")
return node # only care about constant weight
B_array = onnx.numpy_helper.to_array(B) # noqa: N806
if len(B_array.shape) != 2:
logger.debug("MatMul weight is not 2D. Skip to quantize")
return node # can only process 2-D matrix
packed, absmax = self.bnb4_block_quant(B_array)
B_quant = onnx.numpy_helper.from_array(packed) # noqa: N806
B_quant.name = B.name + "_Bnb4"
for input in Bs_graph.input:
if input.name == inputB:
Bs_graph.input.remove(input)
break
absmax_tensor = onnx.numpy_helper.from_array(absmax)
absmax_tensor.name = B.name + "_absmax"
Bs_graph.initializer.extend([B_quant, absmax_tensor])
kwargs = {}
rows, cols = B_array.shape
kwargs["K"] = rows
kwargs["N"] = cols
kwargs["block_size"] = self.block_size
kwargs["quant_type"] = self.quant_type
matmul_bnb4_node = onnx.helper.make_node(
"MatMulBnb4",
inputs=[node.input[0], B_quant.name, absmax_tensor.name],
outputs=[node.output[0]],
name=node.name + "_Bnb4" if node.name else "",
domain="com.microsoft",
**kwargs,
)
logger.debug(f"complete quantization of {node.name} ...")
return matmul_bnb4_node
def _process_subgraph(self, graph_stack: List[GraphProto]):
new_nodes = []
graph = graph_stack[-1]
for node in graph.node:
graph_attrs = [
attr
for attr in node.attribute
if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS
]
if len(graph_attrs):
kwargs = {}
for attr in node.attribute:
if attr.type == onnx.AttributeProto.GRAPH:
# recursive call to take care of sub-graph
graph_stack.append(attr.g)
kv = {attr.name: self._process_subgraph(graph_stack)}
elif attr.type == onnx.AttributeProto.GRAPHS:
value = []
for subgraph in attr.graphs:
# recursive call to take care of sub-graph
graph_stack.append(subgraph)
value.extend([self._process_subgraph(graph_stack)])
kv = {attr.name: value}
else:
kv = attribute_to_kwarg(attr)
kwargs.update(kv)
node = onnx.helper.make_node( # noqa: PLW2901
node.op_type, node.input, node.output, name=node.name, **kwargs
)
new_nodes.append(self._bnb4_matmul_node_weight(node, graph_stack))
graph.ClearField("node")
graph.node.extend(new_nodes)
graph_stack.pop()
return graph
def process(self):
# use a stack to keep track of sub-graphs
graph_stack = [self.model.graph()]
opset_import = self.model.opset_import()
has_ms_domain = False
for opset in opset_import:
if opset.domain == "com.microsoft":
has_ms_domain = True
if not has_ms_domain:
opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)])
self._process_subgraph(graph_stack)
self.model.clean_initializers()
def parse_args():
parser = argparse.ArgumentParser(
description="""Blockwise FP4/NF4 quantization for MatMul 2D weight matrices.
A weight matrix is partitioned into blocks, where each block is a contiguous
subset inside the flattened transposed weight matrix. Each block is quantized
into a set of 4b integers with an absolute value scaling factor.
"""
)
parser.add_argument("--input_model", required=True, help="Path to the input model file")
parser.add_argument("--output_model", required=True, help="Path to the output model file")
parser.add_argument(
"--quant_type",
required=False,
default=1,
options=[MatMulBnb4Quantizer.FP4, MatMulBnb4Quantizer.NF4],
help="Quantization data type. 0: FP4, 1: NF4",
)
parser.add_argument(
"--block_size",
required=False,
default=64,
description="Block size for blockwise quantization. Note: bnb.nn.Linear4bit only uses block_size=64",
)
parser.add_argument("-v", "--verbose", required=False, action="store_true")
parser.set_defaults(verbose=False)
parser.add_argument(
"--nodes_to_exclude",
nargs="+",
type=str,
required=False,
default=[],
help="Specify the nodes to be excluded from quantization with node names",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
if args.verbose:
logger.setLevel(logging.DEBUG)
input_model_path = args.input_model
output_model_path = args.output_model
if os.path.exists(output_model_path):
logger.error(f"file {output_model_path} already exists")
raise Exception(f"file {output_model_path} already exists")
model = onnx.load(input_model_path)
quant = MatMulBnb4Quantizer(model, args.quant_type, args.block_size, nodes_to_exclude=args.nodes_to_exclude)
quant.process()
quant.model.save_model_to_file(output_model_path, True)

View file

@ -0,0 +1,151 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef ORT_MINIMAL_BUILD
#include "core/common/span_utils.h"
#include "core/framework/tensor.h"
#include "core/mlas/inc/mlas_q4.h"
#include "core/mlas/inc/mlas.h"
#include "core/session/inference_session.h"
#include "test/common/tensor_op_test_utils.h"
#include "test/framework/test_utils.h"
#include "test/optimizer/graph_transform_test_builder.h"
#include "test/providers/provider_test_utils.h"
#include "test/util/include/default_providers.h"
#include "core/util/qmath.h"
#include "contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h"
#include <chrono>
#include <random>
#include "gtest/gtest.h"
#include "gmock/gmock.h"
namespace onnxruntime {
namespace test {
void QuantizeDequantizeBnb4(std::vector<float>& raw_vals, // N X K
std::vector<uint8_t>& quant_vals,
std::vector<float>& absmax,
int32_t quant_type,
int32_t N,
int32_t K,
int32_t block_size) {
OrtThreadPoolParams to;
auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to,
concurrency::ThreadPoolType::INTRA_OP);
contrib::QuantizeBlockwiseBnb4<float>(
quant_vals.data(),
raw_vals.data(),
absmax.data(),
block_size,
quant_type,
N,
K,
tp.get());
contrib::DequantizeBlockwiseBnb4<float>(
raw_vals.data(),
quant_vals.data(),
absmax.data(),
block_size,
quant_type,
N,
K,
tp.get());
}
void RunTest(int64_t quant_type, int64_t M, int64_t N, int64_t K, int64_t block_size, bool use_float16) {
RandomValueGenerator random{1234};
std::vector<float> input0_vals(random.Gaussian<float>(std::vector<int64_t>({M, K}), 0.0f, 0.25f));
// quantizer expects transposed weights, N X K
std::vector<float> input1_f_vals(random.Gaussian<float>(std::vector<int64_t>({N, K}), 0.0f, 0.25f));
int64_t numel = N * K;
int64_t quantized_numel = (numel + 1) / 2;
int64_t total_block_count = (numel + block_size - 1) / block_size;
std::vector<uint8_t> input1_vals(quantized_numel);
std::vector<float> absmax(total_block_count);
QuantizeDequantizeBnb4(input1_f_vals,
input1_vals,
absmax,
static_cast<int32_t>(quant_type),
static_cast<int32_t>(N),
static_cast<int32_t>(K),
static_cast<int32_t>(block_size));
std::vector<float> expected_vals(M * N);
for (int64_t m = 0; m < M; m++) {
for (int64_t n = 0; n < N; n++) {
float sum = 0.0f;
for (int64_t k = 0; k < K; k++) {
sum += input0_vals[m * K + k] * input1_f_vals[n * K + k];
}
expected_vals[m * N + n] = sum;
}
}
OpTester test("MatMulBnb4", 1, kMSDomain);
test.AddAttribute<int64_t>("K", K);
test.AddAttribute<int64_t>("N", N);
test.AddAttribute<int64_t>("block_size", block_size);
test.AddAttribute<int64_t>("quant_type", quant_type);
if (use_float16) {
test.AddInput<MLFloat16>("A", {M, K}, ToFloat16(input0_vals), false);
test.AddInput<uint8_t>("B", {quantized_numel}, input1_vals, true);
test.AddInput<MLFloat16>("absmax", {total_block_count}, ToFloat16(absmax), true);
test.AddOutput<MLFloat16>("Y", {M, N}, ToFloat16(expected_vals));
test.SetOutputAbsErr("Y", 0.02f);
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
} else {
test.AddInput<float>("A", {M, K}, input0_vals, false);
test.AddInput<uint8_t>("B", {quantized_numel}, input1_vals, true);
test.AddInput<float>("absmax", {total_block_count}, absmax, true);
test.AddOutput<float>("Y", {M, N}, expected_vals);
test.Run();
}
}
TEST(MatMulBnb4, Float32) {
for (auto qt : {0, 1}) {
for (auto M : {1, 2, 100}) {
for (auto N : {1, 2, 32, 288}) {
for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) {
for (auto block_size : {16, 32, 64, 128}) {
RunTest(qt, M, N, K, block_size, false);
}
}
}
}
}
}
#if defined(USE_CUDA)
TEST(MatMulBnb4, Float16) {
for (auto qt : {0, 1}) {
for (auto M : {1, 2, 100}) {
for (auto N : {1, 2, 32, 288}) {
for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) {
for (auto block_size : {16, 32, 64, 128}) {
RunTest(qt, M, N, K, block_size, true);
}
}
}
}
}
}
#endif
} // namespace test
} // namespace onnxruntime
#endif // ORT_MINIMAL_BUILD

View file

@ -0,0 +1,186 @@
#!/usr/bin/env python
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import tempfile
import unittest
from importlib.util import find_spec
from pathlib import Path
from typing import Dict, Tuple, Union
import numpy as np
import onnx
from onnx import TensorProto, helper
from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count
from onnxruntime.quantization import quant_utils
quant_maps = {
0: [
0.00000000,
5.208333333e-03,
0.66666667,
1.00000000,
0.33333333,
0.50000000,
0.16666667,
0.25000000,
-0.00000000,
-5.208333333e-03,
-0.66666667,
-1.00000000,
-0.33333333,
-0.50000000,
-0.16666667,
-0.25000000,
],
1: [
-1.0,
-0.6961928009986877,
-0.5250730514526367,
-0.39491748809814453,
-0.28444138169288635,
-0.18477343022823334,
-0.09105003625154495,
0.0,
0.07958029955625534,
0.16093020141124725,
0.24611230194568634,
0.33791524171829224,
0.44070982933044434,
0.5626170039176941,
0.7229568362236023,
1.0,
],
}
class TestOpMatMulBnb4(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="test_matmulbnb4.")
@classmethod
def tearDownClass(cls):
cls._tmp_model_dir.cleanup()
def fill_bnb4_data(self, shape: Tuple[int, int], quant_type: int) -> np.ndarray:
rows, cols = shape
line = np.zeros(shape)
line = line.reshape(-1)
quant_map = np.array(quant_maps[quant_type], dtype=np.float32)
v = 0
for i in range(line.shape[0]):
line[i] = quant_map[v]
v += 1
if v >= 16:
v = 0
# bnb quantization quantizes weight.T after flattening
line = line.reshape(cols, rows).transpose()
return line.reshape(shape)
def input_feeds(self, n: int, name2shape: Dict[str, Union[int, Tuple[int, ...]]]) -> TestDataFeeds:
input_data_list = []
for _i in range(n):
inputs = {}
for name, shape in name2shape.items():
inputs.update({name: np.random.randint(-1, 2, shape).astype(np.float32)})
input_data_list.extend([inputs])
dr = TestDataFeeds(input_data_list)
return dr
def construct_model_matmul(self, output_model_path: str, quant_type: int) -> None:
# (input)
# |
# MatMul
# |
# (output)
input_name = "input"
output_name = "output"
initializers = []
def make_matmul(input_name, weight_shape: Union[int, Tuple[int, ...]], weight_name: str, output_name: str):
weight_data = self.fill_bnb4_data(weight_shape, quant_type).astype(np.float32)
initializers.append(onnx.numpy_helper.from_array(weight_data, name=weight_name))
return onnx.helper.make_node(
"MatMul",
[input_name, weight_name],
[output_name],
)
# for this to work (in_features * out_features) % block_size == 0
in_features = 52
out_features = 288
# make MatMul node
matmul_node = make_matmul(
input_name,
[in_features, out_features],
"linear1.weight",
output_name,
)
# make graph
input_tensor = helper.make_tensor_value_info(input_name, TensorProto.FLOAT, [-1, in_features])
output_tensor = helper.make_tensor_value_info(output_name, TensorProto.FLOAT, [-1, out_features])
graph_name = "matmul_bnb4_test"
graph = helper.make_graph(
[matmul_node],
graph_name,
[input_tensor],
[output_tensor],
initializer=initializers,
)
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
model.ir_version = 7 # use stable onnx ir version
onnx.save(model, output_model_path)
def quant_test(self, quant_type: int, block_size: int):
model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath(f"matmul_fp32_{quant_type}.onnx").absolute())
self.construct_model_matmul(model_fp32_path, quant_type)
data_reader = self.input_feeds(1, {"input": [100, 52]})
model_bnb4_path = str(
Path(self._tmp_model_dir.name).joinpath(f"MatMulBnb4_{quant_type}_{block_size}.onnx").absolute()
)
# Quantize fp32 model to bnb4 model
from onnxruntime.quantization import matmul_bnb4_quantizer
model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path))
quant = matmul_bnb4_quantizer.MatMulBnb4Quantizer(model, quant_type, block_size)
quant.process()
quant.model.save_model_to_file(model_bnb4_path, False)
quant_nodes = {"MatMulBnb4": 1}
check_op_type_count(self, model_bnb4_path, **quant_nodes)
data_reader.rewind()
try:
check_model_correctness(self, model_fp32_path, model_bnb4_path, data_reader.get_next())
except Exception as exception:
raise exception
@unittest.skipIf(
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_bnb4"
)
def test_quantize_matmul_bnb4_fp4(self):
np.random.seed(13)
self.quant_test(0, 64)
@unittest.skipIf(
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_bnb4"
)
def test_quantize_matmul_bnb4_nf4(self):
np.random.seed(13)
self.quant_test(1, 64)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,139 @@
#!/usr/bin/env python
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import unittest
from importlib.util import find_spec
import numpy as np
import numpy.typing as npt
quant_enums = {"FP4": 0, "NF4": 1}
def quantize_block_fp4(block: npt.ArrayLike):
# quantize a block of float32 values to uint8 by simulating a binary search using pivots
# could have used (block[:,None] - quant_map).argmin(axis=1) but there are some mismatches due to
# floating point precision
# block: 1-D array of normalized [-1,1] float32 values, len(block) % 2 == 0
# pivots to find the quantization index
# only half of the pivots are needed since the other half is symmetric
pivots = np.array(
[0.00260417, 0.0859375, 0.20833333, 0.29166667, 0.4166667, 0.583333, 0.8333333, 1], dtype=np.float32
)
# indices are not 0,1,2,3,4,5,6,7 because it is a floating point data type
pivot_indices = np.array([0, 1, 6, 7, 4, 5, 2, 3], dtype=np.uint8)
# signs of the block
signs = (block < 0).astype(np.uint8) * 8
# find the uint8 quantization index
# argmax finds the first occurrence of True
quant_indices = pivot_indices[(np.abs(block)[:, None] <= pivots).argmax(axis=1)] + signs
return np.bitwise_or(np.left_shift(quant_indices[::2], 4), quant_indices[1::2])
def quantize_block_nf4(block: npt.ArrayLike):
pivots = np.array(
[
-0.8480964004993439,
-0.6106329262256622,
-0.4599952697753906,
-0.33967943489551544,
-0.23460740596055984,
-0.13791173323988914,
-0.045525018125772476,
0.03979014977812767,
0.1202552504837513,
0.2035212516784668,
0.2920137718319893,
0.3893125355243683,
0.5016634166240692,
0.6427869200706482,
0.8614784181118011,
1.0,
],
dtype=np.float32,
)
quant_indices = (block[:, None] <= pivots).argmax(axis=1).astype(np.uint8)
return np.bitwise_or(np.left_shift(quant_indices[::2], 4), quant_indices[1::2])
def quantize_blockwise_bnb4_ref(matrix_float: npt.ArrayLike, block_size: int, quant_type: str, target=None):
if len(matrix_float.shape) != 2:
raise ValueError("Current bnb4 block quantization only supports 2D tensors!")
numel = matrix_float.size
num_blocks = (numel + block_size - 1) // block_size
quantized_numel = (numel + 1) // 2
packed = np.zeros(quantized_numel, dtype=np.uint8)
absmax = np.zeros(num_blocks, dtype=matrix_float.dtype)
flattened_matrix_float = matrix_float.flatten()
for block_idx in range(num_blocks):
block_len = min(block_size, numel - block_idx * block_size)
block = np.float32(flattened_matrix_float[block_idx * block_size : block_idx * block_size + block_len])
block_absmax = np.max(np.abs(block))
reciprocal_absmax = 1.0 / block_absmax if block_absmax != 0 else 0.0
absmax[block_idx] = block_absmax
if block_len % 2 != 0:
block = np.append(block, 0.0)
block_len += 1
block *= reciprocal_absmax
start = block_idx * block_size // 2
end = start + block_len // 2
if quant_type == "FP4":
packed[start:end] = quantize_block_fp4(block)
else:
packed[start:end] = quantize_block_nf4(block)
return (packed, absmax)
def quantize_blockwise_bnb4_target(matrix_float: npt.ArrayLike, block_size: int, quant_type: str):
if len(matrix_float.shape) != 2:
raise ValueError("Current int4 block quantization only supports 2D tensors!")
quant_type_enum = quant_enums[quant_type]
n, k = matrix_float.shape # already transposed
numel = n * k
num_blocks = (numel + block_size - 1) // block_size
quantized_numel = (numel + 1) // 2
packed = np.zeros(quantized_numel, dtype="uint8")
absmax = np.zeros(num_blocks, dtype=matrix_float.dtype)
from onnxruntime.capi._pybind_state import quantize_matmul_bnb4
quantize_matmul_bnb4(packed, matrix_float, absmax, block_size, quant_type_enum, n, k)
return (packed, absmax)
class TestQuantizeBlockwiseBnb4(unittest.TestCase):
@unittest.skipIf(
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_bnb4"
)
def test_quantize_blockwise_bnb4(self):
for quant_type in ["FP4", "NF4"]:
for k, n in [(128, 128), (32, 128), (128, 32), (52, 128), (128, 52), (73, 123)]:
for block_size in [16, 32, 64, 128]:
for type in [np.float32, np.float16]:
matrix_float = np.random.uniform(-1, 1, (k, n)).astype(type)
quant_value_ref, absmax_ref = quantize_blockwise_bnb4_ref(matrix_float, block_size, quant_type)
quant_value, absmax = quantize_blockwise_bnb4_target(matrix_float, block_size, quant_type)
assert np.allclose(quant_value_ref, quant_value)
assert np.allclose(absmax_ref, absmax)
if __name__ == "__main__":
unittest.main()