[js] changes to allow Float16Array if any polyfill is available (#19305)

### Description

This change adds only necessary code to enable ort-web works with any
Float16Array polyfill. Unlike #19302, in this PR, ort-web does not
include any specific polyfill; instead, it's user's choice for how to
use a polyfill.

ORT-web uses Float16Array if it's available; otherwise, fallback to use
Uint16Array.

```js
// case 1: user does not use polyfill:
import * as ort from 'onnxruntime-web';

const myF16Data = new Uint16Array(...);  // need to use Uint16Array
const myF16tensor = new ort.Tensor('float16', myF16Data, dims);
```

```js
// case 2: user use polyfill:
import * as ort from 'onnxruntime-web';
import {
  Float16Array, isFloat16Array, isTypedArray,
  getFloat16, setFloat16,
  f16round,
} from "@petamoriken/float16";
globalThis.Float16Array = Float16Array;  // ort-web will pick the global Float16Array

const myF16Data = new Float16Array(...);  // Use the polyfilled Float16Array type
const myF16tensor = new ort.Tensor('float16', myF16Data, dims);
```
This commit is contained in:
Yulong Wang 2024-02-21 00:31:06 -08:00 committed by GitHub
parent 8092a89688
commit 58f4921686
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 37 additions and 16 deletions

View file

@ -14,7 +14,6 @@ export const NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP = new Map<string, SupportedTy
['uint8', Uint8Array],
['int8', Int8Array],
['uint16', Uint16Array],
['float16', Uint16Array],
['int16', Int16Array],
['int32', Int32Array],
['bool', Uint8Array],
@ -34,16 +33,22 @@ export const NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP = new Map<SupportedTypedArray
[Uint32Array, 'uint32'],
]);
// the following code allows delaying execution of BigInt checking. This allows lazy initialization for
// NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP and NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, which allows BigInt polyfill
// if available.
let isBigIntChecked = false;
export const checkBigInt = () => {
if (!isBigIntChecked) {
isBigIntChecked = true;
const isBigInt64ArrayAvailable = typeof BigInt64Array !== 'undefined' && typeof BigInt64Array.from === 'function';
const isBigUint64ArrayAvailable =
typeof BigUint64Array !== 'undefined' && typeof BigUint64Array.from === 'function';
// a dummy type declaration for Float16Array in case any polyfill is available.
declare global {
// eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-explicit-any
const Float16Array: any;
}
// the following code allows delaying execution of BigInt/Float16Array checking. This allows lazy initialization for
// NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP and NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, which allows BigInt/Float16Array
// polyfill if available.
let isTypedArrayChecked = false;
export const checkTypedArray = () => {
if (!isTypedArrayChecked) {
isTypedArrayChecked = true;
const isBigInt64ArrayAvailable = typeof BigInt64Array !== 'undefined' && BigInt64Array.from;
const isBigUint64ArrayAvailable = typeof BigUint64Array !== 'undefined' && BigUint64Array.from;
const isFloat16ArrayAvailable = typeof Float16Array !== 'undefined' && Float16Array.from;
if (isBigInt64ArrayAvailable) {
NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('int64', BigInt64Array);
@ -53,5 +58,12 @@ export const checkBigInt = () => {
NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('uint64', BigUint64Array);
NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(BigUint64Array, 'uint64');
}
if (isFloat16ArrayAvailable) {
NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('float16', Float16Array);
NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(Float16Array, 'float16');
} else {
// if Float16Array is not available, use 'Uint16Array' to store the data.
NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('float16', Uint16Array);
}
}
};

View file

@ -5,7 +5,7 @@ import {tensorToDataURL, tensorToImageData} from './tensor-conversion-impl.js';
import {TensorToDataUrlOptions, TensorToImageDataOptions} from './tensor-conversion.js';
import {tensorFromGpuBuffer, tensorFromImage, tensorFromPinnedBuffer, tensorFromTexture} from './tensor-factory-impl.js';
import {CpuPinnedConstructorParameters, GpuBufferConstructorParameters, TensorFromGpuBufferOptions, TensorFromImageBitmapOptions, TensorFromImageDataOptions, TensorFromImageElementOptions, TensorFromTextureOptions, TensorFromUrlOptions, TextureConstructorParameters} from './tensor-factory.js';
import {checkBigInt, NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP, NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, SupportedTypedArray, SupportedTypedArrayConstructors} from './tensor-impl-type-mapping.js';
import {checkTypedArray, NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP, NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, SupportedTypedArray, SupportedTypedArrayConstructors} from './tensor-impl-type-mapping.js';
import {calculateSize, tensorReshape} from './tensor-utils-impl.js';
import {Tensor as TensorInterface} from './tensor.js';
@ -67,8 +67,8 @@ export class Tensor implements TensorInterface {
arg0: TensorType|TensorDataType|readonly string[]|readonly boolean[]|CpuPinnedConstructorParameters|
TextureConstructorParameters|GpuBufferConstructorParameters,
arg1?: TensorDataType|readonly number[]|readonly string[]|readonly boolean[], arg2?: readonly number[]) {
// perform one-time check for BigInt support
checkBigInt();
// perform one-time check for BigInt/Float16Array support
checkTypedArray();
let type: TensorType;
let dims: readonly number[];
@ -142,7 +142,9 @@ export class Tensor implements TensorInterface {
throw new TypeError(`Unsupported tensor type: ${arg0}.`);
}
if (Array.isArray(arg1)) {
if (arg0 === 'float16') {
if (arg0 === 'float16' && typedArrayConstructor === Uint16Array) {
// When no Float16Array polyfill is used, we cannot create 'float16' tensor from number array.
//
// Throw error here because when user try to use number array as data,
// e.g. new Tensor('float16', [1, 2, 3, 4], dims)), it will actually call
// Uint16Array.from(arg1) which generates wrong data.

View file

@ -3,6 +3,12 @@
import {Tensor} from 'onnxruntime-common';
// a dummy type declaration for Float16Array in case any polyfill is available.
declare global {
// eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-explicit-any
const Float16Array: any;
}
// This file includes common definitions. They do NOT have dependency on the WebAssembly instance.
/**
@ -117,7 +123,8 @@ export const tensorTypeToTypedArrayConstructor = (type: Tensor.Type): Float32Arr
Uint8ArrayConstructor|Float64ArrayConstructor|Uint32ArrayConstructor|BigUint64ArrayConstructor => {
switch (type) {
case 'float16':
return Uint16Array;
// allow Float16Array polyfill.
return typeof Float16Array !== 'undefined' && Float16Array.from ? Float16Array : Uint16Array;
case 'float32':
return Float32Array;
case 'uint8':