[js/web] Add support for int4/uint4 tensor (#21720)

### Description
Add support for int4/uint4 tensor.
This commit is contained in:
Yulong Wang 2024-08-15 21:32:10 -07:00 committed by GitHub
parent d4d0bea1fb
commit ef2ccc477b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 206 additions and 84 deletions

View file

@ -28,6 +28,8 @@ export const NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP = new Map<string, SupportedTy
['bool', Uint8Array],
['float64', Float64Array],
['uint32', Uint32Array],
['int4', Uint8Array],
['uint4', Uint8Array],
]);
// a runtime map that maps type string to TypedArray constructor. Should match Tensor.DataTypeMap.

View file

@ -180,14 +180,19 @@ export class Tensor implements TensorInterface {
throw new TypeError(`Unsupported tensor type: ${arg0}.`);
}
if (Array.isArray(arg1)) {
if (arg0 === 'float16' && typedArrayConstructor === Uint16Array) {
// When no Float16Array polyfill is used, we cannot create 'float16' tensor from number array.
if ((arg0 === 'float16' && typedArrayConstructor === Uint16Array) || arg0 === 'uint4' || arg0 === 'int4') {
// - 'float16':
// When no Float16Array polyfill is used, we cannot create 'float16' tensor from number array.
//
// Throw error here because when user try to use number array as data,
// e.g. new Tensor('float16', [1, 2, 3, 4], dims)), it will actually call
// Uint16Array.from(arg1) which generates wrong data.
//
// - 'uint4' and 'int4':
// Uint8Array.from(arg1) will generate wrong data for 'uint4' and 'int4' tensor.
//
// Throw error here because when user try to use number array as data,
// e.g. new Tensor('float16', [1, 2, 3, 4], dims)), it will actually call
// Uint16Array.from(arg1) which generates wrong data.
throw new TypeError(
'Creating a float16 tensor from number array is not supported. Please use Uint16Array as data.',
`Creating a ${arg0} tensor from number array is not supported. Please use ${typedArrayConstructor.name} as data.`,
);
} else if (arg0 === 'uint64' || arg0 === 'int64') {
// use 'as any' here because:
@ -266,7 +271,11 @@ export class Tensor implements TensorInterface {
const size = calculateSize(dims);
// if data is on CPU, check whether data length matches tensor size
if (this.cpuData && size !== this.cpuData.length) {
throw new Error(`Tensor's size(${size}) does not match data length(${this.cpuData.length}).`);
if ((type === 'uint4' || type === 'int4') && Math.ceil(size / 2) === this.cpuData.length) {
// for (u)int4, the data length is half of the tensor size. So we check this special case when size is odd.
} else {
throw new Error(`Tensor's size(${size}) does not match data length(${this.cpuData.length}).`);
}
}
this.type = type;

View file

@ -81,6 +81,8 @@ export declare namespace Tensor {
// complex64: never;
// complex128: never;
// bfloat16: never;
uint4: Uint8Array;
int4: Int8Array;
}
interface ElementTypeMap {
@ -100,6 +102,8 @@ export declare namespace Tensor {
// complex64: never;
// complex128: never;
// bfloat16: never;
uint4: number;
int4: number;
}
type DataType = DataTypeMap[Type];

View file

@ -4,7 +4,7 @@
import { Env } from 'onnxruntime-common';
import type { OrtWasmModule } from '../wasm-types';
import { DataType, getTensorElementSize } from '../wasm-common';
import { DataType, calculateTensorSizeInBytes } from '../wasm-common';
import { WebGpuBackend } from './backend-webgpu';
import { LOG_DEBUG } from './log';
@ -122,11 +122,10 @@ class ComputeContextImpl implements ComputeContext {
const createKernelOutput = (index: number, dataType: number, dims: readonly number[]): TensorView =>
new TensorViewImpl(this.module, dataType, this.output(index, dims), dims);
const createTemporaryOutput = (dataType: number, dims: readonly number[]): TensorView => {
const elementSize = getTensorElementSize(dataType);
if (!elementSize) {
const bufferSize = calculateTensorSizeInBytes(dataType, dims);
if (!bufferSize) {
throw new Error(`Unsupported data type: ${dataType}`);
}
const bufferSize = elementSize * ShapeUtil.size(dims);
const gpuDataId = bufferSize > 0 ? this.backend.gpuDataManager.create(bufferSize).id : 0;
return new TensorViewImpl(this.module, dataType, gpuDataId, dims);
};
@ -245,9 +244,7 @@ export const init = async (
LOG_DEBUG(
'verbose',
() =>
`[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${
contextDataOffset
}`,
`[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${contextDataOffset}`,
);
const context = new ComputeContextImpl(module, backend, contextDataOffset);
return backend.computeKernel(kernel, context, errors);

View file

@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import { DataType, getTensorElementSize } from '../../../wasm-common';
import { calculateTensorSizeInBytes, DataType } from '../../../wasm-common';
import { TensorView } from '../../tensor-view';
import { ShapeUtil } from '../../util';
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
@ -77,8 +77,7 @@ export const createMatMulNBitsProgramInfo = (
const outputNumber = getMaxComponents(dimAOuter);
const aComponents = getMaxComponents(attributes.k);
const bComponents = getMaxComponents(blobSizeInWords);
const elementSize = getTensorElementSize(dataType)!;
const workgroupOutputSize = dimAOuter * nBlocksPerCol * elementSize;
const workgroupOutputSize = calculateTensorSizeInBytes(dataType, dimAOuter * nBlocksPerCol)!;
const maxNumberOfComponents = Math.floor(maxComputeWorkgroupStorageSize / workgroupOutputSize);
const useBlockwiseMatMulNBits = nBlocksPerCol <= maxComputeWorkgroupSizes[0] && maxNumberOfComponents > 0;
const components =

View file

@ -32,6 +32,10 @@ export const enum DataType {
complex64 = 14,
complex128 = 15,
bfloat16 = 16,
// 4-bit data-types
uint4 = 21,
int4 = 22,
}
/**
@ -65,6 +69,10 @@ export const tensorDataTypeStringToEnum = (type: string): DataType => {
return DataType.int64;
case 'uint64':
return DataType.uint64;
case 'int4':
return DataType.int4;
case 'uint4':
return DataType.uint4;
default:
throw new Error(`unsupported data type: ${type}`);
@ -102,6 +110,10 @@ export const tensorDataTypeEnumToString = (typeProto: DataType): Tensor.Type =>
return 'int64';
case DataType.uint64:
return 'uint64';
case DataType.int4:
return 'int4';
case DataType.uint4:
return 'uint4';
default:
throw new Error(`unsupported data type: ${typeProto}`);
@ -109,11 +121,42 @@ export const tensorDataTypeEnumToString = (typeProto: DataType): Tensor.Type =>
};
/**
* get tensor element size in bytes by the given data type
* get tensor size in bytes by the given data type and dimensions
* @returns size in integer or undefined if the data type is not supported
*/
export const getTensorElementSize = (dateType: number): number | undefined =>
[undefined, 4, 1, 1, 2, 2, 4, 8, undefined, 1, 2, 8, 4, 8, undefined, undefined, undefined][dateType];
export const calculateTensorSizeInBytes = (
dateType: number,
dimsOrSize: readonly number[] | number,
): number | undefined => {
const elementSize = [
-1, // undefined = 0
4, // float = 1
1, // uint8 = 2
1, // int8 = 3
2, // uint16 = 4
2, // int16 = 5
4, // int32 = 6
8, // int64 = 7
-1, // string = 8
1, // bool = 9
2, // float16 = 10
8, // double = 11
4, // uint32 = 12
8, // uint64 = 13
-1, // complex64 = 14
-1, // complex128 = 15
-1, // bfloat16 = 16
-1, // FLOAT8E4M3FN = 17
-1, // FLOAT8E4M3FNUZ = 18
-1, // FLOAT8E5M2 = 19
-1, // FLOAT8E5M2FNUZ = 20
0.5, // uint4 = 21
0.5, // int4 = 22
][dateType];
const size = typeof dimsOrSize === 'number' ? dimsOrSize : dimsOrSize.reduce((a, b) => a * b, 1);
return elementSize > 0 ? Math.ceil(size * elementSize) : undefined;
};
/**
* get typed array constructor by the given tensor type

View file

@ -17,8 +17,8 @@ import {
import { setRunOptions } from './run-options';
import { setSessionOptions } from './session-options';
import {
calculateTensorSizeInBytes,
dataLocationStringToEnum,
getTensorElementSize,
isGpuBufferSupportedType,
logLevelStringToEnum,
tensorDataTypeEnumToString,
@ -360,9 +360,7 @@ export const createSession = async (
}
if (enableGraphCapture && location !== 'gpu-buffer') {
throw new Error(
`Not supported preferred output location: ${
location
}. Only 'gpu-buffer' location is supported when enableGraphCapture is true.`,
`Not supported preferred output location: ${location}. Only 'gpu-buffer' location is supported when enableGraphCapture is true.`,
);
}
outputPreferredLocations.push(location);
@ -474,8 +472,7 @@ export const prepareInputOutputTensor = (
if (location === 'gpu-buffer') {
const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer;
const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!;
dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes;
dataByteLength = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(dataType), dims)!;
const registerBuffer = wasm.jsepRegisterBuffer;
if (!registerBuffer) {
@ -611,9 +608,7 @@ export const run = async (
if (inputNamesUTF8Encoded.length !== inputCount) {
throw new Error(
`input count from feeds (${
inputCount
}) is expected to be always equal to model's input count (${inputNamesUTF8Encoded.length}).`,
`input count from feeds (${inputCount}) is expected to be always equal to model's input count (${inputNamesUTF8Encoded.length}).`,
);
}
@ -752,8 +747,8 @@ export const run = async (
throw new Error('preferredLocation "gpu-buffer" is not supported without using WebGPU.');
}
const gpuBuffer = getBuffer(dataOffset);
const elementSize = getTensorElementSize(dataType);
if (elementSize === undefined || !isGpuBufferSupportedType(type)) {
const bufferSize = calculateTensorSizeInBytes(dataType, size);
if (bufferSize === undefined || !isGpuBufferSupportedType(type)) {
throw new Error(`Unsupported data type: ${type}`);
}
@ -765,7 +760,7 @@ export const run = async (
dims,
{
gpuBuffer,
download: wasm.jsepCreateDownloader!(gpuBuffer, size * elementSize, type),
download: wasm.jsepCreateDownloader!(gpuBuffer, bufferSize, type),
dispose: () => {
wasm._OrtReleaseTensor(tensor);
},

View file

@ -0,0 +1,72 @@
[
{
"name": "DequantizeLinear int4",
"opset": { "domain": "", "version": 21 },
"operator": "DequantizeLinear",
"attributes": [{ "name": "axis", "data": 0, "type": "int" }],
"cases": [
{
"name": "T[2,3]",
"inputs": [
{
"data": [0, 1, 7, -4, -8],
"dims": [5],
"type": "int4"
},
{
"data": [2],
"dims": [],
"type": "float32"
},
{
"data": [1],
"dims": [1],
"type": "int4"
}
],
"outputs": [
{
"data": [-2, 0, 12, -10, -18],
"dims": [5],
"type": "float32"
}
]
}
]
},
{
"name": "DequantizeLinear uint4",
"opset": { "domain": "", "version": 21 },
"operator": "DequantizeLinear",
"attributes": [{ "name": "axis", "data": 0, "type": "int" }],
"cases": [
{
"name": "T[2,3]",
"inputs": [
{
"data": [0, 1, 7, 10, 15],
"dims": [5],
"type": "uint4"
},
{
"data": [2],
"dims": [],
"type": "float32"
},
{
"data": [1],
"dims": [1],
"type": "uint4"
}
],
"outputs": [
{
"data": [-2, 0, 12, 18, 28],
"dims": [5],
"type": "float32"
}
]
}
]
}
]

View file

@ -189,7 +189,9 @@
"uint32",
"uint64",
"bool",
"string"
"string",
"int4",
"uint4"
]
},
"data": {
@ -226,7 +228,9 @@
"uint32",
"uint64",
"bool",
"string"
"string",
"int4",
"uint4"
]
},
"data": {
@ -261,7 +265,9 @@
"uint32",
"uint64",
"bool",
"string"
"string",
"int4",
"uint4"
]
},
"data": {
@ -298,7 +304,9 @@
"uint32",
"uint64",
"bool",
"string"
"string",
"int4",
"uint4"
]
},
"data": {

View file

@ -16,7 +16,11 @@ import { onnx } from '../lib/onnxjs/ort-schema/protobuf/onnx';
import { Tensor } from '../lib/onnxjs/tensor';
import { ProtoUtil } from '../lib/onnxjs/util';
import { createView } from '../lib/wasm/jsep/tensor-view';
import { getTensorElementSize, isGpuBufferSupportedType, tensorDataTypeStringToEnum } from '../lib/wasm/wasm-common';
import {
calculateTensorSizeInBytes,
isGpuBufferSupportedType,
tensorDataTypeStringToEnum,
} from '../lib/wasm/wasm-common';
import { base64toBuffer, createMockGraph, readFile } from './test-shared';
import { Test } from './test-types';
@ -372,9 +376,7 @@ export class TensorResultValidator {
if (!match) {
Logger.error(
'TestRunner',
`Tensor mismatch: \nACTUAL: type=${actual[i].type}; dims=[${actual[i].dims}]; data=[${
actual[i].data
}]\nEXPECT: type=${expected[i].type}; dims=[${expected[i].dims}]; data=[${expected[i].data}]`,
`Tensor mismatch: \nACTUAL: type=${actual[i].type}; dims=[${actual[i].dims}]; data=[${actual[i].data}]\nEXPECT: type=${expected[i].type}; dims=[${expected[i].dims}]; data=[${expected[i].data}]`,
);
}
expect(match, 'tensor data should match').to.be.true;
@ -462,6 +464,8 @@ export class TensorResultValidator {
case 'uint32':
case 'int64':
case 'bool':
case 'int4':
case 'uint4':
return TensorResultValidator.integerEqual(
actual.data as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array,
expected.data as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array,
@ -586,8 +590,7 @@ function createGpuTensorForOutput(type: ort.Tensor.Type, dims: readonly number[]
throw new Error(`createGpuTensorForOutput can not work with ${type} tensor`);
}
const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(type))!;
const size = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes;
const size = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(type), dims)!;
const device = ort.env.webgpu.device as GPUDevice;
const gpuBuffer = device.createBuffer({
@ -852,22 +855,14 @@ export class ProtoOpTestContext {
for (let i = 0; i < inputCount; i++) {
if (inputsOmitted[i] !== !testCase.inputs![i].data) {
throw new Error(
`Test cases for test: ${test.name} [${
test.operator
}] must have consistent inputs data availability. Data of input[${i}] in testCase #0 and #${
caseIndex
} should be both available or both omitted.`,
`Test cases for test: ${test.name} [${test.operator}] must have consistent inputs data availability. Data of input[${i}] in testCase #0 and #${caseIndex} should be both available or both omitted.`,
);
}
}
for (let i = 0; i < outputCount; i++) {
if (outputsOmitted[i] !== !testCase.outputs![i].data) {
throw new Error(
`Test cases for test: ${test.name} [${
test.operator
}] must have consistent outputs data availability. Data of output[${
i
}] in testCase #0 and #${caseIndex} should be both available or both omitted.`,
`Test cases for test: ${test.name} [${test.operator}] must have consistent outputs data availability. Data of output[${i}] in testCase #0 and #${caseIndex} should be both available or both omitted.`,
);
}
}
@ -898,9 +893,7 @@ export class ProtoOpTestContext {
// check if all test cases have data
if (test.cases.some((testCase) => testCase.inputs!.some((input) => !input.data || !input.dims))) {
throw new Error(
`Test cases for test: ${test.name} [${
test.operator
}] must have data for each inputs when inputShapeDefinitions is 'rankOnly'`,
`Test cases for test: ${test.name} [${test.operator}] must have data for each inputs when inputShapeDefinitions is 'rankOnly'`,
);
}
@ -919,18 +912,14 @@ export class ProtoOpTestContext {
)
) {
throw new Error(
`Test cases for test: ${test.name} [${
test.operator
}] must have the same rank for each inputs in different test cases`,
`Test cases for test: ${test.name} [${test.operator}] must have the same rank for each inputs in different test cases`,
);
}
} else if (test.inputShapeDefinitions === 'static') {
// check if all test cases have data
if (test.cases.some((testCase) => testCase.inputs!.some((input) => !input.data || !input.dims))) {
throw new Error(
`Test cases for test: ${test.name} [${
test.operator
}] must have data for each inputs when inputShapeDefinitions is 'rankOnly'`,
`Test cases for test: ${test.name} [${test.operator}] must have data for each inputs when inputShapeDefinitions is 'rankOnly'`,
);
}
@ -946,9 +935,7 @@ export class ProtoOpTestContext {
)
) {
throw new Error(
`Test cases for test: ${test.name} [${
test.operator
}] must have the same shape for each inputs in different test cases`,
`Test cases for test: ${test.name} [${test.operator}] must have the same shape for each inputs in different test cases`,
);
}
} else {
@ -1033,18 +1020,33 @@ async function runProtoOpTestcase(
): Promise<void> {
const feeds: Record<string, ort.Tensor> = {};
const fetches: Record<string, Pick<ort.Tensor, 'dims' | 'type'>> = {};
const createTensor = (type: ort.Tensor.Type, data: number[], dims: readonly number[]): ort.Tensor => {
let buffer: number[] | BigUint64Array | BigInt64Array | Uint16Array | Uint8Array = data;
if (type === 'uint64') {
buffer = BigUint64Array.from(data.map(BigInt));
} else if (type === 'int64') {
buffer = BigInt64Array.from(data.map(BigInt));
} else if (type === 'float16') {
const dataArr = Float16ArrayPolyfill.from(data);
buffer = new Uint16Array(dataArr.buffer, dataArr.byteOffset, dataArr.byteLength / 2);
} else if (type === 'uint4' || type === 'int4') {
buffer = new Uint8Array(calculateTensorSizeInBytes(tensorDataTypeStringToEnum(type), dims)!);
// encode (u)int4 data into Uint8Array
for (let j = 0; j < data.length; j++) {
/* eslint-disable no-bitwise */
const byteIndex = j >> 1;
const bitOffset = (j & 1) << 2;
buffer[byteIndex] |= data[j] << bitOffset;
/* eslint-enable no-bitwise */
}
}
return new ort.Tensor(type, buffer, dims);
};
testCase.inputs.forEach((input, i) => {
if (input.data) {
let data: number[] | BigUint64Array | BigInt64Array | Uint16Array = input.data;
if (input.type === 'uint64') {
data = BigUint64Array.from(input.data.map(BigInt));
} else if (input.type === 'int64') {
data = BigInt64Array.from(input.data.map(BigInt));
} else if (input.type === 'float16') {
const dataArr = Float16ArrayPolyfill.from(input.data);
data = new Uint16Array(dataArr.buffer, dataArr.byteOffset, dataArr.byteLength / 2);
}
feeds[`input_${i}`] = new ort.Tensor(input.type, data, input.dims);
feeds[`input_${i}`] = createTensor(input.type, input.data, input.dims);
}
});
@ -1052,16 +1054,7 @@ async function runProtoOpTestcase(
const expectedOutputNames: string[] = [];
testCase.outputs.forEach((output, i) => {
if (output.data) {
let data: number[] | BigUint64Array | BigInt64Array | Uint16Array = output.data;
if (output.type === 'uint64') {
data = BigUint64Array.from(output.data.map(BigInt));
} else if (output.type === 'int64') {
data = BigInt64Array.from(output.data.map(BigInt));
} else if (output.type === 'float16') {
const dataArr = Float16ArrayPolyfill.from(output.data);
data = new Uint16Array(dataArr.buffer, dataArr.byteOffset, dataArr.byteLength / 2);
}
outputs.push(new ort.Tensor(output.type, data, output.dims));
outputs.push(createTensor(output.type, output.data, output.dims));
expectedOutputNames.push(`output_${i}`);
fetches[`output_${i}`] = { dims: output.dims, type: output.type };
}