mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-21 02:18:09 +00:00
Add MatMul FP4 and NF4 Support (#18066)
### Description Add a contrib op MatMulBnb4 (FP4 and NF4) and related toolchain to support quantization on weight. This PR adds: - schema for contrib op MatMulBnb4 which can support FP4 (4-bit floating point) and NF4 (4-bit NormalFloat) quantization on weight. - a naive implementation for MatMulBnb4 on CPU and GPU, i.e., implemented like MatMul(A, Dequantize(B)). - a special implementation for GemV for MatMulBnb4 and related benchmark tool. - tool to quantize model to FP4 or NF4.
This commit is contained in:
parent
d88d52eead
commit
d30d4d372a
23 changed files with 2236 additions and 0 deletions
|
|
@ -54,6 +54,11 @@ set(contrib_ops_excluded_files
|
|||
"quantization/attention_quantization_impl.cuh"
|
||||
"quantization/dequantize_blockwise.cuh"
|
||||
"quantization/dequantize_blockwise.cu"
|
||||
"quantization/dequantize_blockwise_bnb4.cuh"
|
||||
"quantization/dequantize_blockwise_bnb4.cu"
|
||||
"quantization/matmul_bnb4.cc"
|
||||
"quantization/matmul_bnb4.cuh"
|
||||
"quantization/matmul_bnb4.cu"
|
||||
"quantization/matmul_nbits.cc"
|
||||
"quantization/matmul_nbits.cuh"
|
||||
"quantization/matmul_nbits.cu"
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ Do not modify directly.*
|
|||
* <a href="#com.microsoft.Inverse">com.microsoft.Inverse</a>
|
||||
* <a href="#com.microsoft.Irfft">com.microsoft.Irfft</a>
|
||||
* <a href="#com.microsoft.LongformerAttention">com.microsoft.LongformerAttention</a>
|
||||
* <a href="#com.microsoft.MatMulBnb4">com.microsoft.MatMulBnb4</a>
|
||||
* <a href="#com.microsoft.MatMulFpQ4">com.microsoft.MatMulFpQ4</a>
|
||||
* <a href="#com.microsoft.MatMulInteger16">com.microsoft.MatMulInteger16</a>
|
||||
* <a href="#com.microsoft.MatMulIntegerToFloat">com.microsoft.MatMulIntegerToFloat</a>
|
||||
|
|
@ -2504,6 +2505,62 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
</dl>
|
||||
|
||||
|
||||
### <a name="com.microsoft.MatMulBnb4"></a><a name="com.microsoft.matmulbnb4">**com.microsoft.MatMulBnb4**</a>
|
||||
|
||||
MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 data type (https://arxiv.org/pdf/2305.14314.pdf). It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences:
|
||||
1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'.
|
||||
2. Input B is quantized with 4 bits with quantization data type specified by attribute 'quant_type'. It is transposed, flattened and quantized blockwisely with block size specified by attribute 'block_size'.
|
||||
And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,..
|
||||
3. Input B's quantization constants or scales are specified by input 'absmax'.
|
||||
|
||||
Input B is stored as uint8_t with shape: [(N * K + 1) / 2].
|
||||
Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size].
|
||||
|
||||
#### Version
|
||||
|
||||
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
|
||||
|
||||
#### Attributes
|
||||
|
||||
<dl>
|
||||
<dt><tt>K</tt> : int (required)</dt>
|
||||
<dd>size of each input feature</dd>
|
||||
<dt><tt>N</tt> : int (required)</dt>
|
||||
<dd>size of each output feature</dd>
|
||||
<dt><tt>block_size</tt> : int (required)</dt>
|
||||
<dd>number of groupsize used for weight quantization. It needs to be a power of 2 and not smaller than 16.</dd>
|
||||
<dt><tt>quant_type</tt> : int (required)</dt>
|
||||
<dd>quantization data type. 0 for FP4, 1 for NF4.</dd>
|
||||
</dl>
|
||||
|
||||
#### Inputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>A</tt> : T1</dt>
|
||||
<dd>The input tensor, not quantized</dd>
|
||||
<dt><tt>B</tt> : T2</dt>
|
||||
<dd>1-dimensional quantized data for weight</dd>
|
||||
<dt><tt>absmax</tt> : T1</dt>
|
||||
<dd>quantization constants</dd>
|
||||
</dl>
|
||||
|
||||
#### Outputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>Y</tt> : T1</dt>
|
||||
<dd>tensor. The output tensor has the same rank as the input. </dd>
|
||||
</dl>
|
||||
|
||||
#### Type Constraints
|
||||
|
||||
<dl>
|
||||
<dt><tt>T1</tt> : tensor(float), tensor(float16)</dt>
|
||||
<dd>Constrain input and output types to float/half_float tensors.</dd>
|
||||
<dt><tt>T2</tt> : tensor(uint8)</dt>
|
||||
<dd>Constrain quantized weight types to uint8.</dd>
|
||||
</dl>
|
||||
|
||||
|
||||
### <a name="com.microsoft.MatMulFpQ4"></a><a name="com.microsoft.matmulfpq4">**com.microsoft.MatMulFpQ4**</a>
|
||||
|
||||
Matrix product with right hand matrix being pre-packed and quantized int4 data blob.
|
||||
|
|
|
|||
|
|
@ -457,6 +457,7 @@ Do not modify directly.*
|
|||
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float)|
|
||||
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|
||||
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
|
||||
|MatMulFpQ4|*in* A:**T1**<br> *in* B:**T2**<br> *in* B_shape:**T3**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(int64)|
|
||||
|MatMulInteger16|*in* A:**T1**<br> *in* B:**T2**<br> *out* Y:**T3**|1+|**T1** = tensor(int16)<br/> **T2** = tensor(int16)<br/> **T3** = tensor(int32)|
|
||||
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float)|
|
||||
|
|
@ -852,6 +853,7 @@ Do not modify directly.*
|
|||
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|
||||
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T2**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|
||||
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|NGramRepeatBlock|*in* input_ids:**Tid**<br> *in* scores:**T**<br> *out* scores_out:**T**|1+|**T** = tensor(float)<br/> **Tid** = tensor(int64)|
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Gathe
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul); // backward compatibility
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FusedMatMul);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulNBits);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulBnb4);
|
||||
#ifndef ORT_MINIMAL_BUILD
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulFpQ4);
|
||||
#endif
|
||||
|
|
@ -270,6 +271,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul)>, // backward compatibility
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FusedMatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulNBits)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulBnb4)>,
|
||||
#ifndef ORT_MINIMAL_BUILD
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulFpQ4)>,
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -0,0 +1,202 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#define FORCEINLINE __forceinline
|
||||
#else
|
||||
#define FORCEINLINE __attribute__((always_inline)) inline
|
||||
#endif
|
||||
|
||||
typedef enum Bnb_DataType_t {
|
||||
FP4 = 0,
|
||||
NF4 = 1,
|
||||
} Bnb_DataType_t;
|
||||
|
||||
FORCEINLINE uint8_t QuantizeOneFP4(float x) {
|
||||
// FP4 with bias of 3
|
||||
// first bit is a sign
|
||||
// subnormals
|
||||
// 0b000 = 0
|
||||
// 0b001 = 0.0625
|
||||
// 0b110 = 2
|
||||
// 0b111 = 3
|
||||
// 0b100 = 4
|
||||
// 0b101 = 6
|
||||
// 0b010 = 8
|
||||
// 0b011 = 12
|
||||
|
||||
// we do a binary search
|
||||
// the pivots are divided by 12 (the FP4 absmax)
|
||||
// since we assum input data is in [-1.0, 1.0]
|
||||
|
||||
// !be careful here, its easy to make a mistake
|
||||
// that is difficult to noice if you add an extra
|
||||
// zero somewhere!
|
||||
|
||||
uint8_t sign = x < 0 ? 0b1000 : 0b0000;
|
||||
x = fabsf(x);
|
||||
if (x > 0.29166667f) {
|
||||
if (x > 0.583333f) {
|
||||
if (x > 0.8333333f) {
|
||||
return 0b0011 + sign;
|
||||
} else {
|
||||
return 0b0010 + sign;
|
||||
}
|
||||
} else if (x > 0.4166667f) {
|
||||
return 0b101 + sign;
|
||||
} else {
|
||||
return 0b100 + sign;
|
||||
}
|
||||
} else if (x > 0.0859375f) {
|
||||
if (x > 0.20833333f) {
|
||||
return 0b0111 + sign;
|
||||
} else {
|
||||
return 0b0110 + sign;
|
||||
}
|
||||
} else if (x > 0.00260417f) {
|
||||
return 0b0001 + sign;
|
||||
} else {
|
||||
return 0b0000 + sign;
|
||||
}
|
||||
}
|
||||
|
||||
FORCEINLINE uint8_t QuantizeOneNF4(float x) {
|
||||
if (x > 0.03979014977812767f) {
|
||||
if (x > 0.3893125355243683f) { // 1
|
||||
if (x > 0.6427869200706482f) { // 11
|
||||
if (x > 0.8614784181118011f) { // 111
|
||||
return 0b1111;
|
||||
} else {
|
||||
return 0b1110;
|
||||
}
|
||||
} else if (x > 0.5016634166240692f) { // 110
|
||||
return 0b1101;
|
||||
} else {
|
||||
return 0b1100;
|
||||
}
|
||||
} else if (x > 0.2035212516784668f) { // 10
|
||||
if (x > 0.2920137718319893f) { // 101
|
||||
return 0b1011;
|
||||
} else {
|
||||
return 0b1010;
|
||||
}
|
||||
} else if (x > 0.1202552504837513f) { // 100
|
||||
return 0b1001;
|
||||
} else {
|
||||
return 0b1000;
|
||||
}
|
||||
} else if (x > -0.33967943489551544f) { // 0
|
||||
if (x > -0.13791173323988914f) { // 01
|
||||
if (x > -0.045525018125772476f) { // 011
|
||||
return 0b0111;
|
||||
} else {
|
||||
return 0b0110;
|
||||
}
|
||||
} else if (x > -0.23460740596055984f) { // 010
|
||||
return 0b0101;
|
||||
} else {
|
||||
return 0b0100;
|
||||
}
|
||||
} else if (x > -0.6106329262256622f) { // 00
|
||||
if (x > -0.4599952697753906f) { // 001
|
||||
return 0b0011;
|
||||
} else {
|
||||
return 0b0010;
|
||||
}
|
||||
} else if (x > -0.8480964004993439f) { // 000
|
||||
return 0b0001;
|
||||
} else {
|
||||
return 0b0000;
|
||||
}
|
||||
}
|
||||
|
||||
template <int32_t DATA_TYPE>
|
||||
FORCEINLINE uint8_t QuantizeOneBnb4(float x) {
|
||||
if constexpr (DATA_TYPE == FP4)
|
||||
return QuantizeOneFP4(x);
|
||||
else
|
||||
return QuantizeOneNF4(x);
|
||||
}
|
||||
|
||||
template <typename T, int32_t block_size, int32_t DATA_TYPE>
|
||||
FORCEINLINE void QuantizeBlockBnb4(const T* src, uint8_t* dst, T& absmax_block, int32_t block_idx, int32_t numel) {
|
||||
float local_absmax = 0.0f;
|
||||
|
||||
int32_t block_len = std::min(block_size, numel - block_idx * block_size);
|
||||
int32_t src_offset = block_idx * block_size;
|
||||
int32_t dst_offset = block_idx * block_size / 2;
|
||||
|
||||
for (int32_t idx = 0; idx < block_len; idx++) {
|
||||
const float v = static_cast<float>(src[src_offset + idx]);
|
||||
local_absmax = fmaxf(local_absmax, fabsf(v));
|
||||
}
|
||||
|
||||
absmax_block = static_cast<T>(local_absmax);
|
||||
const float reciprocal_absmax = local_absmax ? 1.0f / local_absmax : 0.0f;
|
||||
|
||||
for (int32_t idx = 0; idx < block_len; idx += 2) {
|
||||
const float v0 = static_cast<float>(src[src_offset + idx]) * reciprocal_absmax;
|
||||
const uint8_t vi0 = QuantizeOneBnb4<DATA_TYPE>(v0);
|
||||
|
||||
const float v1 = (idx + 1 < block_len) ? static_cast<float>(src[src_offset + idx + 1]) * reciprocal_absmax : 0;
|
||||
const uint8_t vi1 = QuantizeOneBnb4<DATA_TYPE>(v1);
|
||||
|
||||
dst[dst_offset + idx / 2] = (vi0 << 4) | vi1;
|
||||
}
|
||||
}
|
||||
|
||||
static float fp4_qaunt_map[16] = {0.00000000f, 5.208333333e-03f, 0.66666667f, 1.00000000f,
|
||||
0.33333333f, 0.50000000f, 0.16666667f, 0.25000000f,
|
||||
-0.00000000f, -5.208333333e-03f, -0.66666667f, -1.00000000f,
|
||||
-0.33333333f, -0.50000000f, -0.16666667f, -0.25000000f};
|
||||
|
||||
static float nf4_qaunt_map[16] = {-1.0f,
|
||||
-0.6961928009986877f,
|
||||
-0.5250730514526367f,
|
||||
-0.39491748809814453f,
|
||||
-0.28444138169288635f,
|
||||
-0.18477343022823334f,
|
||||
-0.09105003625154495f,
|
||||
0.0f,
|
||||
0.07958029955625534f,
|
||||
0.16093020141124725f,
|
||||
0.24611230194568634f,
|
||||
0.33791524171829224f,
|
||||
0.44070982933044434f,
|
||||
0.5626170039176941f,
|
||||
0.7229568362236023f,
|
||||
1.0f};
|
||||
|
||||
template <typename T, int32_t DATA_TYPE>
|
||||
FORCEINLINE T DequantizeOneBnb4(uint8_t x) {
|
||||
if constexpr (DATA_TYPE == FP4)
|
||||
return static_cast<T>(fp4_qaunt_map[x]);
|
||||
else
|
||||
return static_cast<T>(nf4_qaunt_map[x]);
|
||||
}
|
||||
|
||||
template <typename T, int32_t block_size, int32_t DATA_TYPE>
|
||||
FORCEINLINE void DequantizeBlockBnb4(const uint8_t* src, T* dst, T absmax_block, int32_t block_idx, int32_t numel) {
|
||||
int32_t block_len = std::min(block_size, numel - block_idx * block_size);
|
||||
int32_t src_offset = block_idx * block_size / 2;
|
||||
int32_t dst_offset = block_idx * block_size;
|
||||
|
||||
for (int32_t idx = 0; idx < block_len; idx += 2) {
|
||||
const uint8_t val = src[src_offset + idx / 2];
|
||||
|
||||
dst[dst_offset + idx] = DequantizeOneBnb4<T, DATA_TYPE>(val >> 4) * absmax_block;
|
||||
if (idx + 1 < block_len) dst[dst_offset + idx + 1] = DequantizeOneBnb4<T, DATA_TYPE>(val & 0xF) * absmax_block;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,143 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "blockwise_quant_block_bnb4.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "core/common/safeint.h"
|
||||
#include "core/framework/float16.h"
|
||||
#include "core/platform/threadpool.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
template <typename T, int32_t block_size, int32_t DATA_TYPE>
|
||||
void QuantizeBlockwiseBnb4(
|
||||
uint8_t* dst, // shape: [(N * K + 1) / 2]
|
||||
const T* src, // shape: [N, K]
|
||||
T* absmax, // shape: [(N * K + block_size - 1) / block_size]
|
||||
int32_t N,
|
||||
int32_t K,
|
||||
onnxruntime::concurrency::ThreadPool* thread_pool) {
|
||||
int32_t numel = N * K;
|
||||
int32_t total_block_count = (numel + block_size - 1) / block_size;
|
||||
|
||||
concurrency::ThreadPool::TryBatchParallelFor(
|
||||
thread_pool,
|
||||
total_block_count,
|
||||
[&](ptrdiff_t block_idx) {
|
||||
QuantizeBlockBnb4<T, block_size, DATA_TYPE>(
|
||||
src,
|
||||
dst,
|
||||
absmax[block_idx],
|
||||
static_cast<int32_t>(block_idx),
|
||||
numel);
|
||||
},
|
||||
0);
|
||||
}
|
||||
|
||||
#define QuantizeBlockwiseBn4DataTyped(block_size, quant_type) \
|
||||
if (quant_type == FP4) \
|
||||
QuantizeBlockwiseBnb4<T, block_size, FP4>(dst, src, absmax, N, K, thread_pool); \
|
||||
else \
|
||||
QuantizeBlockwiseBnb4<T, block_size, NF4>(dst, src, absmax, N, K, thread_pool);
|
||||
|
||||
template <typename T>
|
||||
void QuantizeBlockwiseBnb4(
|
||||
uint8_t* dst, // shape: [(N * K + 1) / 2]
|
||||
const T* src, // shape: [N, K]
|
||||
T* absmax, // shape: [(N * K + block_size - 1) / block_size]
|
||||
int32_t block_size,
|
||||
int32_t quant_type,
|
||||
int32_t N,
|
||||
int32_t K,
|
||||
onnxruntime::concurrency::ThreadPool* thread_pool) {
|
||||
ORT_ENFORCE(
|
||||
quant_type == FP4 || quant_type == NF4,
|
||||
"Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
|
||||
|
||||
if (block_size == 16) {
|
||||
QuantizeBlockwiseBn4DataTyped(16, quant_type);
|
||||
} else if (block_size == 32) {
|
||||
QuantizeBlockwiseBn4DataTyped(32, quant_type);
|
||||
} else if (block_size == 64) {
|
||||
QuantizeBlockwiseBn4DataTyped(64, quant_type);
|
||||
} else if (block_size == 128) {
|
||||
QuantizeBlockwiseBn4DataTyped(128, quant_type);
|
||||
} else if (block_size == 256) {
|
||||
QuantizeBlockwiseBn4DataTyped(256, quant_type);
|
||||
} else {
|
||||
ORT_NOT_IMPLEMENTED("only block size 16, 32, 64, 128, 256 are supported.");
|
||||
}
|
||||
}
|
||||
|
||||
#undef QuantizeBlockwiseBn4DataTyped
|
||||
|
||||
template <typename T, int32_t block_size, int32_t DATA_TYPE>
|
||||
void DequantizeBlockwiseBnb4(
|
||||
T* dst, // shape: [N, K]
|
||||
const uint8_t* src, // shape: [(N * K + 1) / 2)]
|
||||
const T* absmax, // shape: [(N * K + block_size - 1) / block_size]
|
||||
int32_t N,
|
||||
int32_t K,
|
||||
onnxruntime::concurrency::ThreadPool* thread_pool) {
|
||||
int32_t numel = N * K;
|
||||
int32_t total_block_count = (numel + block_size - 1) / block_size;
|
||||
|
||||
concurrency::ThreadPool::TryBatchParallelFor(
|
||||
thread_pool,
|
||||
total_block_count,
|
||||
[&](ptrdiff_t block_idx) {
|
||||
DequantizeBlockBnb4<T, block_size, DATA_TYPE>(
|
||||
src,
|
||||
dst,
|
||||
absmax[block_idx],
|
||||
static_cast<int32_t>(block_idx),
|
||||
numel);
|
||||
},
|
||||
0);
|
||||
}
|
||||
|
||||
#define DequantizeBlockwiseBn4DataTyped(block_size, quant_type) \
|
||||
if (quant_type == FP4) \
|
||||
DequantizeBlockwiseBnb4<T, block_size, FP4>(dst, src, absmax, N, K, thread_pool); \
|
||||
else \
|
||||
DequantizeBlockwiseBnb4<T, block_size, NF4>(dst, src, absmax, N, K, thread_pool);
|
||||
|
||||
template <typename T>
|
||||
void DequantizeBlockwiseBnb4(
|
||||
T* dst, // shape: [N, K]
|
||||
const uint8_t* src, // shape: [(N * K + 1) / 2)]
|
||||
const T* absmax, // shape: [(N * K + block_size - 1) / block_size]
|
||||
int32_t block_size,
|
||||
int32_t quant_type,
|
||||
int32_t N,
|
||||
int32_t K,
|
||||
onnxruntime::concurrency::ThreadPool* thread_pool) {
|
||||
ORT_ENFORCE(
|
||||
quant_type == FP4 || quant_type == NF4,
|
||||
"Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
|
||||
|
||||
if (block_size == 16) {
|
||||
DequantizeBlockwiseBn4DataTyped(16, quant_type);
|
||||
} else if (block_size == 32) {
|
||||
DequantizeBlockwiseBn4DataTyped(32, quant_type);
|
||||
} else if (block_size == 64) {
|
||||
DequantizeBlockwiseBn4DataTyped(64, quant_type);
|
||||
} else if (block_size == 128) {
|
||||
DequantizeBlockwiseBn4DataTyped(128, quant_type);
|
||||
} else if (block_size == 256) {
|
||||
DequantizeBlockwiseBn4DataTyped(256, quant_type);
|
||||
} else {
|
||||
ORT_NOT_IMPLEMENTED("only block size 16, 32, 64, 128, 256 are supported.");
|
||||
}
|
||||
}
|
||||
|
||||
#undef DequantizeBlockwiseBn4DataTyped
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
109
onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc
Normal file
109
onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/common/safeint.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/providers/cpu/math/matmul_helper.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "dequantize_blockwise_bnb4.h"
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
class MatMulBnb4 final : public OpKernel {
|
||||
public:
|
||||
MatMulBnb4(const OpKernelInfo& info) : OpKernel(info) {
|
||||
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("K", &K_));
|
||||
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("N", &N_));
|
||||
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("block_size", &block_size_));
|
||||
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("quant_type", &quant_type_));
|
||||
ORT_ENFORCE(
|
||||
quant_type_ == FP4 || quant_type_ == NF4,
|
||||
"Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
int64_t K_;
|
||||
int64_t N_;
|
||||
int64_t block_size_;
|
||||
int64_t quant_type_;
|
||||
};
|
||||
|
||||
Status MatMulBnb4::Compute(OpKernelContext* ctx) const {
|
||||
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
|
||||
|
||||
const Tensor* a = ctx->Input<Tensor>(0);
|
||||
const Tensor* b_quant = ctx->Input<Tensor>(1);
|
||||
const Tensor* absmax = ctx->Input<Tensor>(2);
|
||||
|
||||
const float* a_data = a->Data<float>();
|
||||
const uint8_t* b_quant_data = b_quant->Data<uint8_t>();
|
||||
const float* absmax_data = absmax->Data<float>();
|
||||
|
||||
AllocatorPtr allocator;
|
||||
auto status = ctx->GetTempSpaceAllocator(&allocator);
|
||||
ORT_RETURN_IF_ERROR(status);
|
||||
auto tmp_b_data_ptr = IAllocator::MakeUniquePtr<float>(allocator, SafeInt<size_t>(K_) * N_);
|
||||
DequantizeBlockwiseBnb4<float>(
|
||||
tmp_b_data_ptr.get(),
|
||||
b_quant_data,
|
||||
absmax_data,
|
||||
static_cast<int32_t>(block_size_),
|
||||
static_cast<int32_t>(quant_type_),
|
||||
static_cast<int32_t>(N_),
|
||||
static_cast<int32_t>(K_),
|
||||
thread_pool);
|
||||
|
||||
constexpr bool transa = false;
|
||||
constexpr bool transb = true;
|
||||
TensorShape b_shape({N_, K_});
|
||||
MatMulComputeHelper helper;
|
||||
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, transa, transb));
|
||||
|
||||
Tensor* y = ctx->Output(0, helper.OutputShape());
|
||||
|
||||
// Bail out early if the output is going to be empty
|
||||
if (y->Shape().Size() == 0) return Status::OK();
|
||||
|
||||
auto* y_data = y->MutableData<float>();
|
||||
|
||||
const size_t max_len = helper.OutputOffsets().size();
|
||||
const size_t M = static_cast<size_t>(helper.M());
|
||||
const size_t N = static_cast<size_t>(helper.N());
|
||||
const size_t K = static_cast<size_t>(helper.K());
|
||||
const size_t lda = helper.Lda(transa);
|
||||
const size_t ldb = helper.Ldb(transb);
|
||||
|
||||
// TODO: implement with native kernel
|
||||
std::vector<MLAS_SGEMM_DATA_PARAMS> data(max_len);
|
||||
for (size_t i = 0; i < max_len; i++) {
|
||||
data[i].BIsPacked = false;
|
||||
data[i].A = a_data + helper.LeftOffsets()[i];
|
||||
data[i].lda = lda;
|
||||
data[i].B = tmp_b_data_ptr.get() + helper.RightOffsets()[i];
|
||||
data[i].ldb = ldb;
|
||||
data[i].C = y_data + helper.OutputOffsets()[i];
|
||||
data[i].ldc = N;
|
||||
data[i].alpha = 1.f;
|
||||
data[i].beta = 0.0f;
|
||||
}
|
||||
MlasGemmBatch(CblasNoTrans, CblasTrans, M, N, K, data.data(), max_len, thread_pool);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
MatMulBnb4,
|
||||
kMSDomain,
|
||||
1,
|
||||
kCpuExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<float>())
|
||||
.TypeConstraint("T2", DataTypeImpl::GetTensorType<uint8_t>()),
|
||||
MatMulBnb4);
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -118,6 +118,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulBnb4);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulBnb4);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Trilu);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, UnfoldTensor);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, DynamicTimeWarping);
|
||||
|
|
@ -279,6 +281,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulBnb4)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulBnb4)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BiasSoftmax)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BiasDropout)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BitmaskDropout)>,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,129 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <cub/cub.cuh>
|
||||
#include <cuda_fp16.h>
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h"
|
||||
#include "dequantize_blockwise_bnb4.cuh"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
template <class T>
|
||||
Status SetBnbQuantMap(int quant_type, T* quant_map_buffer, cudaStream_t stream) {
|
||||
ORT_ENFORCE(
|
||||
quant_type == FP4 || quant_type == NF4,
|
||||
"Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
|
||||
|
||||
T host_quant_map[16];
|
||||
switch (quant_type) {
|
||||
case FP4:
|
||||
for (int i = 0; i < 16; i++) host_quant_map[i] = static_cast<T>(fp4_qaunt_map[i]);
|
||||
break;
|
||||
case NF4:
|
||||
for (int i = 0; i < 16; i++) host_quant_map[i] = static_cast<T>(nf4_qaunt_map[i]);
|
||||
break;
|
||||
}
|
||||
CUDA_CALL_THROW(cudaMemcpyAsync(quant_map_buffer, host_quant_map, sizeof(T) * 16, cudaMemcpyHostToDevice, stream));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template Status SetBnbQuantMap<float>(int quant_type, float* quant_map_buffer, cudaStream_t stream);
|
||||
|
||||
template Status SetBnbQuantMap<half>(int quant_type, half* quant_map_buffer, cudaStream_t stream);
|
||||
|
||||
template <class T, int TILE_SIZE, int THREADS, int NUM_PER_TH>
|
||||
__global__ void kDequantizeBlockwise(
|
||||
const T* quant_map,
|
||||
T* output,
|
||||
const uint8_t* quant_data,
|
||||
const T* absmax,
|
||||
const int block_size,
|
||||
const int n) {
|
||||
const int n_load = (gridDim.x * TILE_SIZE);
|
||||
int valid_items_load = 0;
|
||||
int valid_items_store = 0;
|
||||
const int base_idx = (blockIdx.x * TILE_SIZE);
|
||||
|
||||
T vals[NUM_PER_TH * 2];
|
||||
uint8_t qvals[NUM_PER_TH];
|
||||
T local_abs_max = T(0.0f);
|
||||
|
||||
typedef cub::BlockLoad<uint8_t, THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
|
||||
typedef cub::BlockStore<T, THREADS, NUM_PER_TH * 2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
|
||||
|
||||
__shared__ typename LoadChar::TempStorage loadchar;
|
||||
__shared__ typename StoreT::TempStorage storet;
|
||||
|
||||
for (unsigned int i = base_idx; i < n_load; i += gridDim.x * TILE_SIZE) {
|
||||
valid_items_load = (n + 1) / 2 - i > TILE_SIZE ? TILE_SIZE : (n + 1) / 2 - i;
|
||||
valid_items_store = n - i * 2 > TILE_SIZE * 2 ? TILE_SIZE * 2 : n - i * 2;
|
||||
|
||||
local_abs_max = __ldg(&absmax[(i + threadIdx.x * NUM_PER_TH) / (block_size)]);
|
||||
|
||||
__syncthreads();
|
||||
LoadChar(loadchar).Load(&(quant_data[i]), qvals, valid_items_load, 128);
|
||||
|
||||
#pragma unroll NUM_PER_TH
|
||||
for (int j = 0; j < NUM_PER_TH; j++) {
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530
|
||||
vals[j * 2] = quant_map[qvals[j] >> 4] * local_abs_max;
|
||||
vals[j * 2 + 1] = quant_map[qvals[j] & 0x0F] * local_abs_max;
|
||||
#else
|
||||
// half multiplication not supported
|
||||
vals[j * 2] = static_cast<T>(static_cast<float>(quant_map[qvals[j] >> 4]) * static_cast<float>(local_abs_max));
|
||||
vals[j * 2 + 1] =
|
||||
static_cast<T>(static_cast<float>(quant_map[qvals[j] & 0x0F]) * static_cast<float>(local_abs_max));
|
||||
#endif
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
StoreT(storet).Store(&(output[i * 2]), vals, valid_items_store);
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
Status DequantizeBnb4(
|
||||
const T* quant_map,
|
||||
T* output,
|
||||
const uint8_t* quant_data,
|
||||
const T* absmax,
|
||||
int block_size,
|
||||
int numel,
|
||||
cudaStream_t stream) {
|
||||
int tile_size = 1024;
|
||||
kDequantizeBlockwise<T, 512, 64, 8><<<(numel + tile_size - 1) / tile_size, 64, 0, stream>>>(
|
||||
quant_map,
|
||||
output,
|
||||
quant_data,
|
||||
absmax,
|
||||
block_size / 2,
|
||||
numel);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template Status DequantizeBnb4<float>(
|
||||
const float* quant_map,
|
||||
float* output,
|
||||
const uint8_t* quant_data,
|
||||
const float* absmax,
|
||||
int block_size,
|
||||
int numel,
|
||||
cudaStream_t stream);
|
||||
|
||||
template Status DequantizeBnb4<half>(
|
||||
const half* quant_map,
|
||||
half* output,
|
||||
const uint8_t* quant_data,
|
||||
const half *absmax,
|
||||
int block_size,
|
||||
int numel,
|
||||
cudaStream_t stream);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "core/providers/cuda/shared_inc/cuda_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
template <class T>
|
||||
Status SetBnbQuantMap(int quant_type, T* quant_map_buffer, cudaStream_t stream);
|
||||
|
||||
template <class T>
|
||||
Status DequantizeBnb4(
|
||||
const T* quant_map,
|
||||
T* output,
|
||||
const uint8_t* quant_data,
|
||||
const T* absmax,
|
||||
int block_size,
|
||||
int numel,
|
||||
cudaStream_t stream);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
144
onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc
Normal file
144
onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/common/safeint.h"
|
||||
#include "core/providers/cuda/cuda_kernel.h"
|
||||
#include "core/providers/cuda/shared_inc/fpgeneric.h"
|
||||
#include "core/providers/cpu/math/matmul_helper.h"
|
||||
#include "contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h"
|
||||
#include "matmul_bnb4.cuh"
|
||||
#include "dequantize_blockwise_bnb4.cuh"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
using namespace onnxruntime::cuda;
|
||||
|
||||
template <typename T>
|
||||
class MatMulBnb4 final : public CudaKernel {
|
||||
public:
|
||||
MatMulBnb4(const OpKernelInfo& info) : CudaKernel(info) {
|
||||
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("K", &K_));
|
||||
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("N", &N_));
|
||||
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("block_size", &block_size_));
|
||||
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("quant_type", &quant_type_));
|
||||
ORT_ENFORCE(
|
||||
quant_type_ == FP4 || quant_type_ == NF4,
|
||||
"Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
int64_t K_;
|
||||
int64_t N_;
|
||||
int64_t block_size_;
|
||||
int64_t quant_type_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
Status MatMulBnb4<T>::ComputeInternal(OpKernelContext* ctx) const {
|
||||
const Tensor* a = ctx->Input<Tensor>(0);
|
||||
const Tensor* b_quant = ctx->Input<Tensor>(1);
|
||||
const Tensor* absmax = ctx->Input<Tensor>(2);
|
||||
|
||||
const auto* a_data = a->Data<T>();
|
||||
const uint8_t* b_quant_data = b_quant->Data<uint8_t>();
|
||||
const auto* absmax_data = absmax->Data<T>();
|
||||
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
|
||||
// TODO: find a better way to create the quant_map without using a buffer
|
||||
// don't want to use malloc directly so asking from the caller
|
||||
// can create a __device__ static array for float but doesn't work for half
|
||||
IAllocatorUniquePtr<T> quant_map_buffer = GetScratchBuffer<T>(16, ctx->GetComputeStream());
|
||||
auto* quant_map_buffer_data = quant_map_buffer.get();
|
||||
ORT_RETURN_IF_ERROR(SetBnbQuantMap<CudaT>(
|
||||
SafeInt<int>(quant_type_),
|
||||
reinterpret_cast<CudaT*>(quant_map_buffer_data),
|
||||
static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle())));
|
||||
|
||||
constexpr bool transa = false;
|
||||
constexpr bool transb = true;
|
||||
MatMulComputeHelper helper;
|
||||
TensorShape b_shape({N_, K_});
|
||||
ORT_RETURN_IF_ERROR(
|
||||
helper.Compute(a->Shape(), b_shape, transa, transb));
|
||||
|
||||
Tensor* Y = ctx->Output(0, helper.OutputShape());
|
||||
// Bail out early if the output is going to be empty
|
||||
if (Y->Shape().Size() == 0) return Status::OK();
|
||||
|
||||
bool is_4bit_done = TryMatMulBnb4(
|
||||
reinterpret_cast<const CudaT*>(quant_map_buffer_data),
|
||||
reinterpret_cast<CudaT*>(Y->MutableData<T>()),
|
||||
reinterpret_cast<const CudaT*>(a_data),
|
||||
b_quant_data,
|
||||
reinterpret_cast<const CudaT*>(absmax_data),
|
||||
SafeInt<int>(helper.M()),
|
||||
SafeInt<int>(helper.N()),
|
||||
SafeInt<int>(helper.K()),
|
||||
SafeInt<int>(block_size_),
|
||||
static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle()));
|
||||
|
||||
if (!is_4bit_done) {
|
||||
IAllocatorUniquePtr<T> b_dequant_ptr = GetScratchBuffer<T>(N_ * K_, ctx->GetComputeStream());
|
||||
auto* b_dequant_data = b_dequant_ptr.get();
|
||||
ORT_RETURN_IF_ERROR(DequantizeBnb4<CudaT>(
|
||||
reinterpret_cast<const CudaT*>(quant_map_buffer_data),
|
||||
reinterpret_cast<CudaT*>(b_dequant_data),
|
||||
b_quant_data,
|
||||
reinterpret_cast<const CudaT*>(absmax_data),
|
||||
SafeInt<int>(block_size_),
|
||||
SafeInt<int>(N_ * K_),
|
||||
static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle())));
|
||||
|
||||
const CudaT alpha = ToCudaType<T>::FromFloat(1.f);
|
||||
const CudaT zero = ToCudaType<T>::FromFloat(0.f);
|
||||
|
||||
CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
|
||||
GetCublasHandle(ctx),
|
||||
CUBLAS_OP_T,
|
||||
CUBLAS_OP_N,
|
||||
SafeInt<int>(helper.N()),
|
||||
SafeInt<int>(helper.M()),
|
||||
SafeInt<int>(helper.K()),
|
||||
&alpha,
|
||||
reinterpret_cast<const CudaT*>(b_dequant_data),
|
||||
SafeInt<int>(K_),
|
||||
reinterpret_cast<const CudaT*>(a_data),
|
||||
helper.Lda(transa),
|
||||
&zero,
|
||||
reinterpret_cast<CudaT*>(Y->MutableData<T>()),
|
||||
helper.Ldc(),
|
||||
GetDeviceProp()));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX(
|
||||
MatMulBnb4,
|
||||
kMSDomain,
|
||||
1,
|
||||
float,
|
||||
kCudaExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<float>())
|
||||
.TypeConstraint("T2", DataTypeImpl::GetTensorType<uint8_t>()),
|
||||
MatMulBnb4<float>);
|
||||
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX(
|
||||
MatMulBnb4,
|
||||
kMSDomain,
|
||||
1,
|
||||
MLFloat16,
|
||||
kCudaExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<MLFloat16>())
|
||||
.TypeConstraint("T2", DataTypeImpl::GetTensorType<uint8_t>()),
|
||||
MatMulBnb4<MLFloat16>);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
192
onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu
Normal file
192
onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu
Normal file
|
|
@ -0,0 +1,192 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include <cub/cub.cuh>
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include "matmul_bnb4.cuh"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
#define num_values_4bit 32
|
||||
template <class T, int THREADS, int BITS>
|
||||
__global__ void kgemm_4bit_inference_naive(
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
const T* __restrict__ A,
|
||||
const uint8_t* B,
|
||||
const T* absmax,
|
||||
const T* datatype,
|
||||
T* out,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc,
|
||||
int block_size) {
|
||||
// per threadblock:
|
||||
// load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps]
|
||||
// 4 warps -> 4 loads per iter
|
||||
// 1x32 * 32x4 -> 1x4 outputs per thread block
|
||||
typedef cub::WarpReduce<float> WarpReduce;
|
||||
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS / 32];
|
||||
|
||||
const int warp_idx = threadIdx.x / 32;
|
||||
const int warp_lane = threadIdx.x % 32;
|
||||
const int row_B = (THREADS / 32) * blockIdx.x + warp_idx;
|
||||
const int num_values_8bit = num_values_4bit / 2;
|
||||
float local_C = 0.0f;
|
||||
|
||||
uint8_t local_B_4bit[num_values_8bit];
|
||||
T local_B[num_values_4bit / 4];
|
||||
T local_A[num_values_4bit / 4];
|
||||
__shared__ T quant_map[16];
|
||||
T local_absmax = T(0.0f);
|
||||
|
||||
for (int i = threadIdx.x; i < 16; i++) quant_map[i] = T(datatype[i]);
|
||||
__syncthreads();
|
||||
|
||||
// A: [1, K]
|
||||
// B: [N, K]
|
||||
for (int inner_idx = warp_lane * num_values_4bit; inner_idx < K; inner_idx += 32 * num_values_4bit) {
|
||||
int inner_idx_halved = inner_idx / 2;
|
||||
int offset_B = ldb * row_B;
|
||||
int absidx = ((2 * offset_B) + inner_idx) / block_size;
|
||||
local_absmax = __ldg(&(absmax[absidx]));
|
||||
|
||||
if (row_B < N) {
|
||||
if ((inner_idx_halved + num_values_8bit) < (K / 2)) {
|
||||
// this is the most important for performance considerations
|
||||
reinterpret_cast<int4(&)[num_values_8bit]>(local_B_4bit)[0] =
|
||||
reinterpret_cast<const int4*>(B)[(offset_B + (inner_idx_halved)) / (num_values_8bit)];
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < (num_values_8bit); j++)
|
||||
if ((inner_idx_halved) + j < (K / 2))
|
||||
local_B_4bit[j] = B[offset_B + inner_idx_halved + j];
|
||||
else
|
||||
local_B_4bit[j] = 0b01110111;
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < (num_values_8bit); j++) local_B_4bit[j] = 0b01110111;
|
||||
}
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < num_values_8bit / 4; k++) {
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530
|
||||
local_B[k * 2] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * local_absmax;
|
||||
local_B[k * 2 + 1] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * local_absmax;
|
||||
#else
|
||||
// half multiplication not supported
|
||||
local_B[k * 2] =
|
||||
static_cast<T>(static_cast<float>(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4]) *
|
||||
static_cast<float>(local_absmax));
|
||||
local_B[k * 2 + 1] =
|
||||
static_cast<T>(static_cast<float>(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F]) *
|
||||
static_cast<float>(local_absmax));
|
||||
#endif
|
||||
}
|
||||
|
||||
if (inner_idx + (num_values_4bit / 4) + (i * num_values_4bit / 4) < K) {
|
||||
// this is also relatively important for performance
|
||||
if (BITS == 16) {
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[0] =
|
||||
reinterpret_cast<const int4*>(A)[inner_idx / (num_values_4bit / 4) + i];
|
||||
} else {
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[0] =
|
||||
reinterpret_cast<const int4*>(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 0];
|
||||
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[1] =
|
||||
reinterpret_cast<const int4*>(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 1];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < num_values_4bit / 4; k++) {
|
||||
if (inner_idx + (i * num_values_4bit / 4) + k < K)
|
||||
local_A[k] = A[inner_idx + k + (i * num_values_4bit / 4)];
|
||||
else
|
||||
local_A[k] = T(0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
// accumulate in float; small performance hit for Ampere, but lower error for outputs
|
||||
#pragma unroll
|
||||
for (int k = 0; k < num_values_4bit / 4; k++) {
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530
|
||||
local_C += static_cast<float>(local_A[k] * local_B[k]);
|
||||
#else
|
||||
// half multiplication not supported
|
||||
local_C += static_cast<float>(local_A[k]) * static_cast<float>(local_B[k]);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C);
|
||||
|
||||
if (row_B < N && warp_lane == 0) out[row_B] = T(local_C);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool TryMatMulBnb4(
|
||||
const T* quant_map,
|
||||
T* output,
|
||||
const T* a_data,
|
||||
const uint8_t* b_data_quant,
|
||||
const T* absmax,
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
int block_size,
|
||||
cudaStream_t stream) {
|
||||
if (k % block_size != 0 || m > 1) {
|
||||
return false;
|
||||
}
|
||||
// supported block_sizes are [4096, 2048, 1024, 512, 256, 128, 64, 32]
|
||||
if (block_size % 32 != 0 || block_size > 4096) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int lda = k;
|
||||
int ldb = (k + 1) / 2;
|
||||
int ldc = n;
|
||||
int num_blocks = (n + 3) / 4;
|
||||
|
||||
constexpr int bits = std::is_same_v<T, half> ? 16 : 32;
|
||||
kgemm_4bit_inference_naive<T, 128, bits><<<num_blocks, 128, 0, stream>>>(
|
||||
m, n, k, a_data, b_data_quant, absmax, quant_map, output, lda, ldb, ldc, block_size);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template bool TryMatMulBnb4<float>(
|
||||
const float* quant_map,
|
||||
float* output,
|
||||
const float* a_data,
|
||||
const uint8_t* b_data_quant,
|
||||
const float* absmax,
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
int block_size,
|
||||
cudaStream_t stream);
|
||||
|
||||
template bool TryMatMulBnb4<half>(
|
||||
const half* quant_map,
|
||||
half* output,
|
||||
const half* a_data,
|
||||
const uint8_t* b_data_quant,
|
||||
const half* absmax,
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
int block_size,
|
||||
cudaStream_t stream);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
26
onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cuh
Normal file
26
onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cuh
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "core/providers/cuda/shared_inc/cuda_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
template <class T>
|
||||
bool TryMatMulBnb4(
|
||||
const T* quant_map,
|
||||
T* output,
|
||||
const T* a_data,
|
||||
const uint8_t* b_data_quant,
|
||||
const T* absmax,
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
int block_size,
|
||||
cudaStream_t stream);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -3239,6 +3239,41 @@ Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored
|
|||
MatmulWithQuantWeightShapeInference(ctx, in_features, out_features);
|
||||
});
|
||||
|
||||
static const char* MatMulBnb4_ver1_doc = R"DOC(
|
||||
MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 data type (https://arxiv.org/pdf/2305.14314.pdf). It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences:
|
||||
1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'.
|
||||
2. Input B is quantized with 4 bits with quantization data type specified by attribute 'quant_type'. It is transposed, flattened and quantized blockwisely with block size specified by attribute 'block_size'.
|
||||
And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,..
|
||||
3. Input B's quantization constants or scales are specified by input 'absmax'.
|
||||
|
||||
Input B is stored as uint8_t with shape: [(N * K + 1) / 2].
|
||||
Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size].
|
||||
|
||||
)DOC";
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulBnb4)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
.SetDoc(MatMulBnb4_ver1_doc)
|
||||
.Attr("K", "size of each input feature", AttributeProto::INT)
|
||||
.Attr("N", "size of each output feature", AttributeProto::INT)
|
||||
.Attr("block_size", "number of groupsize used for weight quantization. It needs to be a power of 2 and not smaller than 16.", AttributeProto::INT)
|
||||
.Attr("quant_type", "quantization data type. 0 for FP4, 1 for NF4.", AttributeProto::INT)
|
||||
.Input(0, "A", "The input tensor, not quantized", "T1")
|
||||
.Input(1, "B", "1-dimensional quantized data for weight", "T2")
|
||||
.Input(2, "absmax", "quantization constants", "T1")
|
||||
.Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1")
|
||||
.TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.")
|
||||
.TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
// Type inference
|
||||
propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
// Shape inference
|
||||
int64_t in_features = getAttribute(ctx, "K", -1);
|
||||
int64_t out_features = getAttribute(ctx, "N", -1);
|
||||
MatmulWithQuantWeightShapeInference(ctx, in_features, out_features);
|
||||
});
|
||||
|
||||
#ifdef ENABLE_ATEN
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(ATen)
|
||||
.SetDomain(kPytorchAtenDomain)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#include <pybind11/functional.h>
|
||||
|
||||
#include "contrib_ops/cpu/quantization/dequantize_blockwise.h"
|
||||
#include "contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h"
|
||||
#include "core/util/thread_utils.h"
|
||||
|
||||
namespace pybind11 {
|
||||
|
|
@ -64,9 +65,39 @@ void QuantizeMatMul4BitsBlockwise(
|
|||
tp.get());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void QuantizeMatMulBnb4Blockwise(
|
||||
py::array_t<uint8_t> dst,
|
||||
py::array_t<T> src,
|
||||
py::array_t<T> absmax,
|
||||
int32_t block_size,
|
||||
int32_t quant_type,
|
||||
int32_t N,
|
||||
int32_t K) {
|
||||
OrtThreadPoolParams to;
|
||||
auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to,
|
||||
concurrency::ThreadPoolType::INTRA_OP);
|
||||
|
||||
py::buffer_info dst_buf = dst.request();
|
||||
py::buffer_info src_buf = src.request();
|
||||
py::buffer_info absmax_buf = absmax.request();
|
||||
|
||||
contrib::QuantizeBlockwiseBnb4<T>(
|
||||
static_cast<uint8_t*>(dst_buf.ptr),
|
||||
static_cast<const T*>(src_buf.ptr),
|
||||
static_cast<T*>(absmax_buf.ptr),
|
||||
block_size,
|
||||
quant_type,
|
||||
N,
|
||||
K,
|
||||
tp.get());
|
||||
}
|
||||
|
||||
void CreateQuantPybindModule(py::module& m) {
|
||||
m.def("quantize_matmul_4bits", &QuantizeMatMul4BitsBlockwise<float>);
|
||||
m.def("quantize_matmul_4bits", &QuantizeMatMul4BitsBlockwise<MLFloat16>);
|
||||
m.def("quantize_matmul_bnb4", &QuantizeMatMulBnb4Blockwise<float>);
|
||||
m.def("quantize_matmul_bnb4", &QuantizeMatMulBnb4Blockwise<MLFloat16>);
|
||||
}
|
||||
|
||||
} // namespace python
|
||||
|
|
|
|||
|
|
@ -0,0 +1,89 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// This file serve as a simple example for adding a tunable op to onnxruntime.
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "core/providers/cuda/tunable/cuda_tunable.h"
|
||||
#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
|
||||
#include "python/tools/kernel_explorer/device_array.h"
|
||||
#include "contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
// Extend the OpParams so that all specializations have the same parameter passing interface
|
||||
template <typename T>
|
||||
struct DequantizeBnb4Params : cuda::tunable::OpParams {
|
||||
std::string Signature() const override { return std::to_string(n_); }
|
||||
|
||||
int quant_type_;
|
||||
T* output_;
|
||||
const uint8_t* quant_;
|
||||
const T* absmax_;
|
||||
T* quant_map_buffer_;
|
||||
int n_;
|
||||
int k_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class DequantizeBnb4 : public IKernelExplorer {
|
||||
public:
|
||||
DequantizeBnb4(
|
||||
int quant_type,
|
||||
DeviceArray& output,
|
||||
DeviceArray& quant,
|
||||
DeviceArray& absmax,
|
||||
DeviceArray& quant_map_buffer,
|
||||
int n, int k) {
|
||||
params_.tuning_ctx = TuningContext();
|
||||
params_.stream = Stream();
|
||||
params_.quant_type_ = quant_type;
|
||||
params_.output_ = static_cast<T*>(output.ptr());
|
||||
params_.quant_ = static_cast<uint8_t*>(quant.ptr());
|
||||
params_.absmax_ = static_cast<T*>(absmax.ptr());
|
||||
params_.quant_map_buffer_ = static_cast<T*>(quant_map_buffer.ptr());
|
||||
params_.n_ = n;
|
||||
params_.k_ = k;
|
||||
}
|
||||
|
||||
void Run() override {
|
||||
ORT_THROW_IF_ERROR(contrib::cuda::SetBnbQuantMap(
|
||||
params_.quant_type_,
|
||||
params_.quant_map_buffer_,
|
||||
params_.StreamHandle()));
|
||||
ORT_THROW_IF_ERROR(contrib::cuda::DequantizeBnb4(
|
||||
params_.quant_map_buffer_,
|
||||
params_.output_,
|
||||
params_.quant_,
|
||||
params_.absmax_,
|
||||
64,
|
||||
params_.n_ * params_.k_,
|
||||
params_.StreamHandle()));
|
||||
}
|
||||
|
||||
private:
|
||||
// A VectorAddOp<T> is a callable that can process const VectorAddParams<T>*
|
||||
using ParamsT = DequantizeBnb4Params<T>;
|
||||
ParamsT params_{};
|
||||
};
|
||||
|
||||
#define REGISTER_OP(name, type) \
|
||||
py::class_<name<type>>(m, #name "_" #type) \
|
||||
.def(py::init<int, DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, int, int>()) \
|
||||
.def("SetRepeats", &name<type>::SetRepeats) \
|
||||
.def("Profile", &name<type>::Profile) \
|
||||
.def("Run", &name<type>::Run);
|
||||
|
||||
KE_REGISTER(m) {
|
||||
REGISTER_OP(DequantizeBnb4, half);
|
||||
REGISTER_OP(DequantizeBnb4, float);
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// This file serve as a simple example for adding a tunable op to onnxruntime.
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "core/providers/cuda/tunable/cuda_tunable.h"
|
||||
#include "python/tools/kernel_explorer/kernel_explorer_interface.h"
|
||||
#include "python/tools/kernel_explorer/kernels/vector_add_kernel.cuh"
|
||||
#include "contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh"
|
||||
#include "contrib_ops/cuda/quantization/matmul_bnb4.cuh"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
// Extend the OpParams so that all specializations have the same parameter passing interface
|
||||
template <typename T>
|
||||
struct MatrixFloatBnb4Params : cuda::tunable::OpParams {
|
||||
std::string Signature() const override { return std::to_string(n_); }
|
||||
|
||||
int quant_type_;
|
||||
T* output_;
|
||||
const T* a_;
|
||||
const uint8_t* b_;
|
||||
const T* absmax_;
|
||||
T* quant_map_buffer_;
|
||||
int m_;
|
||||
int n_;
|
||||
int k_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class MatrixFloatBnb4 : public IKernelExplorer {
|
||||
public:
|
||||
MatrixFloatBnb4(DeviceArray& output,
|
||||
DeviceArray& a,
|
||||
DeviceArray& b,
|
||||
DeviceArray& absmax,
|
||||
DeviceArray& quant_map_buffer,
|
||||
int quant_type, int m, int n, int k) {
|
||||
params_.tuning_ctx = TuningContext();
|
||||
params_.stream = Stream();
|
||||
params_.output_ = static_cast<T*>(output.ptr());
|
||||
params_.a_ = static_cast<T*>(a.ptr());
|
||||
params_.b_ = static_cast<uint8_t*>(b.ptr());
|
||||
params_.absmax_ = static_cast<T*>(absmax.ptr());
|
||||
params_.quant_map_buffer_ = static_cast<T*>(quant_map_buffer.ptr());
|
||||
params_.quant_type_ = quant_type;
|
||||
params_.m_ = m;
|
||||
params_.n_ = n;
|
||||
params_.k_ = k;
|
||||
}
|
||||
|
||||
void Run() override {
|
||||
ORT_THROW_IF_ERROR(contrib::cuda::SetBnbQuantMap(
|
||||
params_.quant_type_,
|
||||
params_.quant_map_buffer_,
|
||||
params_.StreamHandle()));
|
||||
contrib::cuda::TryMatMulBnb4(
|
||||
params_.quant_map_buffer_,
|
||||
params_.output_,
|
||||
params_.a_,
|
||||
params_.b_,
|
||||
params_.absmax_,
|
||||
params_.m_,
|
||||
params_.n_,
|
||||
params_.k_,
|
||||
64,
|
||||
params_.StreamHandle());
|
||||
}
|
||||
|
||||
private:
|
||||
// A VectorAddOp<T> is a callable that can process const VectorAddParams<T>*
|
||||
using ParamsT = MatrixFloatBnb4Params<T>;
|
||||
ParamsT params_{};
|
||||
};
|
||||
|
||||
#define REGISTER_OP(name, type) \
|
||||
py::class_<name<type>>(m, #name "_" #type) \
|
||||
.def(py::init<DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, DeviceArray&, int, int, int, int>()) \
|
||||
.def("SetRepeats", &name<type>::SetRepeats) \
|
||||
.def("Profile", &name<type>::Profile) \
|
||||
.def("Run", &name<type>::Run);
|
||||
|
||||
KE_REGISTER(m) {
|
||||
REGISTER_OP(MatrixFloatBnb4, half);
|
||||
REGISTER_OP(MatrixFloatBnb4, float);
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,92 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
|
||||
import kernel_explorer as ke
|
||||
import numpy as np
|
||||
from utils import dtype_to_bytes
|
||||
|
||||
|
||||
def dtype_to_funcs(dtype):
|
||||
type_map = {
|
||||
"float16": list(filter(lambda x: "DequantizeBnb4_half" in x, dir(ke))),
|
||||
"float32": list(filter(lambda x: "DequantizeBnb4_float" in x, dir(ke))),
|
||||
}
|
||||
return type_map[dtype]
|
||||
|
||||
|
||||
quant_enums = {"FP4": 0, "NF4": 1}
|
||||
|
||||
|
||||
dtypes = ["float16", "float32"]
|
||||
quant_types = ["FP4", "NF4"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DequantizeBnb4Metric(ke.BandwidthMetric):
|
||||
quant_type: str
|
||||
n: int
|
||||
k: int
|
||||
|
||||
def report(self):
|
||||
return (
|
||||
f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s"
|
||||
f" {self.quant_type} {self.dtype} n={self.n} k={self.k} {self.name}"
|
||||
)
|
||||
|
||||
|
||||
def profile_dequantize_int4_func(qt, n, k, dtype, func):
|
||||
np.random.seed(0)
|
||||
block_size = 64
|
||||
numel = n * k
|
||||
output = np.random.rand(n, k).astype(dtype)
|
||||
quant = np.random.randint(low=0, high=255, size=(numel + 1) // 2).astype("uint8")
|
||||
absmax = np.random.rand((numel + block_size - 1) // block_size).astype(dtype)
|
||||
quant_map_buffer = np.zeros(16).astype(dtype)
|
||||
|
||||
output_d = ke.DeviceArray(output)
|
||||
quant_d = ke.DeviceArray(quant)
|
||||
absmax_d = ke.DeviceArray(absmax)
|
||||
quant_map_buffer_d = ke.DeviceArray(quant_map_buffer)
|
||||
f = getattr(ke, func)
|
||||
my_op = f(quant_enums[qt], output_d, quant_d, absmax_d, quant_map_buffer_d, n, k)
|
||||
duration_ms = my_op.Profile()
|
||||
total_bytes = numel / 2 + (numel + numel / block_size) * dtype_to_bytes(dtype)
|
||||
|
||||
ke.report(DequantizeBnb4Metric(func, dtype, duration_ms, total_bytes, qt, n, k))
|
||||
|
||||
|
||||
def profile_with_args(qt, n, k, dtype, sort):
|
||||
with ke.benchmark(sort):
|
||||
for func in dtype_to_funcs(dtype):
|
||||
profile_dequantize_int4_func(qt, n, k, dtype, func)
|
||||
|
||||
|
||||
def profile():
|
||||
for qt in quant_types:
|
||||
for dt in dtypes:
|
||||
for n, k in ((4096, 4096), (4096, 12288), (12288, 4096)):
|
||||
profile_with_args(qt, n, k, dt, True)
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
group = parser.add_argument_group("profile with args")
|
||||
group.add_argument("n", type=int)
|
||||
group.add_argument("k", type=int)
|
||||
group.add_argument("quant_type", choices=quant_types)
|
||||
group.add_argument("dtype", choices=dtypes)
|
||||
group.add_argument("--sort", action="store_true")
|
||||
|
||||
if len(sys.argv) == 1:
|
||||
profile()
|
||||
else:
|
||||
args = parser.parse_args()
|
||||
profile_with_args(args.quant_type, args.n, args.k, args.dtype, args.sort)
|
||||
136
onnxruntime/python/tools/kernel_explorer/kernels/matmul_bnb4.py
Normal file
136
onnxruntime/python/tools/kernel_explorer/kernels/matmul_bnb4.py
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
|
||||
import kernel_explorer as ke
|
||||
import numpy as np
|
||||
from utils import dtype_to_bytes
|
||||
|
||||
|
||||
def dtype_to_funcs(dtype):
|
||||
type_map = {
|
||||
"float16": list(filter(lambda x: "MatrixFloatBnb4_half" in x, dir(ke))),
|
||||
"float32": list(filter(lambda x: "MatrixFloatBnb4_float" in x, dir(ke))),
|
||||
}
|
||||
return type_map[dtype]
|
||||
|
||||
|
||||
def dtype_to_funcs_cublas(dtype):
|
||||
type_map = {
|
||||
"float16": list(filter(lambda x: "GemmBenchmark_half" in x, dir(ke))),
|
||||
"float32": list(filter(lambda x: "GemmBenchmark_float" in x, dir(ke))),
|
||||
}
|
||||
return type_map[dtype]
|
||||
|
||||
|
||||
quant_enums = {"FP4": 0, "NF4": 1}
|
||||
|
||||
|
||||
dtypes = ["float16", "float32"]
|
||||
quant_types = ["FP4", "NF4"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatrixMulMetric(ke.BandwidthMetric):
|
||||
m: int
|
||||
n: int
|
||||
k: int
|
||||
|
||||
def report(self):
|
||||
return (
|
||||
f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s {self.dtype} m={self.m} n={self.n} k={self.k} {self.name}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatrixFpBnb4Metric(MatrixMulMetric):
|
||||
quant_type: str
|
||||
|
||||
def report(self):
|
||||
return (
|
||||
f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s"
|
||||
f" {self.quant_type} {self.dtype} m={self.m} n={self.n} k={self.k} {self.name}"
|
||||
)
|
||||
|
||||
|
||||
def profile_matmul_fp_bnb4_func(qt, m, n, k, dtype, func):
|
||||
np.random.seed(0)
|
||||
block_size = 64
|
||||
numel = n * k
|
||||
output = np.random.rand(m, n).astype(dtype)
|
||||
a = np.random.rand(m, k).astype(dtype)
|
||||
b = np.random.randint(low=0, high=255, size=(numel + 1) // 2).astype("uint8")
|
||||
absmax = np.random.rand((numel + block_size - 1) // block_size).astype(dtype)
|
||||
quant_map_buffer = np.zeros(16).astype(dtype)
|
||||
|
||||
output_d = ke.DeviceArray(output)
|
||||
a_d = ke.DeviceArray(a)
|
||||
b_d = ke.DeviceArray(b)
|
||||
absmax_d = ke.DeviceArray(absmax)
|
||||
quant_map_buffer_d = ke.DeviceArray(quant_map_buffer)
|
||||
f = getattr(ke, func)
|
||||
|
||||
my_op = f(output_d, a_d, b_d, absmax_d, quant_map_buffer_d, quant_enums[qt], m, n, k)
|
||||
duration_ms = my_op.Profile()
|
||||
total_bytes = (m * k + n * k + m * n) * (dtype_to_bytes(dtype))
|
||||
|
||||
ke.report(MatrixFpBnb4Metric(func, dtype, duration_ms, total_bytes, m, n, k, qt))
|
||||
|
||||
|
||||
def profile_gemm_func(m, n, k, dtype, func):
|
||||
np.random.seed(0)
|
||||
output = np.random.rand(m, n).astype(dtype)
|
||||
a = np.random.rand(m, k).astype(dtype)
|
||||
b = np.random.rand(k, n).astype(dtype)
|
||||
|
||||
output_d = ke.DeviceArray(output)
|
||||
a_d = ke.DeviceArray(a)
|
||||
b_d = ke.DeviceArray(b)
|
||||
f = getattr(ke, func)
|
||||
my_op = f(output_d, a_d, b_d, m, n, k)
|
||||
duration_ms = my_op.Profile()
|
||||
total_bytes = (m * k + n * k + m * n) * (dtype_to_bytes(dtype))
|
||||
|
||||
ke.report(MatrixMulMetric(func, dtype, duration_ms, total_bytes, m, n, k))
|
||||
|
||||
|
||||
def profile_with_args(qt, m, n, k, dtype, sort):
|
||||
with ke.benchmark(sort):
|
||||
for func in dtype_to_funcs(dtype):
|
||||
profile_matmul_fp_bnb4_func(qt, m, n, k, dtype, func)
|
||||
|
||||
for func in dtype_to_funcs_cublas(dtype):
|
||||
profile_gemm_func(m, n, k, dtype, func)
|
||||
|
||||
|
||||
def profile():
|
||||
dims_m = [1]
|
||||
for qt in quant_types:
|
||||
for dt in dtypes:
|
||||
for m in dims_m:
|
||||
for n, k in ((4096, 4096), (4096, 12288), (12288, 4096)):
|
||||
profile_with_args(qt, m, n, k, dt, False)
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
group = parser.add_argument_group("profile with args")
|
||||
group.add_argument("m", type=int)
|
||||
group.add_argument("n", type=int)
|
||||
group.add_argument("k", type=int)
|
||||
group.add_argument("quant_type", choices=quant_types)
|
||||
group.add_argument("dtype", choices=dtypes)
|
||||
group.add_argument("--sort", action="store_true")
|
||||
|
||||
if len(sys.argv) == 1:
|
||||
profile()
|
||||
else:
|
||||
args = parser.parse_args()
|
||||
profile_with_args(args.quant_type, args.m, args.n, args.k, args.dtype, args.sort)
|
||||
240
onnxruntime/python/tools/quantization/matmul_bnb4_quantizer.py
Normal file
240
onnxruntime/python/tools/quantization/matmul_bnb4_quantizer.py
Normal file
|
|
@ -0,0 +1,240 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import onnx
|
||||
from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto
|
||||
|
||||
from onnxruntime.capi._pybind_state import quantize_matmul_bnb4
|
||||
|
||||
from .onnx_model import ONNXModel
|
||||
from .quant_utils import attribute_to_kwarg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MatMulBnb4Quantizer:
|
||||
"""Perform 4b quantization of constant MatMul weights using FP4 or NF4 data type"""
|
||||
|
||||
##################
|
||||
# quantization types, must be consistent with native code type
|
||||
# Bnb_DataType_t defined in blockwise_quant_block_bnb4.h
|
||||
|
||||
# 4b floating point with bias of 3
|
||||
FP4 = 0
|
||||
|
||||
# 4b NormalFloat
|
||||
NF4 = 1
|
||||
|
||||
def __init__(self, model: ModelProto, quant_type: int, block_size: int, nodes_to_exclude=None):
|
||||
nodes_to_exclude = nodes_to_exclude or []
|
||||
assert quant_type in [MatMulBnb4Quantizer.FP4, MatMulBnb4Quantizer.NF4]
|
||||
self.model = ONNXModel(model)
|
||||
self.quant_type = quant_type
|
||||
self.block_size = block_size
|
||||
self.nodes_to_exclude = set(nodes_to_exclude)
|
||||
|
||||
@staticmethod
|
||||
def __get_initializer(name, graph_path: List[GraphProto]) -> Tuple[TensorProto, GraphProto]:
|
||||
for gid in range(len(graph_path) - 1, -1, -1):
|
||||
graph = graph_path[gid]
|
||||
for tensor in graph.initializer:
|
||||
if tensor.name == name:
|
||||
return tensor, graph
|
||||
return None, None
|
||||
|
||||
def bnb4_block_quant(self, fpweight: npt.ArrayLike) -> np.ndarray:
|
||||
"""4b quantize fp32/fp16 weight"""
|
||||
|
||||
if len(fpweight.shape) != 2:
|
||||
raise ValueError("Current bnb4 block quantization only supports 2D tensors!")
|
||||
# need to copy since the transposed weight still has the original memory layout
|
||||
# Linear4bit quantizes its weight data which is the transposed weight
|
||||
fpweight_t = fpweight.transpose().copy()
|
||||
|
||||
rows, cols = fpweight.shape
|
||||
numel = rows * cols
|
||||
block_size = self.block_size
|
||||
num_blocks = (numel + block_size - 1) // block_size
|
||||
quantized_numel = (numel + 1) // 2
|
||||
|
||||
packed = np.zeros(quantized_numel, dtype="uint8")
|
||||
absmax = np.zeros(num_blocks, dtype=fpweight.dtype)
|
||||
# block wise quantization, fpweight_t is flattened and divided into blocks
|
||||
quantize_matmul_bnb4(packed, fpweight_t, absmax, block_size, self.quant_type, cols, rows)
|
||||
|
||||
return (packed, absmax)
|
||||
|
||||
def _bnb4_matmul_node_weight(self, node: NodeProto, graph_stack: List[GraphProto]) -> NodeProto:
|
||||
"""If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node"""
|
||||
|
||||
if node.op_type != "MatMul":
|
||||
return node # only care about MatMul for now
|
||||
|
||||
logger.debug(f"start to quantize {node.name} ...")
|
||||
if node.name in self.nodes_to_exclude:
|
||||
logger.debug(f"exclude to quantize {node.name} as specified by nodes_to_exclude...")
|
||||
return node
|
||||
|
||||
inputB = node.input[1] # noqa: N806
|
||||
B, Bs_graph = MatMulBnb4Quantizer.__get_initializer(inputB, graph_stack) # noqa: N806
|
||||
if B is None:
|
||||
logger.debug("MatMul doesn't have const weight. Skip to quantize")
|
||||
return node # only care about constant weight
|
||||
|
||||
B_array = onnx.numpy_helper.to_array(B) # noqa: N806
|
||||
if len(B_array.shape) != 2:
|
||||
logger.debug("MatMul weight is not 2D. Skip to quantize")
|
||||
return node # can only process 2-D matrix
|
||||
|
||||
packed, absmax = self.bnb4_block_quant(B_array)
|
||||
B_quant = onnx.numpy_helper.from_array(packed) # noqa: N806
|
||||
B_quant.name = B.name + "_Bnb4"
|
||||
for input in Bs_graph.input:
|
||||
if input.name == inputB:
|
||||
Bs_graph.input.remove(input)
|
||||
break
|
||||
|
||||
absmax_tensor = onnx.numpy_helper.from_array(absmax)
|
||||
absmax_tensor.name = B.name + "_absmax"
|
||||
|
||||
Bs_graph.initializer.extend([B_quant, absmax_tensor])
|
||||
|
||||
kwargs = {}
|
||||
rows, cols = B_array.shape
|
||||
kwargs["K"] = rows
|
||||
kwargs["N"] = cols
|
||||
kwargs["block_size"] = self.block_size
|
||||
kwargs["quant_type"] = self.quant_type
|
||||
|
||||
matmul_bnb4_node = onnx.helper.make_node(
|
||||
"MatMulBnb4",
|
||||
inputs=[node.input[0], B_quant.name, absmax_tensor.name],
|
||||
outputs=[node.output[0]],
|
||||
name=node.name + "_Bnb4" if node.name else "",
|
||||
domain="com.microsoft",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
logger.debug(f"complete quantization of {node.name} ...")
|
||||
|
||||
return matmul_bnb4_node
|
||||
|
||||
def _process_subgraph(self, graph_stack: List[GraphProto]):
|
||||
new_nodes = []
|
||||
graph = graph_stack[-1]
|
||||
|
||||
for node in graph.node:
|
||||
graph_attrs = [
|
||||
attr
|
||||
for attr in node.attribute
|
||||
if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS
|
||||
]
|
||||
if len(graph_attrs):
|
||||
kwargs = {}
|
||||
for attr in node.attribute:
|
||||
if attr.type == onnx.AttributeProto.GRAPH:
|
||||
# recursive call to take care of sub-graph
|
||||
graph_stack.append(attr.g)
|
||||
kv = {attr.name: self._process_subgraph(graph_stack)}
|
||||
elif attr.type == onnx.AttributeProto.GRAPHS:
|
||||
value = []
|
||||
for subgraph in attr.graphs:
|
||||
# recursive call to take care of sub-graph
|
||||
graph_stack.append(subgraph)
|
||||
value.extend([self._process_subgraph(graph_stack)])
|
||||
kv = {attr.name: value}
|
||||
else:
|
||||
kv = attribute_to_kwarg(attr)
|
||||
kwargs.update(kv)
|
||||
node = onnx.helper.make_node( # noqa: PLW2901
|
||||
node.op_type, node.input, node.output, name=node.name, **kwargs
|
||||
)
|
||||
|
||||
new_nodes.append(self._bnb4_matmul_node_weight(node, graph_stack))
|
||||
|
||||
graph.ClearField("node")
|
||||
graph.node.extend(new_nodes)
|
||||
graph_stack.pop()
|
||||
return graph
|
||||
|
||||
def process(self):
|
||||
# use a stack to keep track of sub-graphs
|
||||
graph_stack = [self.model.graph()]
|
||||
opset_import = self.model.opset_import()
|
||||
|
||||
has_ms_domain = False
|
||||
for opset in opset_import:
|
||||
if opset.domain == "com.microsoft":
|
||||
has_ms_domain = True
|
||||
if not has_ms_domain:
|
||||
opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)])
|
||||
|
||||
self._process_subgraph(graph_stack)
|
||||
self.model.clean_initializers()
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""Blockwise FP4/NF4 quantization for MatMul 2D weight matrices.
|
||||
|
||||
A weight matrix is partitioned into blocks, where each block is a contiguous
|
||||
subset inside the flattened transposed weight matrix. Each block is quantized
|
||||
into a set of 4b integers with an absolute value scaling factor.
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument("--input_model", required=True, help="Path to the input model file")
|
||||
parser.add_argument("--output_model", required=True, help="Path to the output model file")
|
||||
parser.add_argument(
|
||||
"--quant_type",
|
||||
required=False,
|
||||
default=1,
|
||||
options=[MatMulBnb4Quantizer.FP4, MatMulBnb4Quantizer.NF4],
|
||||
help="Quantization data type. 0: FP4, 1: NF4",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--block_size",
|
||||
required=False,
|
||||
default=64,
|
||||
description="Block size for blockwise quantization. Note: bnb.nn.Linear4bit only uses block_size=64",
|
||||
)
|
||||
parser.add_argument("-v", "--verbose", required=False, action="store_true")
|
||||
parser.set_defaults(verbose=False)
|
||||
parser.add_argument(
|
||||
"--nodes_to_exclude",
|
||||
nargs="+",
|
||||
type=str,
|
||||
required=False,
|
||||
default=[],
|
||||
help="Specify the nodes to be excluded from quantization with node names",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
if args.verbose:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
input_model_path = args.input_model
|
||||
output_model_path = args.output_model
|
||||
|
||||
if os.path.exists(output_model_path):
|
||||
logger.error(f"file {output_model_path} already exists")
|
||||
raise Exception(f"file {output_model_path} already exists")
|
||||
|
||||
model = onnx.load(input_model_path)
|
||||
quant = MatMulBnb4Quantizer(model, args.quant_type, args.block_size, nodes_to_exclude=args.nodes_to_exclude)
|
||||
quant.process()
|
||||
quant.model.save_model_to_file(output_model_path, True)
|
||||
151
onnxruntime/test/contrib_ops/matmul_bnb4_test.cc
Normal file
151
onnxruntime/test/contrib_ops/matmul_bnb4_test.cc
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef ORT_MINIMAL_BUILD
|
||||
|
||||
#include "core/common/span_utils.h"
|
||||
#include "core/framework/tensor.h"
|
||||
#include "core/mlas/inc/mlas_q4.h"
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
#include "core/session/inference_session.h"
|
||||
#include "test/common/tensor_op_test_utils.h"
|
||||
#include "test/framework/test_utils.h"
|
||||
#include "test/optimizer/graph_transform_test_builder.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
#include "test/util/include/default_providers.h"
|
||||
#include "core/util/qmath.h"
|
||||
#include "contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <random>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "gmock/gmock.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
void QuantizeDequantizeBnb4(std::vector<float>& raw_vals, // N X K
|
||||
std::vector<uint8_t>& quant_vals,
|
||||
std::vector<float>& absmax,
|
||||
int32_t quant_type,
|
||||
int32_t N,
|
||||
int32_t K,
|
||||
int32_t block_size) {
|
||||
OrtThreadPoolParams to;
|
||||
auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to,
|
||||
concurrency::ThreadPoolType::INTRA_OP);
|
||||
|
||||
contrib::QuantizeBlockwiseBnb4<float>(
|
||||
quant_vals.data(),
|
||||
raw_vals.data(),
|
||||
absmax.data(),
|
||||
block_size,
|
||||
quant_type,
|
||||
N,
|
||||
K,
|
||||
tp.get());
|
||||
|
||||
contrib::DequantizeBlockwiseBnb4<float>(
|
||||
raw_vals.data(),
|
||||
quant_vals.data(),
|
||||
absmax.data(),
|
||||
block_size,
|
||||
quant_type,
|
||||
N,
|
||||
K,
|
||||
tp.get());
|
||||
}
|
||||
|
||||
void RunTest(int64_t quant_type, int64_t M, int64_t N, int64_t K, int64_t block_size, bool use_float16) {
|
||||
RandomValueGenerator random{1234};
|
||||
std::vector<float> input0_vals(random.Gaussian<float>(std::vector<int64_t>({M, K}), 0.0f, 0.25f));
|
||||
// quantizer expects transposed weights, N X K
|
||||
std::vector<float> input1_f_vals(random.Gaussian<float>(std::vector<int64_t>({N, K}), 0.0f, 0.25f));
|
||||
|
||||
int64_t numel = N * K;
|
||||
int64_t quantized_numel = (numel + 1) / 2;
|
||||
int64_t total_block_count = (numel + block_size - 1) / block_size;
|
||||
std::vector<uint8_t> input1_vals(quantized_numel);
|
||||
std::vector<float> absmax(total_block_count);
|
||||
|
||||
QuantizeDequantizeBnb4(input1_f_vals,
|
||||
input1_vals,
|
||||
absmax,
|
||||
static_cast<int32_t>(quant_type),
|
||||
static_cast<int32_t>(N),
|
||||
static_cast<int32_t>(K),
|
||||
static_cast<int32_t>(block_size));
|
||||
|
||||
std::vector<float> expected_vals(M * N);
|
||||
for (int64_t m = 0; m < M; m++) {
|
||||
for (int64_t n = 0; n < N; n++) {
|
||||
float sum = 0.0f;
|
||||
for (int64_t k = 0; k < K; k++) {
|
||||
sum += input0_vals[m * K + k] * input1_f_vals[n * K + k];
|
||||
}
|
||||
expected_vals[m * N + n] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
OpTester test("MatMulBnb4", 1, kMSDomain);
|
||||
test.AddAttribute<int64_t>("K", K);
|
||||
test.AddAttribute<int64_t>("N", N);
|
||||
test.AddAttribute<int64_t>("block_size", block_size);
|
||||
test.AddAttribute<int64_t>("quant_type", quant_type);
|
||||
if (use_float16) {
|
||||
test.AddInput<MLFloat16>("A", {M, K}, ToFloat16(input0_vals), false);
|
||||
test.AddInput<uint8_t>("B", {quantized_numel}, input1_vals, true);
|
||||
test.AddInput<MLFloat16>("absmax", {total_block_count}, ToFloat16(absmax), true);
|
||||
|
||||
test.AddOutput<MLFloat16>("Y", {M, N}, ToFloat16(expected_vals));
|
||||
test.SetOutputAbsErr("Y", 0.02f);
|
||||
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
} else {
|
||||
test.AddInput<float>("A", {M, K}, input0_vals, false);
|
||||
test.AddInput<uint8_t>("B", {quantized_numel}, input1_vals, true);
|
||||
test.AddInput<float>("absmax", {total_block_count}, absmax, true);
|
||||
|
||||
test.AddOutput<float>("Y", {M, N}, expected_vals);
|
||||
|
||||
test.Run();
|
||||
}
|
||||
}
|
||||
|
||||
TEST(MatMulBnb4, Float32) {
|
||||
for (auto qt : {0, 1}) {
|
||||
for (auto M : {1, 2, 100}) {
|
||||
for (auto N : {1, 2, 32, 288}) {
|
||||
for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) {
|
||||
for (auto block_size : {16, 32, 64, 128}) {
|
||||
RunTest(qt, M, N, K, block_size, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(USE_CUDA)
|
||||
TEST(MatMulBnb4, Float16) {
|
||||
for (auto qt : {0, 1}) {
|
||||
for (auto M : {1, 2, 100}) {
|
||||
for (auto N : {1, 2, 32, 288}) {
|
||||
for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) {
|
||||
for (auto block_size : {16, 32, 64, 128}) {
|
||||
RunTest(qt, M, N, K, block_size, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
||||
#endif // ORT_MINIMAL_BUILD
|
||||
186
onnxruntime/test/python/quantization/test_op_matmul_bnb4.py
Normal file
186
onnxruntime/test/python/quantization/test_op_matmul_bnb4.py
Normal file
|
|
@ -0,0 +1,186 @@
|
|||
#!/usr/bin/env python
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
from importlib.util import find_spec
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
from onnx import TensorProto, helper
|
||||
from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count
|
||||
|
||||
from onnxruntime.quantization import quant_utils
|
||||
|
||||
quant_maps = {
|
||||
0: [
|
||||
0.00000000,
|
||||
5.208333333e-03,
|
||||
0.66666667,
|
||||
1.00000000,
|
||||
0.33333333,
|
||||
0.50000000,
|
||||
0.16666667,
|
||||
0.25000000,
|
||||
-0.00000000,
|
||||
-5.208333333e-03,
|
||||
-0.66666667,
|
||||
-1.00000000,
|
||||
-0.33333333,
|
||||
-0.50000000,
|
||||
-0.16666667,
|
||||
-0.25000000,
|
||||
],
|
||||
1: [
|
||||
-1.0,
|
||||
-0.6961928009986877,
|
||||
-0.5250730514526367,
|
||||
-0.39491748809814453,
|
||||
-0.28444138169288635,
|
||||
-0.18477343022823334,
|
||||
-0.09105003625154495,
|
||||
0.0,
|
||||
0.07958029955625534,
|
||||
0.16093020141124725,
|
||||
0.24611230194568634,
|
||||
0.33791524171829224,
|
||||
0.44070982933044434,
|
||||
0.5626170039176941,
|
||||
0.7229568362236023,
|
||||
1.0,
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class TestOpMatMulBnb4(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="test_matmulbnb4.")
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls._tmp_model_dir.cleanup()
|
||||
|
||||
def fill_bnb4_data(self, shape: Tuple[int, int], quant_type: int) -> np.ndarray:
|
||||
rows, cols = shape
|
||||
line = np.zeros(shape)
|
||||
line = line.reshape(-1)
|
||||
quant_map = np.array(quant_maps[quant_type], dtype=np.float32)
|
||||
|
||||
v = 0
|
||||
for i in range(line.shape[0]):
|
||||
line[i] = quant_map[v]
|
||||
v += 1
|
||||
if v >= 16:
|
||||
v = 0
|
||||
|
||||
# bnb quantization quantizes weight.T after flattening
|
||||
line = line.reshape(cols, rows).transpose()
|
||||
return line.reshape(shape)
|
||||
|
||||
def input_feeds(self, n: int, name2shape: Dict[str, Union[int, Tuple[int, ...]]]) -> TestDataFeeds:
|
||||
input_data_list = []
|
||||
for _i in range(n):
|
||||
inputs = {}
|
||||
for name, shape in name2shape.items():
|
||||
inputs.update({name: np.random.randint(-1, 2, shape).astype(np.float32)})
|
||||
input_data_list.extend([inputs])
|
||||
dr = TestDataFeeds(input_data_list)
|
||||
return dr
|
||||
|
||||
def construct_model_matmul(self, output_model_path: str, quant_type: int) -> None:
|
||||
# (input)
|
||||
# |
|
||||
# MatMul
|
||||
# |
|
||||
# (output)
|
||||
input_name = "input"
|
||||
output_name = "output"
|
||||
initializers = []
|
||||
|
||||
def make_matmul(input_name, weight_shape: Union[int, Tuple[int, ...]], weight_name: str, output_name: str):
|
||||
weight_data = self.fill_bnb4_data(weight_shape, quant_type).astype(np.float32)
|
||||
initializers.append(onnx.numpy_helper.from_array(weight_data, name=weight_name))
|
||||
return onnx.helper.make_node(
|
||||
"MatMul",
|
||||
[input_name, weight_name],
|
||||
[output_name],
|
||||
)
|
||||
|
||||
# for this to work (in_features * out_features) % block_size == 0
|
||||
in_features = 52
|
||||
out_features = 288
|
||||
# make MatMul node
|
||||
matmul_node = make_matmul(
|
||||
input_name,
|
||||
[in_features, out_features],
|
||||
"linear1.weight",
|
||||
output_name,
|
||||
)
|
||||
|
||||
# make graph
|
||||
input_tensor = helper.make_tensor_value_info(input_name, TensorProto.FLOAT, [-1, in_features])
|
||||
output_tensor = helper.make_tensor_value_info(output_name, TensorProto.FLOAT, [-1, out_features])
|
||||
graph_name = "matmul_bnb4_test"
|
||||
graph = helper.make_graph(
|
||||
[matmul_node],
|
||||
graph_name,
|
||||
[input_tensor],
|
||||
[output_tensor],
|
||||
initializer=initializers,
|
||||
)
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
model.ir_version = 7 # use stable onnx ir version
|
||||
|
||||
onnx.save(model, output_model_path)
|
||||
|
||||
def quant_test(self, quant_type: int, block_size: int):
|
||||
model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath(f"matmul_fp32_{quant_type}.onnx").absolute())
|
||||
self.construct_model_matmul(model_fp32_path, quant_type)
|
||||
data_reader = self.input_feeds(1, {"input": [100, 52]})
|
||||
|
||||
model_bnb4_path = str(
|
||||
Path(self._tmp_model_dir.name).joinpath(f"MatMulBnb4_{quant_type}_{block_size}.onnx").absolute()
|
||||
)
|
||||
|
||||
# Quantize fp32 model to bnb4 model
|
||||
from onnxruntime.quantization import matmul_bnb4_quantizer
|
||||
|
||||
model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path))
|
||||
quant = matmul_bnb4_quantizer.MatMulBnb4Quantizer(model, quant_type, block_size)
|
||||
quant.process()
|
||||
quant.model.save_model_to_file(model_bnb4_path, False)
|
||||
|
||||
quant_nodes = {"MatMulBnb4": 1}
|
||||
check_op_type_count(self, model_bnb4_path, **quant_nodes)
|
||||
|
||||
data_reader.rewind()
|
||||
|
||||
try:
|
||||
check_model_correctness(self, model_fp32_path, model_bnb4_path, data_reader.get_next())
|
||||
except Exception as exception:
|
||||
raise exception
|
||||
|
||||
@unittest.skipIf(
|
||||
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_bnb4"
|
||||
)
|
||||
def test_quantize_matmul_bnb4_fp4(self):
|
||||
np.random.seed(13)
|
||||
self.quant_test(0, 64)
|
||||
|
||||
@unittest.skipIf(
|
||||
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_bnb4"
|
||||
)
|
||||
def test_quantize_matmul_bnb4_nf4(self):
|
||||
np.random.seed(13)
|
||||
self.quant_test(1, 64)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -0,0 +1,139 @@
|
|||
#!/usr/bin/env python
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import unittest
|
||||
from importlib.util import find_spec
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
quant_enums = {"FP4": 0, "NF4": 1}
|
||||
|
||||
|
||||
def quantize_block_fp4(block: npt.ArrayLike):
|
||||
# quantize a block of float32 values to uint8 by simulating a binary search using pivots
|
||||
# could have used (block[:,None] - quant_map).argmin(axis=1) but there are some mismatches due to
|
||||
# floating point precision
|
||||
# block: 1-D array of normalized [-1,1] float32 values, len(block) % 2 == 0
|
||||
|
||||
# pivots to find the quantization index
|
||||
# only half of the pivots are needed since the other half is symmetric
|
||||
pivots = np.array(
|
||||
[0.00260417, 0.0859375, 0.20833333, 0.29166667, 0.4166667, 0.583333, 0.8333333, 1], dtype=np.float32
|
||||
)
|
||||
# indices are not 0,1,2,3,4,5,6,7 because it is a floating point data type
|
||||
pivot_indices = np.array([0, 1, 6, 7, 4, 5, 2, 3], dtype=np.uint8)
|
||||
|
||||
# signs of the block
|
||||
signs = (block < 0).astype(np.uint8) * 8
|
||||
|
||||
# find the uint8 quantization index
|
||||
# argmax finds the first occurrence of True
|
||||
quant_indices = pivot_indices[(np.abs(block)[:, None] <= pivots).argmax(axis=1)] + signs
|
||||
|
||||
return np.bitwise_or(np.left_shift(quant_indices[::2], 4), quant_indices[1::2])
|
||||
|
||||
|
||||
def quantize_block_nf4(block: npt.ArrayLike):
|
||||
pivots = np.array(
|
||||
[
|
||||
-0.8480964004993439,
|
||||
-0.6106329262256622,
|
||||
-0.4599952697753906,
|
||||
-0.33967943489551544,
|
||||
-0.23460740596055984,
|
||||
-0.13791173323988914,
|
||||
-0.045525018125772476,
|
||||
0.03979014977812767,
|
||||
0.1202552504837513,
|
||||
0.2035212516784668,
|
||||
0.2920137718319893,
|
||||
0.3893125355243683,
|
||||
0.5016634166240692,
|
||||
0.6427869200706482,
|
||||
0.8614784181118011,
|
||||
1.0,
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
quant_indices = (block[:, None] <= pivots).argmax(axis=1).astype(np.uint8)
|
||||
|
||||
return np.bitwise_or(np.left_shift(quant_indices[::2], 4), quant_indices[1::2])
|
||||
|
||||
|
||||
def quantize_blockwise_bnb4_ref(matrix_float: npt.ArrayLike, block_size: int, quant_type: str, target=None):
|
||||
if len(matrix_float.shape) != 2:
|
||||
raise ValueError("Current bnb4 block quantization only supports 2D tensors!")
|
||||
|
||||
numel = matrix_float.size
|
||||
num_blocks = (numel + block_size - 1) // block_size
|
||||
quantized_numel = (numel + 1) // 2
|
||||
|
||||
packed = np.zeros(quantized_numel, dtype=np.uint8)
|
||||
absmax = np.zeros(num_blocks, dtype=matrix_float.dtype)
|
||||
|
||||
flattened_matrix_float = matrix_float.flatten()
|
||||
for block_idx in range(num_blocks):
|
||||
block_len = min(block_size, numel - block_idx * block_size)
|
||||
block = np.float32(flattened_matrix_float[block_idx * block_size : block_idx * block_size + block_len])
|
||||
|
||||
block_absmax = np.max(np.abs(block))
|
||||
reciprocal_absmax = 1.0 / block_absmax if block_absmax != 0 else 0.0
|
||||
absmax[block_idx] = block_absmax
|
||||
|
||||
if block_len % 2 != 0:
|
||||
block = np.append(block, 0.0)
|
||||
block_len += 1
|
||||
|
||||
block *= reciprocal_absmax
|
||||
start = block_idx * block_size // 2
|
||||
end = start + block_len // 2
|
||||
if quant_type == "FP4":
|
||||
packed[start:end] = quantize_block_fp4(block)
|
||||
else:
|
||||
packed[start:end] = quantize_block_nf4(block)
|
||||
|
||||
return (packed, absmax)
|
||||
|
||||
|
||||
def quantize_blockwise_bnb4_target(matrix_float: npt.ArrayLike, block_size: int, quant_type: str):
|
||||
if len(matrix_float.shape) != 2:
|
||||
raise ValueError("Current int4 block quantization only supports 2D tensors!")
|
||||
quant_type_enum = quant_enums[quant_type]
|
||||
|
||||
n, k = matrix_float.shape # already transposed
|
||||
numel = n * k
|
||||
num_blocks = (numel + block_size - 1) // block_size
|
||||
quantized_numel = (numel + 1) // 2
|
||||
|
||||
packed = np.zeros(quantized_numel, dtype="uint8")
|
||||
absmax = np.zeros(num_blocks, dtype=matrix_float.dtype)
|
||||
from onnxruntime.capi._pybind_state import quantize_matmul_bnb4
|
||||
|
||||
quantize_matmul_bnb4(packed, matrix_float, absmax, block_size, quant_type_enum, n, k)
|
||||
return (packed, absmax)
|
||||
|
||||
|
||||
class TestQuantizeBlockwiseBnb4(unittest.TestCase):
|
||||
@unittest.skipIf(
|
||||
find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_bnb4"
|
||||
)
|
||||
def test_quantize_blockwise_bnb4(self):
|
||||
for quant_type in ["FP4", "NF4"]:
|
||||
for k, n in [(128, 128), (32, 128), (128, 32), (52, 128), (128, 52), (73, 123)]:
|
||||
for block_size in [16, 32, 64, 128]:
|
||||
for type in [np.float32, np.float16]:
|
||||
matrix_float = np.random.uniform(-1, 1, (k, n)).astype(type)
|
||||
quant_value_ref, absmax_ref = quantize_blockwise_bnb4_ref(matrix_float, block_size, quant_type)
|
||||
quant_value, absmax = quantize_blockwise_bnb4_target(matrix_float, block_size, quant_type)
|
||||
assert np.allclose(quant_value_ref, quant_value)
|
||||
assert np.allclose(absmax_ref, absmax)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in a new issue