mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
[JS/WebGPU] Add MatMulNBits (#19446)
### Description Add MatMulNBits to support MatMul using 4-bit quantized weights ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
06269a3952
commit
dfeda9019c
9 changed files with 1828 additions and 10 deletions
|
|
@ -62,6 +62,7 @@ Do not modify directly.*
|
|||
| LessOrEqual | ai.onnx(12-15,16+) | |
|
||||
| Log | ai.onnx(6-12,13+) | |
|
||||
| MatMul | ai.onnx(1-12,13+) | |
|
||||
| MatMulNBits | com.microsoft(1+) | |
|
||||
| MaxPool | ai.onnx(1-7,8-9,10,11,12+); com.ms.internal.nhwc(1-7,8-9,10,11,12+) | need perf optimization; need implementing activation |
|
||||
| MemcpyFromHost | ai.onnx(1+) | |
|
||||
| MemcpyToHost | ai.onnx(1+) | |
|
||||
|
|
|
|||
|
|
@ -92,6 +92,34 @@ export class ShapeUtil {
|
|||
return ShapeUtil.getSizeFromDimensionRange(dims, 0, dims.length);
|
||||
}
|
||||
|
||||
/**
|
||||
* convert dims corresponding to type change to pack. ex. uint8 data to uint32
|
||||
*/
|
||||
static convertShape(dims: readonly number[], size = 4): readonly number[] {
|
||||
const rank = dims.length;
|
||||
if (rank === 0) {
|
||||
return [];
|
||||
}
|
||||
const newDims = new Array(rank);
|
||||
let i = rank - 1;
|
||||
while (i >= 0) {
|
||||
if (dims[i] % size === 0) {
|
||||
newDims[i] = dims[i] / size;
|
||||
break;
|
||||
}
|
||||
if (size % dims[i] !== 0) {
|
||||
throw new Error('cannot convert shape');
|
||||
}
|
||||
newDims[i] = 1;
|
||||
size /= dims[i];
|
||||
i--;
|
||||
}
|
||||
for (i--; i >= 0; i--) {
|
||||
newDims[i] = dims[i];
|
||||
}
|
||||
return newDims;
|
||||
}
|
||||
|
||||
/**
|
||||
* calculate the size (number of elements) from the given axis (inclusive)
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ import {gemm, parseGemmAttributes} from './ops/gemm';
|
|||
import {instanceNorm} from './ops/instance-norm';
|
||||
import {layerNorm} from './ops/layer-norm';
|
||||
import {matMul} from './ops/matmul';
|
||||
import {matMulNBits, parseMatMulNBitsAttributes} from './ops/matmulnbits';
|
||||
import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multi-head-attentiion';
|
||||
import {pad} from './ops/pad';
|
||||
import * as pool from './ops/pool';
|
||||
|
|
@ -92,6 +93,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
|
|||
['LessOrEqual', [binaryOps.lessOrEqual]],
|
||||
['Log', [unaryOps.log]],
|
||||
['MatMul', [matMul]],
|
||||
['MatMulNBits', [matMulNBits, parseMatMulNBitsAttributes]],
|
||||
// TODO: support new attributes for MaxPool-8 and MaxPool-10
|
||||
['MaxPool', [pool.maxPool, pool.parseMaxPoolAttributes]],
|
||||
['Mul', [binaryOps.mul]],
|
||||
|
|
|
|||
184
js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts
Normal file
184
js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts
Normal file
|
|
@ -0,0 +1,184 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {DataType} from '../../../wasm-common';
|
||||
import {TensorView} from '../../tensor-view';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';
|
||||
|
||||
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common';
|
||||
|
||||
// TODO support quantization bits not equal to 4
|
||||
export interface MatMulNBitsAttributes extends AttributeWithCacheKey {
|
||||
k: number;
|
||||
n: number;
|
||||
accuracyLevel: number;
|
||||
bits: number;
|
||||
blockSize: number;
|
||||
}
|
||||
|
||||
const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): void => {
|
||||
if (inputs.length < 3 || inputs.length > 4) {
|
||||
throw new Error('MatMulNBits requires 3 or 4 inputs');
|
||||
}
|
||||
const a = inputs[0];
|
||||
const aRank = a.dims.length;
|
||||
if (a.dims[aRank - 1] !== attributes.k) {
|
||||
throw new Error('The last dim of input shape does not match the k value');
|
||||
}
|
||||
const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize);
|
||||
const blobSize = attributes.blockSize / 8 * attributes.bits;
|
||||
const b = inputs[1];
|
||||
if (!ShapeUtil.areEqual(b.dims, [attributes.n, nBlocksPerCol, blobSize])) {
|
||||
throw new Error('The second inputs must be 3D tensor with shape N X nBlocksPerCol X blobSize');
|
||||
}
|
||||
const scales = inputs[2];
|
||||
const scalesShape = scales.dims;
|
||||
if (ShapeUtil.size(scalesShape) !== attributes.n * nBlocksPerCol) {
|
||||
throw new Error('scales input size error.');
|
||||
}
|
||||
if (inputs.length === 4) {
|
||||
const zeroPoints = inputs[3];
|
||||
const zeroPointsShape = zeroPoints.dims;
|
||||
const expectedZeroPointsSize =
|
||||
attributes.bits > 4 ? (attributes.n * nBlocksPerCol) : attributes.n * Math.floor((nBlocksPerCol + 1) / 2);
|
||||
if (ShapeUtil.size(zeroPointsShape) !== expectedZeroPointsSize) {
|
||||
throw new Error('zeroPoints input size error.');
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
export const createMatMulNBitsProgramInfo =
|
||||
(inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): ProgramInfo => {
|
||||
const a = inputs[0];
|
||||
const b = inputs[1];
|
||||
const scales = inputs[2];
|
||||
const aRank = a.dims.length;
|
||||
const outputShape = a.dims.slice(0, aRank - 1).concat(attributes.n);
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
|
||||
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.k},
|
||||
{type: DataType.uint32, data: attributes.n}, {type: DataType.uint32, data: attributes.accuracyLevel},
|
||||
{type: DataType.uint32, data: attributes.bits}, {type: DataType.uint32, data: attributes.blockSize}
|
||||
];
|
||||
programUniforms.push(...createTensorShapeVariables(a.dims));
|
||||
programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(b.dims)));
|
||||
programUniforms.push(...createTensorShapeVariables(scales.dims));
|
||||
if (inputs.length === 4) {
|
||||
programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(inputs[3].dims)));
|
||||
}
|
||||
programUniforms.push(...createTensorShapeVariables(outputShape));
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
const a = inputVariable('a', inputs[0].dataType, inputs[0].dims.length);
|
||||
const b = inputVariable('b', DataType.uint32, inputs[1].dims.length);
|
||||
const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length);
|
||||
const inputVariables = [a, b, scales];
|
||||
const zeroPoints =
|
||||
inputs.length === 4 ? inputVariable('zero_points', DataType.uint32, inputs[3].dims.length) : undefined;
|
||||
if (zeroPoints) {
|
||||
inputVariables.push(zeroPoints);
|
||||
}
|
||||
const output = outputVariable('output', inputs[0].dataType, outputShape.length);
|
||||
const uniforms: UniformsArrayType = [
|
||||
{name: 'output_size', type: 'u32'}, {name: 'k', type: 'u32'}, {name: 'n', type: 'u32'},
|
||||
{name: 'accuracy_level', type: 'u32'}, {name: 'bits', type: 'u32'}, {name: 'block_size', type: 'u32'}
|
||||
];
|
||||
const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize);
|
||||
const blobSize = attributes.blockSize / 8 * attributes.bits;
|
||||
const wordPerBlob = blobSize / 4;
|
||||
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
|
||||
return `
|
||||
fn ortUnpack8x4snorm(value: u32) -> array<${dataType}, 8>{
|
||||
var result = array<${dataType}, 8>();
|
||||
var offset: u32 = 0;
|
||||
let count: u32 = 4;
|
||||
for (var i: u32 = 0; i < 8u; i++) {
|
||||
result[i] = ${dataType}(extractBits(value, offset, count));
|
||||
offset += count;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
|
||||
var value: ${dataType} = 0.0;
|
||||
let output_indices = ${output.offsetToIndices('global_idx')};
|
||||
var a_indices: ${a.type.indices} = output_indices;
|
||||
var n = ${output.indicesGet('output_indices', aRank - 1)};
|
||||
// Two zero points are packed into one byte because uniforms.bits <= 4.
|
||||
// zero_point_offset is either 0 or 4. It is bit offset within one byte.
|
||||
// TODO support zero_point_offset for bits > 4
|
||||
${
|
||||
zeroPoints ? `
|
||||
var zero_point_index: u32 = n * ((${nBlocksPerCol} + 1) / 2) / 4;
|
||||
var zero_point_word: u32 = ${zeroPoints.getByOffset('zero_point_index')};
|
||||
var zero_point_offset: u32 = 0;` :
|
||||
''}
|
||||
var scale_idex = n * ${nBlocksPerCol};
|
||||
var b_indices: ${b.type.indices};
|
||||
${b.indicesSet('b_indices', '0', 'n')};
|
||||
var block_offset: u32 = 0;
|
||||
for (var block: u32 = 0; block < ${nBlocksPerCol}; block++) {
|
||||
// The scale and zero points are computed per block.
|
||||
let scale = ${scales.getByOffset('scale_idex')};
|
||||
// The default zero point is 8 for unsigned 4-bit quantization.
|
||||
let zero_point: ${dataType} = ${
|
||||
zeroPoints ? `${dataType}(extractBits(zero_point_word, zero_point_offset, 4))` : 8.0};
|
||||
${b.indicesSet('b_indices', '1', 'block')};
|
||||
var word_offset: u32 = block_offset;
|
||||
for (var word: u32 = 0; word < ${wordPerBlob}; word++) {
|
||||
${b.indicesSet('b_indices', '2', 'word')};
|
||||
let b_value = ${b.getByIndices('b_indices')};
|
||||
let b_quantized_values: array<${dataType}, 8> = ortUnpack8x4snorm(b_value);
|
||||
// Number of B elements per 32-bit word is 32/bits = 32/4 = 8
|
||||
var offset: u32 = word_offset;
|
||||
for (var i: u32 = 0; i < 8; i++) {
|
||||
${a.indicesSet('a_indices', aRank - 1, 'offset')};
|
||||
let a_value = ${a.getByIndices('a_indices')};
|
||||
let b_quantized_value = b_quantized_values[i];
|
||||
let b_dequantized_value = (b_quantized_value - zero_point) * scale;
|
||||
value += a_value * b_dequantized_value;
|
||||
offset++;
|
||||
}
|
||||
word_offset += 8;
|
||||
}
|
||||
scale_idex++;
|
||||
${
|
||||
zeroPoints ? `
|
||||
if (zero_point_offset == 28) {
|
||||
zero_point_offset = 0;
|
||||
zero_point_index++;
|
||||
zero_point_word = ${zeroPoints.getByOffset('zero_point_index')};
|
||||
} else {
|
||||
zero_point_offset += 4;
|
||||
}` :
|
||||
''}
|
||||
block_offset += uniforms.block_size;
|
||||
}
|
||||
${output.setByOffset('global_idx', 'value')};
|
||||
}
|
||||
`;
|
||||
};
|
||||
return {
|
||||
name: 'MatMulNBits',
|
||||
shaderCache:
|
||||
{hint: `${attributes.cacheKey};${inputs.length}`, inputDependencies: Array(inputs.length).fill('rank')},
|
||||
getRunData: () => ({
|
||||
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
|
||||
dispatchGroup: {x: Math.ceil(outputSize / 64)},
|
||||
programUniforms
|
||||
}),
|
||||
getShaderSource
|
||||
};
|
||||
};
|
||||
|
||||
export const matMulNBits = (context: ComputeContext, attributes: MatMulNBitsAttributes): void => {
|
||||
validateInputs(context.inputs, attributes);
|
||||
context.compute(createMatMulNBitsProgramInfo(context.inputs, attributes));
|
||||
};
|
||||
|
||||
export const parseMatMulNBitsAttributes = (attributes: Record<string, unknown>): MatMulNBitsAttributes =>
|
||||
createAttributeWithCacheKey(attributes as Omit<MatMulNBitsAttributes, keyof AttributeWithCacheKey>);
|
||||
1527
js/web/test/data/ops/matmulnbits.jsonc
Normal file
1527
js/web/test/data/ops/matmulnbits.jsonc
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -1362,6 +1362,7 @@
|
|||
"less.jsonc",
|
||||
"log.jsonc",
|
||||
"matmul.jsonc",
|
||||
"matmulnbits.jsonc",
|
||||
"matmul-broadcast.jsonc",
|
||||
"mul.jsonc",
|
||||
"mul_int32.jsonc",
|
||||
|
|
|
|||
|
|
@ -8,13 +8,14 @@ namespace contrib {
|
|||
namespace js {
|
||||
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Attention);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FastGelu);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FastGelu);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MatMulNBits);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization);
|
||||
|
||||
template <>
|
||||
KernelCreateInfo BuildKernelCreateInfo<void>() {
|
||||
|
|
@ -25,14 +26,15 @@ KernelCreateInfo BuildKernelCreateInfo<void>() {
|
|||
Status RegisterJsContribKernels(KernelRegistry& kernel_registry) {
|
||||
static const BuildKernelCreateInfoFn function_table[] = {
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Attention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FastGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FastGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MatMulNBits)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1,
|
||||
SkipLayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv)>};
|
||||
SkipLayerNormalization)>};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
KernelCreateInfo info = function_table_entry();
|
||||
|
|
|
|||
25
onnxruntime/contrib_ops/js/quantization/matmul_nbits.cc
Normal file
25
onnxruntime/contrib_ops/js/quantization/matmul_nbits.cc
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "contrib_ops/js/quantization/matmul_nbits.h"
|
||||
#include "core/providers/js/js_data_types.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace js {
|
||||
|
||||
using onnxruntime::js::JsepSupportedFloatTypes;
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
MatMulNBits,
|
||||
kMSDomain,
|
||||
1,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T1", JsepSupportedFloatTypes())
|
||||
.TypeConstraint("T2", DataTypeImpl::GetTensorType<uint8_t>()),
|
||||
MatMulNBits);
|
||||
|
||||
} // namespace js
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
48
onnxruntime/contrib_ops/js/quantization/matmul_nbits.h
Normal file
48
onnxruntime/contrib_ops/js/quantization/matmul_nbits.h
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/js/js_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace js {
|
||||
|
||||
using onnxruntime::js::JsKernel;
|
||||
|
||||
class MatMulNBits final : public JsKernel {
|
||||
public:
|
||||
MatMulNBits(const OpKernelInfo& info) : JsKernel(info),
|
||||
K_{narrow<size_t>(info.GetAttr<int64_t>("K"))},
|
||||
N_{narrow<size_t>(info.GetAttr<int64_t>("N"))},
|
||||
accuracy_level_{info.GetAttrOrDefault<int64_t>("accuracy_level", 0)},
|
||||
nbits_{narrow<size_t>(info.GetAttr<int64_t>("bits"))},
|
||||
block_size_{narrow<size_t>(info.GetAttr<int64_t>("block_size"))} {
|
||||
ORT_ENFORCE(nbits_ == 4,
|
||||
"Only 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
|
||||
ORT_ENFORCE(block_size_ >= 16 && !(block_size_ & (block_size_ - 1)),
|
||||
"Block size must be a power of 2 and greater than or equal to 16.");
|
||||
JSEP_INIT_KERNEL_ATTRIBUTE(MatMulNBits, ({
|
||||
"k" : $1,
|
||||
"n" : $2,
|
||||
"accuracyLevel" : $3,
|
||||
"bits" : $4,
|
||||
"blockSize" : $5
|
||||
}),
|
||||
static_cast<int32_t>(K_),
|
||||
static_cast<int32_t>(N_),
|
||||
static_cast<int32_t>(accuracy_level_),
|
||||
static_cast<int32_t>(nbits_),
|
||||
static_cast<int32_t>(block_size_));
|
||||
}
|
||||
|
||||
private:
|
||||
const size_t K_;
|
||||
const size_t N_;
|
||||
const int64_t accuracy_level_;
|
||||
const size_t nbits_;
|
||||
const size_t block_size_;
|
||||
};
|
||||
|
||||
} // namespace js
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue