mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
[js/web] Add support for int4/uint4 tensor (#21720)
### Description Add support for int4/uint4 tensor.
This commit is contained in:
parent
d4d0bea1fb
commit
ef2ccc477b
10 changed files with 206 additions and 84 deletions
|
|
@ -28,6 +28,8 @@ export const NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP = new Map<string, SupportedTy
|
|||
['bool', Uint8Array],
|
||||
['float64', Float64Array],
|
||||
['uint32', Uint32Array],
|
||||
['int4', Uint8Array],
|
||||
['uint4', Uint8Array],
|
||||
]);
|
||||
|
||||
// a runtime map that maps type string to TypedArray constructor. Should match Tensor.DataTypeMap.
|
||||
|
|
|
|||
|
|
@ -180,14 +180,19 @@ export class Tensor implements TensorInterface {
|
|||
throw new TypeError(`Unsupported tensor type: ${arg0}.`);
|
||||
}
|
||||
if (Array.isArray(arg1)) {
|
||||
if (arg0 === 'float16' && typedArrayConstructor === Uint16Array) {
|
||||
// When no Float16Array polyfill is used, we cannot create 'float16' tensor from number array.
|
||||
if ((arg0 === 'float16' && typedArrayConstructor === Uint16Array) || arg0 === 'uint4' || arg0 === 'int4') {
|
||||
// - 'float16':
|
||||
// When no Float16Array polyfill is used, we cannot create 'float16' tensor from number array.
|
||||
//
|
||||
// Throw error here because when user try to use number array as data,
|
||||
// e.g. new Tensor('float16', [1, 2, 3, 4], dims)), it will actually call
|
||||
// Uint16Array.from(arg1) which generates wrong data.
|
||||
//
|
||||
// - 'uint4' and 'int4':
|
||||
// Uint8Array.from(arg1) will generate wrong data for 'uint4' and 'int4' tensor.
|
||||
//
|
||||
// Throw error here because when user try to use number array as data,
|
||||
// e.g. new Tensor('float16', [1, 2, 3, 4], dims)), it will actually call
|
||||
// Uint16Array.from(arg1) which generates wrong data.
|
||||
throw new TypeError(
|
||||
'Creating a float16 tensor from number array is not supported. Please use Uint16Array as data.',
|
||||
`Creating a ${arg0} tensor from number array is not supported. Please use ${typedArrayConstructor.name} as data.`,
|
||||
);
|
||||
} else if (arg0 === 'uint64' || arg0 === 'int64') {
|
||||
// use 'as any' here because:
|
||||
|
|
@ -266,7 +271,11 @@ export class Tensor implements TensorInterface {
|
|||
const size = calculateSize(dims);
|
||||
// if data is on CPU, check whether data length matches tensor size
|
||||
if (this.cpuData && size !== this.cpuData.length) {
|
||||
throw new Error(`Tensor's size(${size}) does not match data length(${this.cpuData.length}).`);
|
||||
if ((type === 'uint4' || type === 'int4') && Math.ceil(size / 2) === this.cpuData.length) {
|
||||
// for (u)int4, the data length is half of the tensor size. So we check this special case when size is odd.
|
||||
} else {
|
||||
throw new Error(`Tensor's size(${size}) does not match data length(${this.cpuData.length}).`);
|
||||
}
|
||||
}
|
||||
|
||||
this.type = type;
|
||||
|
|
|
|||
|
|
@ -81,6 +81,8 @@ export declare namespace Tensor {
|
|||
// complex64: never;
|
||||
// complex128: never;
|
||||
// bfloat16: never;
|
||||
uint4: Uint8Array;
|
||||
int4: Int8Array;
|
||||
}
|
||||
|
||||
interface ElementTypeMap {
|
||||
|
|
@ -100,6 +102,8 @@ export declare namespace Tensor {
|
|||
// complex64: never;
|
||||
// complex128: never;
|
||||
// bfloat16: never;
|
||||
uint4: number;
|
||||
int4: number;
|
||||
}
|
||||
|
||||
type DataType = DataTypeMap[Type];
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
import { Env } from 'onnxruntime-common';
|
||||
|
||||
import type { OrtWasmModule } from '../wasm-types';
|
||||
import { DataType, getTensorElementSize } from '../wasm-common';
|
||||
import { DataType, calculateTensorSizeInBytes } from '../wasm-common';
|
||||
|
||||
import { WebGpuBackend } from './backend-webgpu';
|
||||
import { LOG_DEBUG } from './log';
|
||||
|
|
@ -122,11 +122,10 @@ class ComputeContextImpl implements ComputeContext {
|
|||
const createKernelOutput = (index: number, dataType: number, dims: readonly number[]): TensorView =>
|
||||
new TensorViewImpl(this.module, dataType, this.output(index, dims), dims);
|
||||
const createTemporaryOutput = (dataType: number, dims: readonly number[]): TensorView => {
|
||||
const elementSize = getTensorElementSize(dataType);
|
||||
if (!elementSize) {
|
||||
const bufferSize = calculateTensorSizeInBytes(dataType, dims);
|
||||
if (!bufferSize) {
|
||||
throw new Error(`Unsupported data type: ${dataType}`);
|
||||
}
|
||||
const bufferSize = elementSize * ShapeUtil.size(dims);
|
||||
const gpuDataId = bufferSize > 0 ? this.backend.gpuDataManager.create(bufferSize).id : 0;
|
||||
return new TensorViewImpl(this.module, dataType, gpuDataId, dims);
|
||||
};
|
||||
|
|
@ -245,9 +244,7 @@ export const init = async (
|
|||
LOG_DEBUG(
|
||||
'verbose',
|
||||
() =>
|
||||
`[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${
|
||||
contextDataOffset
|
||||
}`,
|
||||
`[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${contextDataOffset}`,
|
||||
);
|
||||
const context = new ComputeContextImpl(module, backend, contextDataOffset);
|
||||
return backend.computeKernel(kernel, context, errors);
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import { DataType, getTensorElementSize } from '../../../wasm-common';
|
||||
import { calculateTensorSizeInBytes, DataType } from '../../../wasm-common';
|
||||
import { TensorView } from '../../tensor-view';
|
||||
import { ShapeUtil } from '../../util';
|
||||
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
|
||||
|
|
@ -77,8 +77,7 @@ export const createMatMulNBitsProgramInfo = (
|
|||
const outputNumber = getMaxComponents(dimAOuter);
|
||||
const aComponents = getMaxComponents(attributes.k);
|
||||
const bComponents = getMaxComponents(blobSizeInWords);
|
||||
const elementSize = getTensorElementSize(dataType)!;
|
||||
const workgroupOutputSize = dimAOuter * nBlocksPerCol * elementSize;
|
||||
const workgroupOutputSize = calculateTensorSizeInBytes(dataType, dimAOuter * nBlocksPerCol)!;
|
||||
const maxNumberOfComponents = Math.floor(maxComputeWorkgroupStorageSize / workgroupOutputSize);
|
||||
const useBlockwiseMatMulNBits = nBlocksPerCol <= maxComputeWorkgroupSizes[0] && maxNumberOfComponents > 0;
|
||||
const components =
|
||||
|
|
|
|||
|
|
@ -32,6 +32,10 @@ export const enum DataType {
|
|||
complex64 = 14,
|
||||
complex128 = 15,
|
||||
bfloat16 = 16,
|
||||
|
||||
// 4-bit data-types
|
||||
uint4 = 21,
|
||||
int4 = 22,
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -65,6 +69,10 @@ export const tensorDataTypeStringToEnum = (type: string): DataType => {
|
|||
return DataType.int64;
|
||||
case 'uint64':
|
||||
return DataType.uint64;
|
||||
case 'int4':
|
||||
return DataType.int4;
|
||||
case 'uint4':
|
||||
return DataType.uint4;
|
||||
|
||||
default:
|
||||
throw new Error(`unsupported data type: ${type}`);
|
||||
|
|
@ -102,6 +110,10 @@ export const tensorDataTypeEnumToString = (typeProto: DataType): Tensor.Type =>
|
|||
return 'int64';
|
||||
case DataType.uint64:
|
||||
return 'uint64';
|
||||
case DataType.int4:
|
||||
return 'int4';
|
||||
case DataType.uint4:
|
||||
return 'uint4';
|
||||
|
||||
default:
|
||||
throw new Error(`unsupported data type: ${typeProto}`);
|
||||
|
|
@ -109,11 +121,42 @@ export const tensorDataTypeEnumToString = (typeProto: DataType): Tensor.Type =>
|
|||
};
|
||||
|
||||
/**
|
||||
* get tensor element size in bytes by the given data type
|
||||
* get tensor size in bytes by the given data type and dimensions
|
||||
* @returns size in integer or undefined if the data type is not supported
|
||||
*/
|
||||
export const getTensorElementSize = (dateType: number): number | undefined =>
|
||||
[undefined, 4, 1, 1, 2, 2, 4, 8, undefined, 1, 2, 8, 4, 8, undefined, undefined, undefined][dateType];
|
||||
export const calculateTensorSizeInBytes = (
|
||||
dateType: number,
|
||||
dimsOrSize: readonly number[] | number,
|
||||
): number | undefined => {
|
||||
const elementSize = [
|
||||
-1, // undefined = 0
|
||||
4, // float = 1
|
||||
1, // uint8 = 2
|
||||
1, // int8 = 3
|
||||
2, // uint16 = 4
|
||||
2, // int16 = 5
|
||||
4, // int32 = 6
|
||||
8, // int64 = 7
|
||||
-1, // string = 8
|
||||
1, // bool = 9
|
||||
2, // float16 = 10
|
||||
8, // double = 11
|
||||
4, // uint32 = 12
|
||||
8, // uint64 = 13
|
||||
-1, // complex64 = 14
|
||||
-1, // complex128 = 15
|
||||
-1, // bfloat16 = 16
|
||||
-1, // FLOAT8E4M3FN = 17
|
||||
-1, // FLOAT8E4M3FNUZ = 18
|
||||
-1, // FLOAT8E5M2 = 19
|
||||
-1, // FLOAT8E5M2FNUZ = 20
|
||||
0.5, // uint4 = 21
|
||||
0.5, // int4 = 22
|
||||
][dateType];
|
||||
|
||||
const size = typeof dimsOrSize === 'number' ? dimsOrSize : dimsOrSize.reduce((a, b) => a * b, 1);
|
||||
return elementSize > 0 ? Math.ceil(size * elementSize) : undefined;
|
||||
};
|
||||
|
||||
/**
|
||||
* get typed array constructor by the given tensor type
|
||||
|
|
|
|||
|
|
@ -17,8 +17,8 @@ import {
|
|||
import { setRunOptions } from './run-options';
|
||||
import { setSessionOptions } from './session-options';
|
||||
import {
|
||||
calculateTensorSizeInBytes,
|
||||
dataLocationStringToEnum,
|
||||
getTensorElementSize,
|
||||
isGpuBufferSupportedType,
|
||||
logLevelStringToEnum,
|
||||
tensorDataTypeEnumToString,
|
||||
|
|
@ -360,9 +360,7 @@ export const createSession = async (
|
|||
}
|
||||
if (enableGraphCapture && location !== 'gpu-buffer') {
|
||||
throw new Error(
|
||||
`Not supported preferred output location: ${
|
||||
location
|
||||
}. Only 'gpu-buffer' location is supported when enableGraphCapture is true.`,
|
||||
`Not supported preferred output location: ${location}. Only 'gpu-buffer' location is supported when enableGraphCapture is true.`,
|
||||
);
|
||||
}
|
||||
outputPreferredLocations.push(location);
|
||||
|
|
@ -474,8 +472,7 @@ export const prepareInputOutputTensor = (
|
|||
|
||||
if (location === 'gpu-buffer') {
|
||||
const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer;
|
||||
const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!;
|
||||
dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes;
|
||||
dataByteLength = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(dataType), dims)!;
|
||||
|
||||
const registerBuffer = wasm.jsepRegisterBuffer;
|
||||
if (!registerBuffer) {
|
||||
|
|
@ -611,9 +608,7 @@ export const run = async (
|
|||
|
||||
if (inputNamesUTF8Encoded.length !== inputCount) {
|
||||
throw new Error(
|
||||
`input count from feeds (${
|
||||
inputCount
|
||||
}) is expected to be always equal to model's input count (${inputNamesUTF8Encoded.length}).`,
|
||||
`input count from feeds (${inputCount}) is expected to be always equal to model's input count (${inputNamesUTF8Encoded.length}).`,
|
||||
);
|
||||
}
|
||||
|
||||
|
|
@ -752,8 +747,8 @@ export const run = async (
|
|||
throw new Error('preferredLocation "gpu-buffer" is not supported without using WebGPU.');
|
||||
}
|
||||
const gpuBuffer = getBuffer(dataOffset);
|
||||
const elementSize = getTensorElementSize(dataType);
|
||||
if (elementSize === undefined || !isGpuBufferSupportedType(type)) {
|
||||
const bufferSize = calculateTensorSizeInBytes(dataType, size);
|
||||
if (bufferSize === undefined || !isGpuBufferSupportedType(type)) {
|
||||
throw new Error(`Unsupported data type: ${type}`);
|
||||
}
|
||||
|
||||
|
|
@ -765,7 +760,7 @@ export const run = async (
|
|||
dims,
|
||||
{
|
||||
gpuBuffer,
|
||||
download: wasm.jsepCreateDownloader!(gpuBuffer, size * elementSize, type),
|
||||
download: wasm.jsepCreateDownloader!(gpuBuffer, bufferSize, type),
|
||||
dispose: () => {
|
||||
wasm._OrtReleaseTensor(tensor);
|
||||
},
|
||||
|
|
|
|||
72
js/web/test/data/ops/dequantize-linear_int4.jsonc
Normal file
72
js/web/test/data/ops/dequantize-linear_int4.jsonc
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
[
|
||||
{
|
||||
"name": "DequantizeLinear int4",
|
||||
"opset": { "domain": "", "version": 21 },
|
||||
"operator": "DequantizeLinear",
|
||||
"attributes": [{ "name": "axis", "data": 0, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[2,3]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [0, 1, 7, -4, -8],
|
||||
"dims": [5],
|
||||
"type": "int4"
|
||||
},
|
||||
{
|
||||
"data": [2],
|
||||
"dims": [],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1],
|
||||
"dims": [1],
|
||||
"type": "int4"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [-2, 0, 12, -10, -18],
|
||||
"dims": [5],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "DequantizeLinear uint4",
|
||||
"opset": { "domain": "", "version": 21 },
|
||||
"operator": "DequantizeLinear",
|
||||
"attributes": [{ "name": "axis", "data": 0, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[2,3]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [0, 1, 7, 10, 15],
|
||||
"dims": [5],
|
||||
"type": "uint4"
|
||||
},
|
||||
{
|
||||
"data": [2],
|
||||
"dims": [],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1],
|
||||
"dims": [1],
|
||||
"type": "uint4"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [-2, 0, 12, 18, 28],
|
||||
"dims": [5],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
|
@ -189,7 +189,9 @@
|
|||
"uint32",
|
||||
"uint64",
|
||||
"bool",
|
||||
"string"
|
||||
"string",
|
||||
"int4",
|
||||
"uint4"
|
||||
]
|
||||
},
|
||||
"data": {
|
||||
|
|
@ -226,7 +228,9 @@
|
|||
"uint32",
|
||||
"uint64",
|
||||
"bool",
|
||||
"string"
|
||||
"string",
|
||||
"int4",
|
||||
"uint4"
|
||||
]
|
||||
},
|
||||
"data": {
|
||||
|
|
@ -261,7 +265,9 @@
|
|||
"uint32",
|
||||
"uint64",
|
||||
"bool",
|
||||
"string"
|
||||
"string",
|
||||
"int4",
|
||||
"uint4"
|
||||
]
|
||||
},
|
||||
"data": {
|
||||
|
|
@ -298,7 +304,9 @@
|
|||
"uint32",
|
||||
"uint64",
|
||||
"bool",
|
||||
"string"
|
||||
"string",
|
||||
"int4",
|
||||
"uint4"
|
||||
]
|
||||
},
|
||||
"data": {
|
||||
|
|
|
|||
|
|
@ -16,7 +16,11 @@ import { onnx } from '../lib/onnxjs/ort-schema/protobuf/onnx';
|
|||
import { Tensor } from '../lib/onnxjs/tensor';
|
||||
import { ProtoUtil } from '../lib/onnxjs/util';
|
||||
import { createView } from '../lib/wasm/jsep/tensor-view';
|
||||
import { getTensorElementSize, isGpuBufferSupportedType, tensorDataTypeStringToEnum } from '../lib/wasm/wasm-common';
|
||||
import {
|
||||
calculateTensorSizeInBytes,
|
||||
isGpuBufferSupportedType,
|
||||
tensorDataTypeStringToEnum,
|
||||
} from '../lib/wasm/wasm-common';
|
||||
|
||||
import { base64toBuffer, createMockGraph, readFile } from './test-shared';
|
||||
import { Test } from './test-types';
|
||||
|
|
@ -372,9 +376,7 @@ export class TensorResultValidator {
|
|||
if (!match) {
|
||||
Logger.error(
|
||||
'TestRunner',
|
||||
`Tensor mismatch: \nACTUAL: type=${actual[i].type}; dims=[${actual[i].dims}]; data=[${
|
||||
actual[i].data
|
||||
}]\nEXPECT: type=${expected[i].type}; dims=[${expected[i].dims}]; data=[${expected[i].data}]`,
|
||||
`Tensor mismatch: \nACTUAL: type=${actual[i].type}; dims=[${actual[i].dims}]; data=[${actual[i].data}]\nEXPECT: type=${expected[i].type}; dims=[${expected[i].dims}]; data=[${expected[i].data}]`,
|
||||
);
|
||||
}
|
||||
expect(match, 'tensor data should match').to.be.true;
|
||||
|
|
@ -462,6 +464,8 @@ export class TensorResultValidator {
|
|||
case 'uint32':
|
||||
case 'int64':
|
||||
case 'bool':
|
||||
case 'int4':
|
||||
case 'uint4':
|
||||
return TensorResultValidator.integerEqual(
|
||||
actual.data as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array,
|
||||
expected.data as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array,
|
||||
|
|
@ -586,8 +590,7 @@ function createGpuTensorForOutput(type: ort.Tensor.Type, dims: readonly number[]
|
|||
throw new Error(`createGpuTensorForOutput can not work with ${type} tensor`);
|
||||
}
|
||||
|
||||
const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(type))!;
|
||||
const size = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes;
|
||||
const size = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(type), dims)!;
|
||||
|
||||
const device = ort.env.webgpu.device as GPUDevice;
|
||||
const gpuBuffer = device.createBuffer({
|
||||
|
|
@ -852,22 +855,14 @@ export class ProtoOpTestContext {
|
|||
for (let i = 0; i < inputCount; i++) {
|
||||
if (inputsOmitted[i] !== !testCase.inputs![i].data) {
|
||||
throw new Error(
|
||||
`Test cases for test: ${test.name} [${
|
||||
test.operator
|
||||
}] must have consistent inputs data availability. Data of input[${i}] in testCase #0 and #${
|
||||
caseIndex
|
||||
} should be both available or both omitted.`,
|
||||
`Test cases for test: ${test.name} [${test.operator}] must have consistent inputs data availability. Data of input[${i}] in testCase #0 and #${caseIndex} should be both available or both omitted.`,
|
||||
);
|
||||
}
|
||||
}
|
||||
for (let i = 0; i < outputCount; i++) {
|
||||
if (outputsOmitted[i] !== !testCase.outputs![i].data) {
|
||||
throw new Error(
|
||||
`Test cases for test: ${test.name} [${
|
||||
test.operator
|
||||
}] must have consistent outputs data availability. Data of output[${
|
||||
i
|
||||
}] in testCase #0 and #${caseIndex} should be both available or both omitted.`,
|
||||
`Test cases for test: ${test.name} [${test.operator}] must have consistent outputs data availability. Data of output[${i}] in testCase #0 and #${caseIndex} should be both available or both omitted.`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -898,9 +893,7 @@ export class ProtoOpTestContext {
|
|||
// check if all test cases have data
|
||||
if (test.cases.some((testCase) => testCase.inputs!.some((input) => !input.data || !input.dims))) {
|
||||
throw new Error(
|
||||
`Test cases for test: ${test.name} [${
|
||||
test.operator
|
||||
}] must have data for each inputs when inputShapeDefinitions is 'rankOnly'`,
|
||||
`Test cases for test: ${test.name} [${test.operator}] must have data for each inputs when inputShapeDefinitions is 'rankOnly'`,
|
||||
);
|
||||
}
|
||||
|
||||
|
|
@ -919,18 +912,14 @@ export class ProtoOpTestContext {
|
|||
)
|
||||
) {
|
||||
throw new Error(
|
||||
`Test cases for test: ${test.name} [${
|
||||
test.operator
|
||||
}] must have the same rank for each inputs in different test cases`,
|
||||
`Test cases for test: ${test.name} [${test.operator}] must have the same rank for each inputs in different test cases`,
|
||||
);
|
||||
}
|
||||
} else if (test.inputShapeDefinitions === 'static') {
|
||||
// check if all test cases have data
|
||||
if (test.cases.some((testCase) => testCase.inputs!.some((input) => !input.data || !input.dims))) {
|
||||
throw new Error(
|
||||
`Test cases for test: ${test.name} [${
|
||||
test.operator
|
||||
}] must have data for each inputs when inputShapeDefinitions is 'rankOnly'`,
|
||||
`Test cases for test: ${test.name} [${test.operator}] must have data for each inputs when inputShapeDefinitions is 'rankOnly'`,
|
||||
);
|
||||
}
|
||||
|
||||
|
|
@ -946,9 +935,7 @@ export class ProtoOpTestContext {
|
|||
)
|
||||
) {
|
||||
throw new Error(
|
||||
`Test cases for test: ${test.name} [${
|
||||
test.operator
|
||||
}] must have the same shape for each inputs in different test cases`,
|
||||
`Test cases for test: ${test.name} [${test.operator}] must have the same shape for each inputs in different test cases`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
|
|
@ -1033,18 +1020,33 @@ async function runProtoOpTestcase(
|
|||
): Promise<void> {
|
||||
const feeds: Record<string, ort.Tensor> = {};
|
||||
const fetches: Record<string, Pick<ort.Tensor, 'dims' | 'type'>> = {};
|
||||
|
||||
const createTensor = (type: ort.Tensor.Type, data: number[], dims: readonly number[]): ort.Tensor => {
|
||||
let buffer: number[] | BigUint64Array | BigInt64Array | Uint16Array | Uint8Array = data;
|
||||
if (type === 'uint64') {
|
||||
buffer = BigUint64Array.from(data.map(BigInt));
|
||||
} else if (type === 'int64') {
|
||||
buffer = BigInt64Array.from(data.map(BigInt));
|
||||
} else if (type === 'float16') {
|
||||
const dataArr = Float16ArrayPolyfill.from(data);
|
||||
buffer = new Uint16Array(dataArr.buffer, dataArr.byteOffset, dataArr.byteLength / 2);
|
||||
} else if (type === 'uint4' || type === 'int4') {
|
||||
buffer = new Uint8Array(calculateTensorSizeInBytes(tensorDataTypeStringToEnum(type), dims)!);
|
||||
// encode (u)int4 data into Uint8Array
|
||||
for (let j = 0; j < data.length; j++) {
|
||||
/* eslint-disable no-bitwise */
|
||||
const byteIndex = j >> 1;
|
||||
const bitOffset = (j & 1) << 2;
|
||||
buffer[byteIndex] |= data[j] << bitOffset;
|
||||
/* eslint-enable no-bitwise */
|
||||
}
|
||||
}
|
||||
return new ort.Tensor(type, buffer, dims);
|
||||
};
|
||||
|
||||
testCase.inputs.forEach((input, i) => {
|
||||
if (input.data) {
|
||||
let data: number[] | BigUint64Array | BigInt64Array | Uint16Array = input.data;
|
||||
if (input.type === 'uint64') {
|
||||
data = BigUint64Array.from(input.data.map(BigInt));
|
||||
} else if (input.type === 'int64') {
|
||||
data = BigInt64Array.from(input.data.map(BigInt));
|
||||
} else if (input.type === 'float16') {
|
||||
const dataArr = Float16ArrayPolyfill.from(input.data);
|
||||
data = new Uint16Array(dataArr.buffer, dataArr.byteOffset, dataArr.byteLength / 2);
|
||||
}
|
||||
feeds[`input_${i}`] = new ort.Tensor(input.type, data, input.dims);
|
||||
feeds[`input_${i}`] = createTensor(input.type, input.data, input.dims);
|
||||
}
|
||||
});
|
||||
|
||||
|
|
@ -1052,16 +1054,7 @@ async function runProtoOpTestcase(
|
|||
const expectedOutputNames: string[] = [];
|
||||
testCase.outputs.forEach((output, i) => {
|
||||
if (output.data) {
|
||||
let data: number[] | BigUint64Array | BigInt64Array | Uint16Array = output.data;
|
||||
if (output.type === 'uint64') {
|
||||
data = BigUint64Array.from(output.data.map(BigInt));
|
||||
} else if (output.type === 'int64') {
|
||||
data = BigInt64Array.from(output.data.map(BigInt));
|
||||
} else if (output.type === 'float16') {
|
||||
const dataArr = Float16ArrayPolyfill.from(output.data);
|
||||
data = new Uint16Array(dataArr.buffer, dataArr.byteOffset, dataArr.byteLength / 2);
|
||||
}
|
||||
outputs.push(new ort.Tensor(output.type, data, output.dims));
|
||||
outputs.push(createTensor(output.type, output.data, output.dims));
|
||||
expectedOutputNames.push(`output_${i}`);
|
||||
fetches[`output_${i}`] = { dims: output.dims, type: output.type };
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue