[JS/WebGPU] Add MatMulNBits (#19446)

### Description
Add MatMulNBits to support MatMul using 4-bit quantized weights



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
satyajandhyala 2024-02-17 09:19:17 -08:00 committed by GitHub
parent 06269a3952
commit dfeda9019c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 1828 additions and 10 deletions

View file

@ -62,6 +62,7 @@ Do not modify directly.*
| LessOrEqual | ai.onnx(12-15,16+) | |
| Log | ai.onnx(6-12,13+) | |
| MatMul | ai.onnx(1-12,13+) | |
| MatMulNBits | com.microsoft(1+) | |
| MaxPool | ai.onnx(1-7,8-9,10,11,12+); com.ms.internal.nhwc(1-7,8-9,10,11,12+) | need perf optimization; need implementing activation |
| MemcpyFromHost | ai.onnx(1+) | |
| MemcpyToHost | ai.onnx(1+) | |

View file

@ -92,6 +92,34 @@ export class ShapeUtil {
return ShapeUtil.getSizeFromDimensionRange(dims, 0, dims.length);
}
/**
* convert dims corresponding to type change to pack. ex. uint8 data to uint32
*/
static convertShape(dims: readonly number[], size = 4): readonly number[] {
const rank = dims.length;
if (rank === 0) {
return [];
}
const newDims = new Array(rank);
let i = rank - 1;
while (i >= 0) {
if (dims[i] % size === 0) {
newDims[i] = dims[i] / size;
break;
}
if (size % dims[i] !== 0) {
throw new Error('cannot convert shape');
}
newDims[i] = 1;
size /= dims[i];
i--;
}
for (i--; i >= 0; i--) {
newDims[i] = dims[i];
}
return newDims;
}
/**
* calculate the size (number of elements) from the given axis (inclusive)
*/

View file

@ -20,6 +20,7 @@ import {gemm, parseGemmAttributes} from './ops/gemm';
import {instanceNorm} from './ops/instance-norm';
import {layerNorm} from './ops/layer-norm';
import {matMul} from './ops/matmul';
import {matMulNBits, parseMatMulNBitsAttributes} from './ops/matmulnbits';
import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multi-head-attentiion';
import {pad} from './ops/pad';
import * as pool from './ops/pool';
@ -92,6 +93,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['LessOrEqual', [binaryOps.lessOrEqual]],
['Log', [unaryOps.log]],
['MatMul', [matMul]],
['MatMulNBits', [matMulNBits, parseMatMulNBitsAttributes]],
// TODO: support new attributes for MaxPool-8 and MaxPool-10
['MaxPool', [pool.maxPool, pool.parseMaxPoolAttributes]],
['Mul', [binaryOps.mul]],

View file

@ -0,0 +1,184 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common';
// TODO support quantization bits not equal to 4
export interface MatMulNBitsAttributes extends AttributeWithCacheKey {
k: number;
n: number;
accuracyLevel: number;
bits: number;
blockSize: number;
}
const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): void => {
if (inputs.length < 3 || inputs.length > 4) {
throw new Error('MatMulNBits requires 3 or 4 inputs');
}
const a = inputs[0];
const aRank = a.dims.length;
if (a.dims[aRank - 1] !== attributes.k) {
throw new Error('The last dim of input shape does not match the k value');
}
const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize);
const blobSize = attributes.blockSize / 8 * attributes.bits;
const b = inputs[1];
if (!ShapeUtil.areEqual(b.dims, [attributes.n, nBlocksPerCol, blobSize])) {
throw new Error('The second inputs must be 3D tensor with shape N X nBlocksPerCol X blobSize');
}
const scales = inputs[2];
const scalesShape = scales.dims;
if (ShapeUtil.size(scalesShape) !== attributes.n * nBlocksPerCol) {
throw new Error('scales input size error.');
}
if (inputs.length === 4) {
const zeroPoints = inputs[3];
const zeroPointsShape = zeroPoints.dims;
const expectedZeroPointsSize =
attributes.bits > 4 ? (attributes.n * nBlocksPerCol) : attributes.n * Math.floor((nBlocksPerCol + 1) / 2);
if (ShapeUtil.size(zeroPointsShape) !== expectedZeroPointsSize) {
throw new Error('zeroPoints input size error.');
}
}
};
export const createMatMulNBitsProgramInfo =
(inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): ProgramInfo => {
const a = inputs[0];
const b = inputs[1];
const scales = inputs[2];
const aRank = a.dims.length;
const outputShape = a.dims.slice(0, aRank - 1).concat(attributes.n);
const outputSize = ShapeUtil.size(outputShape);
const programUniforms: ProgramUniform[] = [
{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.k},
{type: DataType.uint32, data: attributes.n}, {type: DataType.uint32, data: attributes.accuracyLevel},
{type: DataType.uint32, data: attributes.bits}, {type: DataType.uint32, data: attributes.blockSize}
];
programUniforms.push(...createTensorShapeVariables(a.dims));
programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(b.dims)));
programUniforms.push(...createTensorShapeVariables(scales.dims));
if (inputs.length === 4) {
programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(inputs[3].dims)));
}
programUniforms.push(...createTensorShapeVariables(outputShape));
const getShaderSource = (shaderHelper: ShaderHelper) => {
const a = inputVariable('a', inputs[0].dataType, inputs[0].dims.length);
const b = inputVariable('b', DataType.uint32, inputs[1].dims.length);
const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length);
const inputVariables = [a, b, scales];
const zeroPoints =
inputs.length === 4 ? inputVariable('zero_points', DataType.uint32, inputs[3].dims.length) : undefined;
if (zeroPoints) {
inputVariables.push(zeroPoints);
}
const output = outputVariable('output', inputs[0].dataType, outputShape.length);
const uniforms: UniformsArrayType = [
{name: 'output_size', type: 'u32'}, {name: 'k', type: 'u32'}, {name: 'n', type: 'u32'},
{name: 'accuracy_level', type: 'u32'}, {name: 'bits', type: 'u32'}, {name: 'block_size', type: 'u32'}
];
const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize);
const blobSize = attributes.blockSize / 8 * attributes.bits;
const wordPerBlob = blobSize / 4;
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
return `
fn ortUnpack8x4snorm(value: u32) -> array<${dataType}, 8>{
var result = array<${dataType}, 8>();
var offset: u32 = 0;
let count: u32 = 4;
for (var i: u32 = 0; i < 8u; i++) {
result[i] = ${dataType}(extractBits(value, offset, count));
offset += count;
}
return result;
}
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
var value: ${dataType} = 0.0;
let output_indices = ${output.offsetToIndices('global_idx')};
var a_indices: ${a.type.indices} = output_indices;
var n = ${output.indicesGet('output_indices', aRank - 1)};
// Two zero points are packed into one byte because uniforms.bits <= 4.
// zero_point_offset is either 0 or 4. It is bit offset within one byte.
// TODO support zero_point_offset for bits > 4
${
zeroPoints ? `
var zero_point_index: u32 = n * ((${nBlocksPerCol} + 1) / 2) / 4;
var zero_point_word: u32 = ${zeroPoints.getByOffset('zero_point_index')};
var zero_point_offset: u32 = 0;` :
''}
var scale_idex = n * ${nBlocksPerCol};
var b_indices: ${b.type.indices};
${b.indicesSet('b_indices', '0', 'n')};
var block_offset: u32 = 0;
for (var block: u32 = 0; block < ${nBlocksPerCol}; block++) {
// The scale and zero points are computed per block.
let scale = ${scales.getByOffset('scale_idex')};
// The default zero point is 8 for unsigned 4-bit quantization.
let zero_point: ${dataType} = ${
zeroPoints ? `${dataType}(extractBits(zero_point_word, zero_point_offset, 4))` : 8.0};
${b.indicesSet('b_indices', '1', 'block')};
var word_offset: u32 = block_offset;
for (var word: u32 = 0; word < ${wordPerBlob}; word++) {
${b.indicesSet('b_indices', '2', 'word')};
let b_value = ${b.getByIndices('b_indices')};
let b_quantized_values: array<${dataType}, 8> = ortUnpack8x4snorm(b_value);
// Number of B elements per 32-bit word is 32/bits = 32/4 = 8
var offset: u32 = word_offset;
for (var i: u32 = 0; i < 8; i++) {
${a.indicesSet('a_indices', aRank - 1, 'offset')};
let a_value = ${a.getByIndices('a_indices')};
let b_quantized_value = b_quantized_values[i];
let b_dequantized_value = (b_quantized_value - zero_point) * scale;
value += a_value * b_dequantized_value;
offset++;
}
word_offset += 8;
}
scale_idex++;
${
zeroPoints ? `
if (zero_point_offset == 28) {
zero_point_offset = 0;
zero_point_index++;
zero_point_word = ${zeroPoints.getByOffset('zero_point_index')};
} else {
zero_point_offset += 4;
}` :
''}
block_offset += uniforms.block_size;
}
${output.setByOffset('global_idx', 'value')};
}
`;
};
return {
name: 'MatMulNBits',
shaderCache:
{hint: `${attributes.cacheKey};${inputs.length}`, inputDependencies: Array(inputs.length).fill('rank')},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64)},
programUniforms
}),
getShaderSource
};
};
export const matMulNBits = (context: ComputeContext, attributes: MatMulNBitsAttributes): void => {
validateInputs(context.inputs, attributes);
context.compute(createMatMulNBitsProgramInfo(context.inputs, attributes));
};
export const parseMatMulNBitsAttributes = (attributes: Record<string, unknown>): MatMulNBitsAttributes =>
createAttributeWithCacheKey(attributes as Omit<MatMulNBitsAttributes, keyof AttributeWithCacheKey>);

