[JS/WebGPU] Improve MatMulNBits perf (#19974)

### Description
<!-- Describe your changes. -->
Improve performance using shared memory


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
Satya Kumar Jandhyala 2024-04-12 11:03:05 -07:00 committed by GitHub
parent 794d39a977
commit b33216be4c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 1090 additions and 121 deletions

View file

@ -92,6 +92,17 @@ class ComputeContextImpl implements ComputeContext {
this.inputs = inputs;
}
getMaxComputeWorkgroupSizes(): [number, number, number] {
return [
this.backend.device.limits.maxComputeWorkgroupSizeX, this.backend.device.limits.maxComputeWorkgroupSizeY,
this.backend.device.limits.maxComputeWorkgroupSizeZ
];
}
getMaxComputeWorkgroupStoragesize(): number {
return this.backend.device.limits.maxComputeWorkgroupStorageSize;
}
compute(program: ProgramInfo, inputsOutputsMapping?: ComputeContextInputsOutputsMapping): TensorView[] {
// prepare inputs. inputs should always be valid data.
const mappedInputs =

View file

@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {DataType} from '../../../wasm-common';
import {DataType, getTensorElementSize} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
@ -50,38 +50,49 @@ const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAt
};
export const createMatMulNBitsProgramInfo =
(inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): ProgramInfo => {
(inputs: readonly TensorView[], attributes: MatMulNBitsAttributes,
maxComputeWorkgroupSizes: [number, number, number], maxComputeWorkgroupStorageSize: number): ProgramInfo => {
const inputShape = inputs[0].dims;
const aRank = inputShape.length;
const outputShape = inputShape.slice(0, aRank - 1).concat(attributes.n);
const m = inputShape[aRank - 2];
const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize);
const dimAOuter = inputShape[aRank - 2];
const dimInner = attributes.k;
const dimBOuter = attributes.n;
const batchDims = inputShape.slice(0, aRank - 2);
const batchSize = ShapeUtil.size(batchDims);
const blobSize = attributes.blockSize / 8 * attributes.bits;
const blobSizeInWords = blobSize / 4;
const outputNumber = getMaxComponents(m);
const components = getMaxComponents(attributes.n);
const dataType = inputs[0].dataType;
const outputNumber = getMaxComponents(dimAOuter);
const aComponents = getMaxComponents(attributes.k);
const bComponents = getMaxComponents(blobSizeInWords);
const elementSize = getTensorElementSize(dataType)!;
const workgroupOutputSize = dimAOuter * nBlocksPerCol * elementSize;
const maxNumberOfComponents = Math.floor(maxComputeWorkgroupStorageSize / workgroupOutputSize);
const useBlockwiseMatMulNBits = nBlocksPerCol <= maxComputeWorkgroupSizes[0] && maxNumberOfComponents > 0;
const components = (!useBlockwiseMatMulNBits || maxNumberOfComponents >= 4) ? getMaxComponents(dimBOuter) :
((maxNumberOfComponents >= 2) && getMaxComponents(dimBOuter) >= 2) ? 2 :
1;
const outputShape = batchDims.concat([dimAOuter, dimBOuter]);
const outputSize = ShapeUtil.size(outputShape) / components / outputNumber;
const programUniforms: ProgramUniform[] = [
{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.k},
{type: DataType.uint32, data: attributes.n}, {type: DataType.uint32, data: attributes.accuracyLevel},
{type: DataType.uint32, data: attributes.bits}, {type: DataType.uint32, data: attributes.blockSize}
];
const aShape = inputShape.slice();
aShape.splice(-1, 1, attributes.k / aComponents);
const programUniforms: ProgramUniform[] = useBlockwiseMatMulNBits ?
[] :
[{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.blockSize}];
const inputShapeTemp = [batchSize, dimAOuter, dimInner / aComponents];
const bShape = ShapeUtil.convertShape(inputs[1].dims).slice();
bShape.splice(-1, 1, blobSizeInWords / bComponents);
programUniforms.push(...createTensorShapeVariables(aShape));
programUniforms.push(...createTensorShapeVariables(inputShapeTemp));
programUniforms.push(...createTensorShapeVariables(bShape));
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
if (inputs.length === 4) {
programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(inputs[3].dims)));
}
const oShape = outputShape.slice();
oShape.splice(-1, 1, attributes.n / components);
programUniforms.push(...createTensorShapeVariables(oShape));
const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components];
programUniforms.push(...createTensorShapeVariables(outputShapeTemp));
const getShaderSource = (shaderHelper: ShaderHelper) => {
const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents);
const inputRank = inputShapeTemp.length;
const a = inputVariable('a', inputs[0].dataType, inputRank, aComponents);
const b = inputVariable('b', DataType.uint32, bShape.length, bComponents);
const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length);
const inputVariables = [a, b, scales];
@ -90,12 +101,9 @@ export const createMatMulNBitsProgramInfo =
if (zeroPoints) {
inputVariables.push(zeroPoints);
}
const output = outputVariable('output', inputs[0].dataType, outputShape.length, components);
const uniforms: UniformsArrayType = [
{name: 'output_size', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'},
{name: 'accuracy_level', type: 'u32'}, {name: 'bits', type: 'u32'}, {name: 'block_size', type: 'u32'}
];
const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize);
const outputRank = outputShapeTemp.length;
const output = outputVariable('output', inputs[0].dataType, outputRank, components);
const uniforms: UniformsArrayType = [{name: 'output_size', type: 'u32'}, {name: 'block_size', type: 'u32'}];
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const qDqDataType = (() => {
@ -111,43 +119,49 @@ export const createMatMulNBitsProgramInfo =
}
})();
const dequantizeImpl = `
fn dequantize(quantized: ${qDqDataType}, zero_point: ${dataType}, scale: ${dataType}) -> ${qDqDataType} {
${(() => {
const processOneBlock = `
for (var word: u32 = 0; word < ${blobSizeInWords}; word += ${bComponents}) {
${b.indicesSet('b_indices', '2', 'word')};
let b_data = ${b.getByIndices('b_indices')};
for (var i: u32 = 0; i < ${bComponents}; i++) {
let b_value: u32 = ${bComponents === 1 ? 'b_data' : 'b_data[word + i]'};
let b_mask: u32 = 0x0F0F0F0Fu;
let b_value_lower: vec4<u32> = unpack4xU8(b_value & b_mask);
let b_value_upper: vec4<u32> = unpack4xU8((b_value >> 4) & b_mask);
let b_quantized_values = ${qDqDataType}(${
Array.from({length: 4}, (_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`)
.join(', ')});
let b_dequantized_values = ${(() => {
if (aComponents === 1) {
return `var dequantized = ${qDqDataType}(${
Array.from({length: 8}, (_, i) => `(quantized[${i}] - zero_point) * scale`).join(', ')});
return dequantized;`;
return `${qDqDataType}(${
Array.from({length: 8}, (_, i) => `(b_quantized_values[${i}] - zero_point) * scale`).join(', ')});`;
} else {
return `var zero_points: ${qDqDataType} = ${qDqDataType}(${Array(8).fill('zero_point').join(',')});
return (quantized - zero_points) * scale;`;
return `(b_quantized_values - ${qDqDataType}(${Array(8).fill('zero_point').join(',')})) * scale;`;
}
})()};
// Number of B elements per 32-bit word is 32/bits = 32/4 = 8
for (var m: u32 = 0; m < ${useBlockwiseMatMulNBits ? dimAOuter : outputNumber}u; m++) {
${a.indicesSet('a_indices', inputRank - 2, useBlockwiseMatMulNBits ? 'm' : `row * ${outputNumber} + m`)};
${a.indicesSet('a_indices', inputRank - 1, 'word_offset')};
var input_offset = ${a.indicesToOffset('a_indices')};
var a_data: ${qDqDataType};
for (var j: u32 = 0; j < ${8 / aComponents}; j++) {
a_data[j] = ${a.getByOffset('input_offset')};
input_offset++;
}
${useBlockwiseMatMulNBits ? 'workgroup_shared[workgroup_shared_offset + m]' : 'output_values[m]'}${
components > 1 ? '[c]' : ''} += ${
Array
.from(
{length: 8 / aComponents},
(_, i) => `${
aComponents === 1 ? `a_data[${i}] * b_dequantized_values[${i}]` :
`dot(a_data[${i}], b_dequantized_values[${i}])`}`)
.join(' + ')};
}
word_offset += ${8 / aComponents};
}
})()}
}`;
const ortUnpack8x4snormImpl = `
fn ortUnpack8x4snorm(value: u32) -> ${qDqDataType} {
var quantized: ${qDqDataType};
var offset: u32 = 0;
let count: u32 = 4;
for (var i: u32 = 0; i < 8u; i++) {
var result = ${dataType}(extractBits(value, offset, count));
${(() => {
switch (aComponents) {
case 1:
return 'quantized[i] = result;';
case 2:
return 'quantized[i / 2][i % 2] = result;';
case 4:
return 'quantized[i / 4][i % 4] = result;';
default:
throw new Error(`${aComponents}-component is not supported.`);
}
})()}
offset += count;
}
return quantized;
}`;
const updateZeroPointIndex = zeroPoints ? `
zero_point_offset += 4;
if (zero_point_offset == 32) {
@ -157,30 +171,84 @@ export const createMatMulNBitsProgramInfo =
}` :
'';
return `
${dequantizeImpl};
${ortUnpack8x4snormImpl};
return useBlockwiseMatMulNBits ? `
var<workgroup> workgroup_shared: array<${output.type.value}, ${dimAOuter * nBlocksPerCol}>;
${shaderHelper.declareVariables(...inputVariables, output)}
${shaderHelper.mainStart([
nBlocksPerCol, 1, 1
])}
var a_indices: ${a.type.indices};
var block = local_id.x;
var col = workgroup_id.y;
var batch = workgroup_id.z;
${a.indicesSet('a_indices', '0', 'batch')};
// Two zero points are packed into one byte when uniforms.bits is 4.
for (var c: u32 = 0; c < ${components}; c++) {
let col_times_components_plus_c = col * ${components} + c;
${
zeroPoints ? `
var zero_point_bytes_per_col: u32 = (${nBlocksPerCol} + 1) / 2;
var zero_point_byte_count: u32 = col_times_components_plus_c * zero_point_bytes_per_col + (block >> 0x1u);
var zero_point_word_index: u32 = zero_point_byte_count >> 0x2u;
var zero_point_byte_offset: u32 = zero_point_byte_count & 0x3u;
var zero_point_nibble_offset: u32 = block & 0x1u;
var zero_point_bits_offset: u32 = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);
var zero_point_word: u32 = ${zeroPoints.getByOffset('zero_point_word_index')} >> zero_point_bits_offset;` :
''}
var b_indices: ${b.type.indices};
${b.indicesSet('b_indices', '0', 'col_times_components_plus_c')};
// The scale and zero points are computed per block.
var scales_index = col_times_components_plus_c * ${nBlocksPerCol} + block;
let scale = ${scales.getByOffset('scales_index')};
// The default zero point is 8 for unsigned 4-bit quantization.
let zero_point = ${dataType}(${zeroPoints ? '(zero_point_word) & 0xFu' : 8.0});
${b.indicesSet('b_indices', '1', 'block')};
var word_offset: u32 = block * ${attributes.blockSize / aComponents};
var workgroup_shared_offset: u32 = block * ${dimAOuter};
${processOneBlock}
}
workgroupBarrier();
if (local_id.x == 0u) {
var output_indices: ${output.type.indices};
${output.indicesSet('output_indices', '0', 'batch')};
${output.indicesSet('output_indices', outputRank - 1, 'col')};
${output.indicesSet('output_indices', outputRank - 2, '0')};
var output_offset = ${output.indicesToOffset('output_indices')};
for (var m: u32 = 0u; m < ${dimAOuter}u; m++) {
var output_value: ${output.type.value} = ${output.type.value}(0);
var workgroup_shared_offset: u32 = m;
for (var b: u32 = 0u; b < ${nBlocksPerCol}u; b++) {
output_value += workgroup_shared[workgroup_shared_offset];
workgroup_shared_offset += ${dimAOuter};
}
${output.setByOffset('output_offset', 'output_value')};
output_offset += ${dimBOuter / components};
}
}
}` :
`
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
var output_values: array<${output.type.value}, ${outputNumber}>;
var output_indices = ${output.offsetToIndices('global_idx')};
var n = ${output.indicesGet('output_indices', aRank - 1)};
var m = ${output.indicesGet('output_indices', aRank - 2)};
var col = ${output.indicesGet('output_indices', outputRank - 1)};
var row = ${output.indicesGet('output_indices', outputRank - 2)};
var a_indices: ${a.type.indices} = output_indices;
// Two zero points are packed into one byte because uniforms.bits <= 4.
// zero_point_offset is either 0 or 4. It is bit offset within one byte.
// TODO support zero_point_offset for bits > 4
${
zeroPoints ? `
var zero_point_index: u32 = n * ${components} * ((${nBlocksPerCol} + 1) / 2) / 4;
zeroPoints ? `
var zero_point_abs_offset = col * ${components} * ((${nBlocksPerCol} + 1) / 2);
var zero_point_index: u32 = zero_point_abs_offset / 4;
var zero_point_word: u32 = ${zeroPoints.getByOffset('zero_point_index')};
var zero_point_offset: u32 = 0;` :
''}
var scale_index = n * ${nBlocksPerCol * components};
var zero_point_offset: u32 = (zero_point_abs_offset % 4) * 8;` :
''}
var scale_index = col * ${nBlocksPerCol * components};
var b_indices: ${b.type.indices};
for (var c: u32 = 0; c < ${components}; c++) {
${b.indicesSet('b_indices', '0', `n * ${components} + c`)};
${b.indicesSet('b_indices', '0', `col * ${components} + c`)};
var block_offset: u32 = 0;
for (var block: u32 = 0; block < ${nBlocksPerCol}; block++) {
// The scale and zero points are computed per block.
@ -189,52 +257,35 @@ export const createMatMulNBitsProgramInfo =
let zero_point = ${dataType}(${zeroPoints ? 'extractBits(zero_point_word, zero_point_offset, 4)' : 8.0});
${b.indicesSet('b_indices', '1', 'block')};
var word_offset: u32 = block_offset;
for (var word: u32 = 0; word < ${blobSizeInWords}; word += ${bComponents}) {
${b.indicesSet('b_indices', '2', 'word')};
let b_data = ${b.getByIndices('b_indices')};
for (var i: u32 = 0; i < ${bComponents}; i++) {
let b_value = ${bComponents === 1 ? 'b_data' : 'b_data[word + i]'};
let b_quantized_values: ${qDqDataType} = ortUnpack8x4snorm(b_value);
let b_dequantized_values = dequantize(b_quantized_values, zero_point, scale);
// Number of B elements per 32-bit word is 32/bits = 32/4 = 8
var offset: u32 = word_offset;
for (var j: u32 = 0; j < 8/${aComponents}; j++) {
${a.indicesSet('a_indices', aRank - 1, `offset/${aComponents}`)};
for (var k: u32 = 0; k < ${outputNumber}u; k++) {
${a.indicesSet('a_indices', aRank - 2, `m * ${outputNumber} + k`)};
let a_data = ${a.getByIndices('a_indices')};
output_values[k]${components > 1 ? '[c]' : ''} += ${
aComponents === 1 ? 'a_data * b_dequantized_values[j]' : 'dot(a_data, b_dequantized_values[j])'};
}
offset += ${aComponents};
}
word_offset += 8;
}
}
${processOneBlock}
scale_index++;
${updateZeroPointIndex}
block_offset += uniforms.block_size;
block_offset += uniforms.block_size / ${aComponents};
}
// Drop the trailing 4 bits if the zero_poit_offset is not a byte boundary to align with the next byte.
${
zeroPoints ? `if (zero_point_offset % 8 > 0) {
zeroPoints ? `if (zero_point_offset % 8 > 0) {
${updateZeroPointIndex}
}` :
''}
''}
}
for (var k: u32 = 0u; k < ${outputNumber}u; k++) {
${output.indicesSet('output_indices', aRank - 2, `${outputNumber + ' * m + k'}`)};
${output.indicesSet('output_indices', outputRank - 2, `${outputNumber} * row + k`)};
${output.setByIndices('output_indices', 'output_values[k]')}
}
}`;
};
return {
name: 'MatMulNBits',
shaderCache:
{hint: `${attributes.cacheKey};${inputs.length}`, inputDependencies: Array(inputs.length).fill('rank')},
name: useBlockwiseMatMulNBits ? 'BlockwiseMatMulNBits' : 'MatMulNBits',
shaderCache: {
hint: `${attributes.cacheKey};${dimAOuter};${dataType};${inputs.length}`,
inputDependencies: Array(inputs.length).fill('rank')
},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
outputs: [{dims: outputShape, dataType}],
name: useBlockwiseMatMulNBits ? 'BlockwiseMatMulNBits' : 'MatMulNBits',
dispatchGroup: useBlockwiseMatMulNBits ? {x: 1, y: Math.ceil(dimBOuter / components), z: batchSize} :
{x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms
}),
getShaderSource
@ -243,7 +294,10 @@ export const createMatMulNBitsProgramInfo =
export const matMulNBits = (context: ComputeContext, attributes: MatMulNBitsAttributes): void => {
validateInputs(context.inputs, attributes);
context.compute(createMatMulNBitsProgramInfo(context.inputs, attributes));
const maxComputeWorkgroupSizes: [number, number, number] = context.getMaxComputeWorkgroupSizes();
const maxComputeWorkgroupStorageSize = context.getMaxComputeWorkgroupStoragesize();
context.compute(createMatMulNBitsProgramInfo(
context.inputs, attributes, maxComputeWorkgroupSizes, maxComputeWorkgroupStorageSize));
};
export const parseMatMulNBitsAttributes = (attributes: Record<string, unknown>): MatMulNBitsAttributes =>

View file

@ -188,6 +188,8 @@ export interface ComputeContext {
compute(program: ProgramInfo, inputsOutputsMapping?: ComputeContextInputsOutputsMapping): TensorView[];
output(index: number, dims: readonly number[]): number;
getMaxComputeWorkgroupSizes(): [number, number, number];
getMaxComputeWorkgroupStoragesize(): number;
}
export type TimestampQuery = 'none'|'inside-passes'|'at-passes';

View file

@ -1,6 +1,6 @@
[
{
"name": "MatMulNBits; K=16, N=16, block_size=16, bits=4",
"name": "MatMulNBits; K=16, N=8, block_size=16, bits=4",
"operator": "MatMulNBits",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [
@ -11,7 +11,7 @@
],
"cases": [
{
"name": "MatMulNBits; K=16, N=16, block_size=16, bits=4; symmetric",
"name": "MatMulNBits; K=16, N=8, block_size=16, bits=4; symmetric",
"inputs": [
{
"data": [
@ -56,6 +56,647 @@
}
]
},
{
"name": "MatMulNBits; K=16, N=8, block_size=16, bits=4",
"operator": "MatMulNBits",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [
{ "name": "K", "data": 16, "type": "int" },
{ "name": "N", "data": 8, "type": "int" },
{ "name": "block_size", "data": 16, "type": "int" },
{ "name": "bits", "data": 4, "type": "int" }
],
"cases": [
{
"name": "MatMulNBits; K=16, N=8, block_size=16, bits=4; asymmetric",
"inputs": [
{
"data": [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105,
106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126,
127
],
"dims": [8, 16],
"type": "float32"
},
{
"dims": [8, 1, 8],
"type": "uint8",
"data": [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64
]
},
{
"dims": [8],
"type": "float32",
"data": [0, 1, 2, 3, 4, 5, 6, 7]
},
{
"dims": [8],
"type": "uint8",
"data": [248, 249, 250, 251, 252, 253, 254, 255]
}
],
"outputs": [
{
"dims": [8, 8],
"type": "float32",
"data": [
0, -505, -1600, -2043, -3904, -4285, -6912, -7231, 0, -1449, -5312, -6027, -12864, -12845, -22656, -21903,
0, -2393, -9024, -10011, -21824, -21405, -38400, -36575, 0, -3337, -12736, -13995, -30784, -29965, -54144,
-51247, 0, -4281, -16448, -17979, -39744, -38525, -69888, -65919, 0, -5225, -20160, -21963, -48704,
-47085, -85632, -80591, 0, -6169, -23872, -25947, -57664, -55645, -101376, -95263, 0, -7113, -27584,
-29931, -66624, -64205, -117120, -109935
]
}
]
}
]
},
{
"name": "MatMulNBits; K=32, N=8, block_size=16, bits=4",
"operator": "MatMulNBits",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [
{ "name": "K", "data": 32, "type": "int" },
{ "name": "N", "data": 8, "type": "int" },
{ "name": "block_size", "data": 16, "type": "int" },
{ "name": "bits", "data": 4, "type": "int" }
],
"cases": [
{
"name": "MatMulNBits; K=32, N=8, block_size=16, bits=4; symmetric",
"inputs": [
{
"data": [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105,
106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126,
127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147,
148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168,
169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189,
190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210,
211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231,
232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252,
253, 254, 255
],
"dims": [8, 32],
"type": "float32"
},
{
"dims": [8, 2, 8],
"type": "uint8",
"data": [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106,
107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127,
128
]
},
{
"dims": [16],
"type": "float32",
"data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
}
],
"outputs": [
{
"dims": [8, 8],
"type": "float32",
"data": [
-1073, -3763, -5429, -6071, -5689, -4283, -1853, 1601, -2449, -12499, -19477, -23383, -24217, -21979,
-16669, -8287, -3825, -21235, -33525, -40695, -42745, -39675, -31485, -18175, -5201, -29971, -47573,
-58007, -61273, -57371, -46301, -28063, -6577, -38707, -61621, -75319, -79801, -75067, -61117, -37951,
-7953, -47443, -75669, -92631, -98329, -92763, -75933, -47839, -9329, -56179, -89717, -109943, -116857,
-110459, -90749, -57727, -10705, -64915, -103765, -127255, -135385, -128155, -105565, -67615
]
}
]
}
]
},
{
"name": "MatMulNBits; K=32, N=8, block_size=16, bits=4",
"operator": "MatMulNBits",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [
{ "name": "K", "data": 32, "type": "int" },
{ "name": "N", "data": 8, "type": "int" },
{ "name": "block_size", "data": 16, "type": "int" },
{ "name": "bits", "data": 4, "type": "int" }
],
"cases": [
{
"name": "MatMulNBits; K=32, N=8, block_size=16, bits=4; asymmetric",
"inputs": [
{
"data": [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105,
106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126,
127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147,
148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168,
169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189,
190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210,
211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231,
232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252,
253, 254, 255
],
"dims": [8, 32],
"type": "float32"
},
{
"dims": [8, 2, 8],
"type": "uint8",
"data": [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106,
107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127,
128
]
},
{
"dims": [16],
"type": "float32",
"data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
},
{
"dims": [8],
"type": "uint8",
"data": [0, 1, 2, 3, 4, 5, 6, 7]
}
],
"outputs": [
{
"dims": [8, 8],
"type": "float32",
"data": [
1935, 6941, 12491, 18585, 25223, 32405, 40131, 48401, 4655, 17661, 31211, 45305, 59943, 75125, 90851,
107121, 7375, 28381, 49931, 72025, 94663, 117845, 141571, 165841, 10095, 39101, 68651, 98745, 129383,
160565, 192291, 224561, 12815, 49821, 87371, 125465, 164103, 203285, 243011, 283281, 15535, 60541, 106091,
152185, 198823, 246005, 293731, 342001, 18255, 71261, 124811, 178905, 233543, 288725, 344451, 400721,
20975, 81981, 143531, 205625, 268263, 331445, 395171, 459441
]
}
]
}
]
},
{
"name": "MatMulNBits; K=48, N=8, block_size=16, bits=4",
"operator": "MatMulNBits",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [
{ "name": "K", "data": 48, "type": "int" },
{ "name": "N", "data": 8, "type": "int" },
{ "name": "block_size", "data": 16, "type": "int" },
{ "name": "bits", "data": 4, "type": "int" }
],
"cases": [
{
"name": "MatMulNBits; K=48, N=8, block_size=16, bits=4; symmetric",
"inputs": [
{
"data": [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105,
106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126,
127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147,
148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168,
169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189,
190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210,
211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231,
232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252,
253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273,
274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294,
295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315,
316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336,
337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357,
358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378,
379, 380, 381, 382, 383
],
"dims": [8, 48],
"type": "float32"
},
{
"dims": [8, 3, 8],
"type": "uint8",
"data": [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106,
107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127,
128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148,
149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190,
191, 192
]
},
{
"dims": [24],
"type": "float32",
"data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]
}
],
"outputs": [
{
"dims": [8, 8],
"type": "float32",
"data": [
-7569, -13416, -24375, -14292, -20445, 5568, 4221, 46164, -17697, -39528, -73383, -45588, -66861, 10560,
1869, 128916, -27825, -65640, -122391, -76884, -113277, 15552, -483, 211668, -37953, -91752, -171399,
-108180, -159693, 20544, -2835, 294420, -48081, -117864, -220407, -139476, -206109, 25536, -5187, 377172,
-58209, -143976, -269415, -170772, -252525, 30528, -7539, 459924, -68337, -170088, -318423, -202068,
-298941, 35520, -9891, 542676, -78465, -196200, -367431, -233364, -345357, 40512, -12243, 625428
]
}
]
}
]
},
{
"name": "MatMulNBits; K=48, N=8, block_size=16, bits=4",
"operator": "MatMulNBits",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [
{ "name": "K", "data": 48, "type": "int" },
{ "name": "N", "data": 8, "type": "int" },
{ "name": "block_size", "data": 16, "type": "int" },
{ "name": "bits", "data": 4, "type": "int" }
],
"cases": [
{
"name": "MatMulNBits; K=48, N=8, block_size=16, bits=4; asymmetric",
"inputs": [
{
"data": [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105,
106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126,
127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147,
148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168,
169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189,
190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210,
211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231,
232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252,
253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273,
274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294,
295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315,
316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336,
337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357,
358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378,
379, 380, 381, 382, 383
],
"dims": [8, 48],
"type": "float32"
},
{
"dims": [8, 3, 8],
"type": "uint8",
"data": [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106,
107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127,
128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148,
149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190,
191, 192
]
},
{
"dims": [24],
"type": "float32",
"data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]
},
{
"dims": [16],
"type": "uint8",
"data": [240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255]
}
],
"outputs": [
{
"dims": [8, 8],
"type": "float32",
"data": [
-1353, -5984, -24751, -31500, -63509, -72376, -117627, -128612, -6105, -20576, -74527, -94284, -190565,
-215608, -354219, -384548, -10857, -35168, -124303, -157068, -317621, -358840, -590811, -640484, -15609,
-49760, -174079, -219852, -444677, -502072, -827403, -896420, -20361, -64352, -223855, -282636, -571733,
-645304, -1063995, -1152356, -25113, -78944, -273631, -345420, -698789, -788536, -1300587, -1408292,
-29865, -93536, -323407, -408204, -825845, -931768, -1537179, -1664228, -34617, -108128, -373183, -470988,
-952901, -1075000, -1773771, -1920164
]
}
]
}
]
},
{
"name": "MatMulNBits; K=64, N=8, block_size=16, bits=4",
"operator": "MatMulNBits",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [
{ "name": "K", "data": 64, "type": "int" },
{ "name": "N", "data": 8, "type": "int" },
{ "name": "block_size", "data": 16, "type": "int" },
{ "name": "bits", "data": 4, "type": "int" }
],
"cases": [
{
"name": "MatMulNBits; K=64, N=8, block_size=16, bits=4; symmetric",
"inputs": [
{
"data": [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105,
106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126,
127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147,
148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168,
169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189,
190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210,
211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231,
232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252,
253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273,
274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294,
295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315,
316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336,
337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357,
358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378,
379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399,
400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420,
421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441,
442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462,
463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483,
484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504,
505, 506, 507, 508, 509, 510, 511
],
"dims": [8, 64],
"type": "float32"
},
{
"dims": [8, 4, 8],
"type": "uint8",
"data": [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106,
107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127,
128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148,
149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190,
191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232,
233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255, 256
]
},
{
"dims": [32],
"type": "float32",
"data": [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
29, 30, 31
]
}
],
"outputs": [
{
"dims": [8, 8],
"type": "float32",
"data": [
-13572, -28812, -27668, -10140, 23772, 74068, 140748, 192564, -33796, -91532, -100116, -59548, 30172,
169044, 357068, 531252, -54020, -154252, -172564, -108956, 36572, 264020, 573388, 869940, -74244, -216972,
-245012, -158364, 42972, 358996, 789708, 1208628, -94468, -279692, -317460, -207772, 49372, 453972,
1006028, 1547316, -114692, -342412, -389908, -257180, 55772, 548948, 1222348, 1886004, -134916, -405132,
-462356, -306588, 62172, 643924, 1438668, 2224692, -155140, -467852, -534804, -355996, 68572, 738900,
1654988, 2563380
]
}
]
}
]
},
{
"name": "MatMulNBits; K=64, N=8, block_size=16, bits=4",
"operator": "MatMulNBits",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [
{ "name": "K", "data": 64, "type": "int" },
{ "name": "N", "data": 8, "type": "int" },
{ "name": "block_size", "data": 16, "type": "int" },
{ "name": "bits", "data": 4, "type": "int" }
],
"cases": [
{
"name": "MatMulNBits; K=64, N=8, block_size=16, bits=4; asymmetric",
"inputs": [
{
"data": [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105,
106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126,
127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147,
148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168,
169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189,
190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210,
211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231,
232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252,
253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273,
274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294,
295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315,
316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336,
337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357,
358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378,
379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399,
400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420,
421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441,
442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462,
463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483,
484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504,
505, 506, 507, 508, 509, 510, 511
],
"dims": [8, 64],
"type": "float32"
},
{
"dims": [8, 4, 8],
"type": "uint8",
"data": [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106,
107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127,
128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148,
149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190,
191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232,
233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255, 256
]
},
{
"dims": [32],
"type": "float32",
"data": [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
29, 30, 31
]
},
{
"dims": [16],
"type": "uint8",
"data": [240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255]
}
],
"outputs": [
{
"dims": [8, 8],
"type": "float32",
"data": [
-26004, -63644, -96932, -125868, -150452, -170684, -186564, -229340, -60564, -157084, -249252, -337068,
-420532, -499644, -574404, -707804, -95124, -250524, -401572, -548268, -690612, -828604, -962244,
-1186268, -129684, -343964, -553892, -759468, -960692, -1157564, -1350084, -1664732, -164244, -437404,
-706212, -970668, -1230772, -1486524, -1737924, -2143196, -198804, -530844, -858532, -1181868, -1500852,
-1815484, -2125764, -2621660, -233364, -624284, -1010852, -1393068, -1770932, -2144444, -2513604,
-3100124, -267924, -717724, -1163172, -1604268, -2041012, -2473404, -2901444, -3578588
]
}
]
}
]
},
{
"name": "MatMulNBits; K=80, N=8, block_size=16, bits=4",
"operator": "MatMulNBits",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [
{ "name": "K", "data": 80, "type": "int" },
{ "name": "N", "data": 8, "type": "int" },
{ "name": "block_size", "data": 16, "type": "int" },
{ "name": "bits", "data": 4, "type": "int" }
],
"cases": [
{
"name": "MatMulNBits; K=80, N=8, block_size=16, bits=4; asymmetric",
"inputs": [
{
"data": [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105,
106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126,
127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147,
148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168,
169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189,
190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210,
211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231,
232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252,
253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273,
274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294,
295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315,
316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336,
337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357,
358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378,
379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399,
400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420,
421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441,
442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462,
463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483,
484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504,
505, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525,
526, 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546,
547, 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567,
568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588,
589, 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609,
610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630,
631, 632, 633, 634, 635, 636, 637, 638, 639
],
"dims": [8, 80],
"type": "float32"
},
{
"dims": [8, 5, 8],
"type": "uint8",
"data": [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106,
107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127,
128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148,
149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190,
191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232,
233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274,
275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295,
296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316,
317, 318, 319, 320
]
},
{
"dims": [40],
"type": "float32",
"data": [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39
]
},
{
"dims": [24],
"type": "uint8",
"data": [
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260,
261, 262, 263
]
}
],
"outputs": [
{
"dims": [8, 8],
"type": "float32",
"data": [
-19988, -63429, -155448, -216179, -358428, 740351, 259888, 172481, -56788, -186869, -451128, -632899,
-1053788, 1574031, 1165488, 546481, -93588, -310309, -746808, -1049619, -1749148, 2407711, 2071088,
920481, -130388, -433749, -1042488, -1466339, -2444508, 3241391, 2976688, 1294481, -167188, -557189,
-1338168, -1883059, -3139868, 4075071, 3882288, 1668481, -203988, -680629, -1633848, -2299779, -3835228,
4908751, 4787888, 2042481, -240788, -804069, -1929528, -2716499, -4530588, 5742431, 5693488, 2416481,
-277588, -927509, -2225208, -3133219, -5225948, 6576111, 6599088, 2790481
]
}
]
}
]
},
{
"name": "MatMulNBits; K=16, N=16, block_size=16, bits=4",
"operator": "MatMulNBits",
@ -188,7 +829,7 @@
{
"dims": [16],
"type": "uint8",
"data": [128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128]
"data": [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
}
],
"outputs": [
@ -196,24 +837,25 @@
"dims": [16, 16],
"type": "float32",
"data": [
0, 728, 688, 2376, 1632, 4280, 2832, 6440, 4288, 8856, 6000, 11528, 7968, 14456, 10192, 17640, 0, 2200,
1840, 7176, 4448, 12920, 7824, 19432, 11968, 26712, 16880, 34760, 22560, 43576, 29008, 53160, 0, 3672,
2992, 11976, 7264, 21560, 12816, 32424, 19648, 44568, 27760, 57992, 37152, 72696, 47824, 88680, 0, 5144,
4144, 16776, 10080, 30200, 17808, 45416, 27328, 62424, 38640, 81224, 51744, 101816, 66640, 124200, 0,
6616, 5296, 21576, 12896, 38840, 22800, 58408, 35008, 80280, 49520, 104456, 66336, 130936, 85456, 159720,
0, 8088, 6448, 26376, 15712, 47480, 27792, 71400, 42688, 98136, 60400, 127688, 80928, 160056, 104272,
195240, 0, 9560, 7600, 31176, 18528, 56120, 32784, 84392, 50368, 115992, 71280, 150920, 95520, 189176,
123088, 230760, 0, 11032, 8752, 35976, 21344, 64760, 37776, 97384, 58048, 133848, 82160, 174152, 110112,
218296, 141904, 266280, 0, 12504, 9904, 40776, 24160, 73400, 42768, 110376, 65728, 151704, 93040, 197384,
124704, 247416, 160720, 301800, 0, 13976, 11056, 45576, 26976, 82040, 47760, 123368, 73408, 169560,
103920, 220616, 139296, 276536, 179536, 337320, 0, 15448, 12208, 50376, 29792, 90680, 52752, 136360,
81088, 187416, 114800, 243848, 153888, 305656, 198352, 372840, 0, 16920, 13360, 55176, 32608, 99320,
57744, 149352, 88768, 205272, 125680, 267080, 168480, 334776, 217168, 408360, 0, 18392, 14512, 59976,
35424, 107960, 62736, 162344, 96448, 223128, 136560, 290312, 183072, 363896, 235984, 443880, 0, 19864,
15664, 64776, 38240, 116600, 67728, 175336, 104128, 240984, 147440, 313544, 197664, 393016, 254800,
479400, 0, 21336, 16816, 69576, 41056, 125240, 72720, 188328, 111808, 258840, 158320, 336776, 212256,
422136, 273616, 514920, 0, 22808, 17968, 74376, 43872, 133880, 77712, 201320, 119488, 276696, 169200,
360008, 226848, 451256, 292432, 550440
0, 608, 208, 1296, -288, 1280, -1488, 560, -3392, -864, -6000, -2992, -9312, -5824, -13328, -9360, 0,
1824, 336, 3792, -1568, 3520, -5712, 1008, -12096, -3744, -20720, -10736, -31584, -19968, -44688, -31440,
0, 3040, 464, 6288, -2848, 5760, -9936, 1456, -20800, -6624, -35440, -18480, -53856, -34112, -76048,
-53520, 0, 4256, 592, 8784, -4128, 8000, -14160, 1904, -29504, -9504, -50160, -26224, -76128, -48256,
-107408, -75600, 0, 5472, 720, 11280, -5408, 10240, -18384, 2352, -38208, -12384, -64880, -33968, -98400,
-62400, -138768, -97680, 0, 6688, 848, 13776, -6688, 12480, -22608, 2800, -46912, -15264, -79600, -41712,
-120672, -76544, -170128, -119760, 0, 7904, 976, 16272, -7968, 14720, -26832, 3248, -55616, -18144,
-94320, -49456, -142944, -90688, -201488, -141840, 0, 9120, 1104, 18768, -9248, 16960, -31056, 3696,
-64320, -21024, -109040, -57200, -165216, -104832, -232848, -163920, 0, 10336, 1232, 21264, -10528, 19200,
-35280, 4144, -73024, -23904, -123760, -64944, -187488, -118976, -264208, -186000, 0, 11552, 1360, 23760,
-11808, 21440, -39504, 4592, -81728, -26784, -138480, -72688, -209760, -133120, -295568, -208080, 0,
12768, 1488, 26256, -13088, 23680, -43728, 5040, -90432, -29664, -153200, -80432, -232032, -147264,
-326928, -230160, 0, 13984, 1616, 28752, -14368, 25920, -47952, 5488, -99136, -32544, -167920, -88176,
-254304, -161408, -358288, -252240, 0, 15200, 1744, 31248, -15648, 28160, -52176, 5936, -107840, -35424,
-182640, -95920, -276576, -175552, -389648, -274320, 0, 16416, 1872, 33744, -16928, 30400, -56400, 6384,
-116544, -38304, -197360, -103664, -298848, -189696, -421008, -296400, 0, 17632, 2000, 36240, -18208,
32640, -60624, 6832, -125248, -41184, -212080, -111408, -321120, -203840, -452368, -318480, 0, 18848,
2128, 38736, -19488, 34880, -64848, 7280, -133952, -44064, -226800, -119152, -343392, -217984, -483728,
-340560
]
}
]
@ -1580,5 +2222,265 @@
]
}
]
},
{
"name": "MatMulNBits; K=16, N=8, block_size=16, bits=4, batchDim = [1]",
"operator": "MatMulNBits",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [
{ "name": "K", "data": 16, "type": "int" },
{ "name": "N", "data": 8, "type": "int" },
{ "name": "block_size", "data": 16, "type": "int" },
{ "name": "bits", "data": 4, "type": "int" }
],
"cases": [
{
"name": "MatMulNBits; K=16, N=8, block_size=16, bits=4, batchDim = [1]; symmetric",
"inputs": [
{
"data": [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105,
106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126,
127
],
"dims": [1, 8, 16],
"type": "float32"
},
{
"dims": [8, 1, 8],
"type": "uint8",
"data": [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64
]
},
{
"dims": [1, 8],
"type": "float32",
"data": [0, 1, 2, 3, 4, 5, 6, 7]
}
],
"outputs": [
{
"dims": [1, 8, 8],
"type": "float32",
"data": [
0, -385, -1120, -963, -1984, -1285, -2592, -1351, 0, -1073, -3808, -2643, -6848, -3445, -9120, -3479, 0,
-1761, -6496, -4323, -11712, -5605, -15648, -5607, 0, -2449, -9184, -6003, -16576, -7765, -22176, -7735,
0, -3137, -11872, -7683, -21440, -9925, -28704, -9863, 0, -3825, -14560, -9363, -26304, -12085, -35232,
-11991, 0, -4513, -17248, -11043, -31168, -14245, -41760, -14119, 0, -5201, -19936, -12723, -36032,
-16405, -48288, -16247
]
}
]
}
]
},
{
"name": "MatMulNBits; K=16, N=8, block_size=16, bits=4, batchDim = [1, 2]",
"operator": "MatMulNBits",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [
{ "name": "K", "data": 16, "type": "int" },
{ "name": "N", "data": 8, "type": "int" },
{ "name": "block_size", "data": 16, "type": "int" },
{ "name": "bits", "data": 4, "type": "int" }
],
"cases": [
{
"name": "MatMulNBits; K=16, N=8, block_size=16, bits=4; symmetric, batchDim = [1, 2]",
"inputs": [
{
"data": [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105,
106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126,
127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147,
148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168,
169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189,
190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210,
211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231,
232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252,
253, 254, 255
],
"dims": [1, 2, 8, 16],
"type": "float32"
},
{
"dims": [8, 1, 8],
"type": "uint8",
"data": [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64
]
},
{
"dims": [1, 8],
"type": "float32",
"data": [0, 1, 2, 3, 4, 5, 6, 7]
}
],
"outputs": [
{
"dims": [1, 2, 8, 8],
"type": "float32",
"data": [
0, -385, -1120, -963, -1984, -1285, -2592, -1351, 0, -1073, -3808, -2643, -6848, -3445, -9120, -3479, 0,
-1761, -6496, -4323, -11712, -5605, -15648, -5607, 0, -2449, -9184, -6003, -16576, -7765, -22176, -7735,
0, -3137, -11872, -7683, -21440, -9925, -28704, -9863, 0, -3825, -14560, -9363, -26304, -12085, -35232,
-11991, 0, -4513, -17248, -11043, -31168, -14245, -41760, -14119, 0, -5201, -19936, -12723, -36032,
-16405, -48288, -16247, 0, -5889, -22624, -14403, -40896, -18565, -54816, -18375, 0, -6577, -25312,
-16083, -45760, -20725, -61344, -20503, 0, -7265, -28000, -17763, -50624, -22885, -67872, -22631, 0,
-7953, -30688, -19443, -55488, -25045, -74400, -24759, 0, -8641, -33376, -21123, -60352, -27205, -80928,
-26887, 0, -9329, -36064, -22803, -65216, -29365, -87456, -29015, 0, -10017, -38752, -24483, -70080,
-31525, -93984, -31143, 0, -10705, -41440, -26163, -74944, -33685, -100512, -33271
]
}
]
}
]
},
{
"name": "MatMulNBits; output shape = 8 X 16; K=16, N=16, block_size=16, bits=4",
"operator": "MatMulNBits",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [
{ "name": "K", "data": 16, "type": "int" },
{ "name": "N", "data": 16, "type": "int" },
{ "name": "block_size", "data": 16, "type": "int" },
{ "name": "bits", "data": 4, "type": "int" }
],
"cases": [
{
"name": "MatMulNBits; K=16, N=16, block_size=16, bits=4; asymmetric",
"inputs": [
{
"data": [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105,
106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126,
127
],
"dims": [8, 16],
"type": "float32"
},
{
"dims": [16, 1, 8],
"type": "uint8",
"data": [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105,
106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126,
127
]
},
{
"dims": [16],
"type": "float32",
"data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
},
{
"dims": [16],
"type": "uint8",
"data": [128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128]
}
],
"outputs": [
{
"dims": [8, 16],
"type": "float32",
"data": [
0, 728, 688, 2376, 1632, 4280, 2832, 6440, 4288, 8856, 6000, 11528, 7968, 14456, 10192, 17640, 0, 2200,
1840, 7176, 4448, 12920, 7824, 19432, 11968, 26712, 16880, 34760, 22560, 43576, 29008, 53160, 0, 3672,
2992, 11976, 7264, 21560, 12816, 32424, 19648, 44568, 27760, 57992, 37152, 72696, 47824, 88680, 0, 5144,
4144, 16776, 10080, 30200, 17808, 45416, 27328, 62424, 38640, 81224, 51744, 101816, 66640, 124200, 0,
6616, 5296, 21576, 12896, 38840, 22800, 58408, 35008, 80280, 49520, 104456, 66336, 130936, 85456, 159720,
0, 8088, 6448, 26376, 15712, 47480, 27792, 71400, 42688, 98136, 60400, 127688, 80928, 160056, 104272,
195240, 0, 9560, 7600, 31176, 18528, 56120, 32784, 84392, 50368, 115992, 71280, 150920, 95520, 189176,
123088, 230760, 0, 11032, 8752, 35976, 21344, 64760, 37776, 97384, 58048, 133848, 82160, 174152, 110112,
218296, 141904, 266280
]
}
]
}
]
},
{
"name": "MatMulNBits; output shape = 16 X 8; K=16, N=8, block_size=16, bits=4",
"operator": "MatMulNBits",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [
{ "name": "K", "data": 16, "type": "int" },
{ "name": "N", "data": 8, "type": "int" },
{ "name": "block_size", "data": 16, "type": "int" },
{ "name": "bits", "data": 4, "type": "int" }
],
"cases": [
{
"name": "MatMulNBits; K=16, N=8, block_size=16, bits=4; symmetric",
"inputs": [
{
"data": [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105,
106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126,
127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147,
148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168,
169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189,
190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210,
211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231,
232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252,
253, 254, 255
],
"dims": [16, 16],
"type": "float32"
},
{
"dims": [8, 1, 8],
"type": "uint8",
"data": [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64
]
},
{
"dims": [8],
"type": "float32",
"data": [0, 1, 2, 3, 4, 5, 6, 7]
}
],
"outputs": [
{
"dims": [16, 8],
"type": "float32",
"data": [
0, -385, -1120, -963, -1984, -1285, -2592, -1351, 0, -1073, -3808, -2643, -6848, -3445, -9120, -3479, 0,
-1761, -6496, -4323, -11712, -5605, -15648, -5607, 0, -2449, -9184, -6003, -16576, -7765, -22176, -7735,
0, -3137, -11872, -7683, -21440, -9925, -28704, -9863, 0, -3825, -14560, -9363, -26304, -12085, -35232,
-11991, 0, -4513, -17248, -11043, -31168, -14245, -41760, -14119, 0, -5201, -19936, -12723, -36032,
-16405, -48288, -16247, 0, -5889, -22624, -14403, -40896, -18565, -54816, -18375, 0, -6577, -25312,
-16083, -45760, -20725, -61344, -20503, 0, -7265, -28000, -17763, -50624, -22885, -67872, -22631, 0,
-7953, -30688, -19443, -55488, -25045, -74400, -24759, 0, -8641, -33376, -21123, -60352, -27205, -80928,
-26887, 0, -9329, -36064, -22803, -65216, -29365, -87456, -29015, 0, -10017, -38752, -24483, -70080,
-31525, -93984, -31143, 0, -10705, -41440, -26163, -74944, -33685, -100512, -33271
]
}
]
}
]
}
]