File diff suppressed because it is too large Load diff

View file

@ -1362,6 +1362,7 @@
"less.jsonc",
"log.jsonc",
"matmul.jsonc",
"matmulnbits.jsonc",
"matmul-broadcast.jsonc",
"mul.jsonc",
"mul_int32.jsonc",

View file

@ -8,13 +8,14 @@ namespace contrib {
namespace js {
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Attention);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FastGelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FastGelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MatMulNBits);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization);
template <>
KernelCreateInfo BuildKernelCreateInfo<void>() {
@ -25,14 +26,15 @@ KernelCreateInfo BuildKernelCreateInfo<void>() {
Status RegisterJsContribKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FastGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FastGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MatMulNBits)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1,
SkipLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv)>};
SkipLayerNormalization)>};
for (auto& function_table_entry : function_table) {
KernelCreateInfo info = function_table_entry();

View file

@ -0,0 +1,25 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/js/quantization/matmul_nbits.h"
#include "core/providers/js/js_data_types.h"
namespace onnxruntime {
namespace contrib {
namespace js {
using onnxruntime::js::JsepSupportedFloatTypes;
ONNX_OPERATOR_KERNEL_EX(
MatMulNBits,
kMSDomain,
1,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T1", JsepSupportedFloatTypes())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<uint8_t>()),
MatMulNBits);
} // namespace js
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,48 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/js/js_kernel.h"
namespace onnxruntime {
namespace contrib {
namespace js {
using onnxruntime::js::JsKernel;
class MatMulNBits final : public JsKernel {
public:
MatMulNBits(const OpKernelInfo& info) : JsKernel(info),
K_{narrow<size_t>(info.GetAttr<int64_t>("K"))},
N_{narrow<size_t>(info.GetAttr<int64_t>("N"))},
accuracy_level_{info.GetAttrOrDefault<int64_t>("accuracy_level", 0)},
nbits_{narrow<size_t>(info.GetAttr<int64_t>("bits"))},
block_size_{narrow<size_t>(info.GetAttr<int64_t>("block_size"))} {
ORT_ENFORCE(nbits_ == 4,
"Only 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
ORT_ENFORCE(block_size_ >= 16 && !(block_size_ & (block_size_ - 1)),
"Block size must be a power of 2 and greater than or equal to 16.");
JSEP_INIT_KERNEL_ATTRIBUTE(MatMulNBits, ({
"k" : $1,
"n" : $2,
"accuracyLevel" : $3,
"bits" : $4,
"blockSize" : $5
}),
static_cast<int32_t>(K_),
static_cast<int32_t>(N_),
static_cast<int32_t>(accuracy_level_),
static_cast<int32_t>(nbits_),
static_cast<int32_t>(block_size_));
}
private:
const size_t K_;
const size_t N_;
const int64_t accuracy_level_;
const size_t nbits_;
const size_t block_size_;
};
} // namespace js
} // namespace contrib
} // namespace onnxruntime