mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-23 02:38:28 +00:00
### Description This PR is a preview of cherry-picks for ort-web to `rel-1.17.3` based on `rel-1.17.2`. <details> <summary>Changes of ort-web to cherry-pick</summary> The following commits are from main branch. `o` stands for pick, and `x` stands for skip. ``` o2e0a388c36[js/webgpu] Add HardSigmoid support (#19215) od226e40856[js/webgpu] set query type in onRunStart (#19202) o61610ff986[js/webgpu] Add FusedConv clip test case (#18900) oa33b5bd1fa[JS/WebGPU] Added Uniforms to SkipLayerNorm. (#18788) o591f90c0b9[js/webgpu] Fix issue of timestamp query (#19258) o7252c6e747[WebNN EP] Support WebNN async API with Asyncify (#19145) o5b06505073[js/webgpu] Fix Tanh explosion (#19201) o656ca66186[js/webgpu] Support uniforms for conv, conv transpose, conv grouped (#18753) oa3f0e2422b[js/webgpu] Support f16 uniform (#19098) o9e69606360fix f16 for attention, enable slice and flatten for more types (#19262) o624b4e2063[js/webgpu] Remove enableShapesUniforms (#19279) o90883a366a[js/webgpu] Add hardSigmoid activation for fusedConv (#19233) o85cef0af8c[js/webgpu] Support capture and replay for jsep (#18989) od73131cf0f[js/webgpu] Use DataType as uniform cpu type (#19281) odd1f6ccc45[js/webgpu] resolve codescan alert (#19343) o3a2ab1963a[js/webgpu] Refactor createTensorShapeVariables (#18883) oefc17e79de[js/webgpu] Fix the undefined push error (#19366) x50806a7dd5[js/web] support external data in npm test (#19377) occbe264a39[js/webgpu] Add LeakyRelu activation for fusedConv (#19369) o5ff27ef02a[js/webgpu] support customop FastGelu (#19392) x03be65e064[js/web] fix types exports in package.json (#19458) o06269a3952[js/webgpu] allow uint8 tensors for webgpu (#19545) odfeda9019c[JS/WebGPU] Add MatMulNBits (#19446) o1b48054e1b[js/webgpu] Create Split indices helpers by rank, not by shape (#19554) o3fe2c137ee[js] small fix to workaround formatter (#19400) x70567a4b3a[js/web] use ApiTensor insteadof onnxjs Tensor in TensorResultValidator (#19358) o6e04e36e3f[js/common] upgrade tsc in common from 4.9.5 to 5.2.2 (#19317) o58f4921686[js] changes to allow Float16Array if any polyfill is available (#19305) o57d6819212[js/web] Fix fused-conv is not included in npm test (#19581) oebd220b073Misspelling in README.md (#19433) o38c3432393Bump ip from 1.1.8 to 1.1.9 in /js/react_native (#19582) ofe82fccf1a[js/webgpu] Fix Conv2DTransposeMatMul f16 compilation failure (#19596) o76a2a487a1Bump ip from 1.1.8 to 1.1.9 in /js/react_native/e2e (#19583) o29b1106033[node] Switch to setImmediate to avoid starving the Node.js event loop (#19610) oae3d73c981[JS/WebGPU] Fix Split and Where to handle corner cases. (#19613) oaec2389ad0[js/webgpu] allows a ProgramInfo's RunData to use zero sized output (#19614) obb43a0f133[js/webgpu] minor fixes to make tinyllama work (#19564) o0edb035808[js/web] fix suite test list for zero sized tensor (#19638) o3cb81cdde2[js/common] move 'env.wasm.trace' to 'env.trace' (#19617) oe30618d055[js/webgpu] use Headless for webgpu test by default (#19702) of06164ef8b[js/web] transfer input buffer back to caller thread (#19677) xa788514027[js/web] dump debug logs for karma for diagnose purpose (#19785) o24b72d2613[JS/WebGPU] Preserve zero size input tensor dims. (#19737) o4538d31a8b[js/webgpu] expose a few properties in WebGPU API (#19857) o53de2d8cb0[js/webgpu] Enable GroupedConvVectorize path (#19791) oed250b88c3[JS/WebGPU] Optimize MatMulNBits (#19852) xe771a763c3[js/test] align web test runner flags with ort.env (#19790) o79e50aeef3[js/web] rewrite backend resolve to allow multiple EPs (#19735) oacb0df2280Fix #19931 broken Get Started link of "ONNX Runtime JavaScript API" page (#19932) ob29849a287[js/common] fix typedoc warnings (#19933) oafdab62f53Bump follow-redirects from 1.15.4 to 1.15.6 in /js/web (#19949) o28ad6c3955Bump follow-redirects from 1.15.4 to 1.15.6 in /js/node (#19951) o7e0d424934accumulate in fp32 for Reduce* (#19868) o4c6a6a37f7[js/webgpu] Fix NAN caused by un-initialized buffer in instance-norm (#19387) o01c7aaf6aa[js/webgpu] allow setting env.webgpu.adapter (#19940) oc45cff60cf[js/webgpu] fix maxpool / fp16 (#19981) ``` </details> <details> <summary>Cherry-pick commandlines</summary> ```sh git cherry-pick2e0a388c36git cherry-pickd226e40856git cherry-pick61610ff986git cherry-picka33b5bd1fagit cherry-pick591f90c0b9git cherry-pick7252c6e747git cherry-pick5b06505073git cherry-pick656ca66186git cherry-picka3f0e2422bgit cherry-pick9e69606360git cherry-pick624b4e2063git cherry-pick90883a366agit cherry-pick85cef0af8c#<<<<< Note: conflicts git cherry-pickd73131cf0fgit cherry-pickdd1f6ccc45git cherry-pick3a2ab1963agit cherry-pickefc17e79degit cherry-pickccbe264a39git cherry-pick5ff27ef02agit cherry-pick06269a3952git cherry-pickdfeda9019cgit cherry-pick1b48054e1bgit cherry-pick3fe2c137eegit cherry-pick6e04e36e3fgit cherry-pick58f4921686git cherry-pick57d6819212git cherry-pickebd220b073git cherry-pick38c3432393git cherry-pickfe82fccf1agit cherry-pick76a2a487a1git cherry-pick29b1106033git cherry-pickae3d73c981git cherry-pickaec2389ad0git cherry-pickbb43a0f133git cherry-pick0edb035808git cherry-pick3cb81cdde2git cherry-picke30618d055git cherry-pickf06164ef8bgit cherry-pick24b72d2613git cherry-pick4538d31a8bgit cherry-pick53de2d8cb0git cherry-picked250b88c3git cherry-pick79e50aeef3git cherry-pickacb0df2280git cherry-pickb29849a287git cherry-pickafdab62f53git cherry-pick28ad6c3955git cherry-pick7e0d424934git cherry-pick4c6a6a37f7git cherry-pick01c7aaf6aagit cherry-pickc45cff60cf``` </details> <details> <summary>Cherry-pick conflicts</summary> -85cef0af8c#18989 this change is for enabling graph capture feature for JSEP, and it is done after ROCM EP enabled graph capture feature. However, the ROCM EP graph capture feature is not cherry-picked in rel-1.17.2. </details> --------- Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: Jiajia Qin <jiajia.qin@intel.com> Co-authored-by: Xu Xing <xing.xu@intel.com> Co-authored-by: satyajandhyala <satya.k.jandhyala@gmail.com> Co-authored-by: Yang Gu <yang.gu@intel.com> Co-authored-by: Wanming Lin <wanming.lin@intel.com> Co-authored-by: Jiajie Hu <jiajie.hu@intel.com> Co-authored-by: Guenther Schmuelling <guschmue@microsoft.com> Co-authored-by: Matttttt <18152455+martholomew@users.noreply.github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Segev Finer <segev208@gmail.com> Co-authored-by: Belem Zhang <belem.zhang@intel.com>
716 lines
28 KiB
TypeScript
716 lines
28 KiB
TypeScript
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
import {Env, InferenceSession, Tensor} from 'onnxruntime-common';
|
|
|
|
import {SerializableInternalBuffer, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages';
|
|
import {setRunOptions} from './run-options';
|
|
import {setSessionOptions} from './session-options';
|
|
import {dataLocationStringToEnum, getTensorElementSize, isGpuBufferSupportedType, logLevelStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common';
|
|
import {getInstance} from './wasm-factory';
|
|
import {allocWasmString, checkLastError} from './wasm-utils';
|
|
import {loadFile} from './wasm-utils-load-file';
|
|
|
|
// #region Initializations
|
|
|
|
/**
|
|
* There are 4 different "initialization" steps for ORT. They happen in different places and different time.
|
|
*
|
|
* 1. JavaScript initialization for onnxruntime-common and onnxruntime-web.
|
|
* This is the first initialization step. In this step, onnxruntime-web calls onnxruntime-common's registerBackend()
|
|
* function multiple times to register all the available backends. The backend registration is very fast. It only
|
|
* registers the backend name with the uninitialized backend object. No heavy initialization is done in this step.
|
|
* Refer to web/lib/index.ts for the backend registration.
|
|
*
|
|
* 2. WebAssembly artifact initialization.
|
|
* This happens when any registered wasm backend is used for the first time (ie. `ort.InferenceSession.create()` or
|
|
* `ort.TrainingSession.create()` is called). In this step, onnxruntime-web does the followings:
|
|
* - create a proxy worker and make sure the proxy worker is ready to receive messages, if proxy is enabled.
|
|
* - perform feature detection, locate correct WebAssembly artifact path and call the Emscripten generated
|
|
* JavaScript code to initialize the WebAssembly runtime.
|
|
* - if proxy is enabled, this step happens in the proxy worker using message 'init-wasm'.
|
|
* - downloading the 'ort-wasm{...}.wasm' file is done in this step.
|
|
* - if multi-thread is enabled, one or more webworker will be created to initialize the PThread threadpool.
|
|
*
|
|
* 3. ORT environment initialization.
|
|
* This happens after step 2. In this step, onnxruntime-web performs ONNX Runtime environment initialization.
|
|
* Function `_OrtInit()` is called in this step.
|
|
* - if proxy is enabled, this step happens in the proxy worker using message 'init-ort'.
|
|
* - logging level (ort.env.logLevel) and thread number (ort.env.wasm.numThreads) are set in this step.
|
|
*
|
|
* 4. Session initialization.
|
|
* This happens when `ort.InferenceSession.create()` or `ort.TrainingSession.create()` is called. Unlike the first 3
|
|
* steps (they only called once), this step will be done for each session. In this step, onnxruntime-web does the
|
|
* followings:
|
|
* If the parameter is a URL:
|
|
* - download the model data from the URL.
|
|
* - copy the model data to the WASM heap. (proxy: 'copy-from')
|
|
* - dereference the model buffer. This step allows the original ArrayBuffer to be garbage collected.
|
|
* - call `_OrtCreateSession()` to create the session. (proxy: 'create')
|
|
*
|
|
* If the parameter is a Uint8Array object:
|
|
* - copy the model data to the WASM heap. (proxy: 'copy-from')
|
|
* - call `_OrtCreateSession()` to create the session. (proxy: 'create')
|
|
*
|
|
*
|
|
*/
|
|
|
|
/**
|
|
* initialize ORT environment.
|
|
*
|
|
* @param numThreads SetGlobalIntraOpNumThreads(numThreads)
|
|
* @param loggingLevel CreateEnv(static_cast<OrtLoggingLevel>(logging_level))
|
|
*/
|
|
const initOrt = (numThreads: number, loggingLevel: number): void => {
|
|
const errorCode = getInstance()._OrtInit(numThreads, loggingLevel);
|
|
if (errorCode !== 0) {
|
|
checkLastError('Can\'t initialize onnxruntime.');
|
|
}
|
|
};
|
|
|
|
/**
|
|
* intialize runtime environment.
|
|
* @param env passed in the environment config object.
|
|
*/
|
|
export const initRuntime = async(env: Env): Promise<void> => {
|
|
// init ORT
|
|
initOrt(env.wasm.numThreads!, logLevelStringToEnum(env.logLevel));
|
|
};
|
|
|
|
/**
|
|
* perform EP specific initialization.
|
|
*
|
|
* @param env
|
|
* @param epName
|
|
*/
|
|
export const initEp = async(env: Env, epName: string): Promise<void> => {
|
|
if (!BUILD_DEFS.DISABLE_WEBGPU) {
|
|
// eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires
|
|
const initJsep = require('./jsep/init').init;
|
|
|
|
if (epName === 'webgpu') {
|
|
// perform WebGPU availability check
|
|
if (typeof navigator === 'undefined' || !navigator.gpu) {
|
|
throw new Error('WebGPU is not supported in current environment');
|
|
}
|
|
|
|
let adapter = env.webgpu.adapter as GPUAdapter | null;
|
|
if (!adapter) {
|
|
// if adapter is not set, request a new adapter.
|
|
const powerPreference = env.webgpu.powerPreference;
|
|
if (powerPreference !== undefined && powerPreference !== 'low-power' &&
|
|
powerPreference !== 'high-performance') {
|
|
throw new Error(`Invalid powerPreference setting: "${powerPreference}"`);
|
|
}
|
|
const forceFallbackAdapter = env.webgpu.forceFallbackAdapter;
|
|
if (forceFallbackAdapter !== undefined && typeof forceFallbackAdapter !== 'boolean') {
|
|
throw new Error(`Invalid forceFallbackAdapter setting: "${forceFallbackAdapter}"`);
|
|
}
|
|
adapter = await navigator.gpu.requestAdapter({powerPreference, forceFallbackAdapter});
|
|
if (!adapter) {
|
|
throw new Error(
|
|
'Failed to get GPU adapter. ' +
|
|
'You may need to enable flag "--enable-unsafe-webgpu" if you are using Chrome.');
|
|
}
|
|
} else {
|
|
// if adapter is set, validate it.
|
|
if (typeof adapter.limits !== 'object' || typeof adapter.features !== 'object' ||
|
|
typeof adapter.requestDevice !== 'function') {
|
|
throw new Error('Invalid GPU adapter set in `env.webgpu.adapter`. It must be a GPUAdapter object.');
|
|
}
|
|
}
|
|
|
|
if (!env.wasm.simd) {
|
|
throw new Error(
|
|
'Not supported for WebGPU=ON and SIMD=OFF. Please set `env.wasm.simd` to true when using `webgpu` EP');
|
|
}
|
|
|
|
await initJsep('webgpu', getInstance(), env, adapter);
|
|
}
|
|
if (epName === 'webnn') {
|
|
// perform WebNN availability check
|
|
if (typeof navigator === 'undefined' || !(navigator as unknown as {ml: unknown}).ml) {
|
|
throw new Error('WebNN is not supported in current environment');
|
|
}
|
|
|
|
await initJsep('webnn', getInstance(), env);
|
|
}
|
|
}
|
|
};
|
|
|
|
// #endregion Initializations
|
|
|
|
/**
|
|
* valid data locations for input/output tensors.
|
|
*/
|
|
type SupportedTensorDataLocationForInputOutput = 'cpu'|'cpu-pinned'|'gpu-buffer';
|
|
|
|
type IOBindingState = {
|
|
/**
|
|
* the handle of IO binding.
|
|
*/
|
|
readonly handle: number;
|
|
|
|
/**
|
|
* the preferred location for each output tensor.
|
|
*
|
|
* value is one of 'cpu', 'cpu-pinned', 'gpu-buffer'.
|
|
*/
|
|
readonly outputPreferredLocations: readonly SupportedTensorDataLocationForInputOutput[];
|
|
|
|
/**
|
|
* enum value of the preferred location for each output tensor.
|
|
*/
|
|
readonly outputPreferredLocationsEncoded: readonly number[];
|
|
};
|
|
|
|
/**
|
|
* tuple elements are: InferenceSession ID; inputNamesUTF8Encoded; outputNamesUTF8Encoded; bindingState
|
|
*/
|
|
type SessionMetadata = [
|
|
inferenceSessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[],
|
|
bindingState: IOBindingState|null, enableGraphCapture: boolean, inputOutputBound: boolean
|
|
];
|
|
|
|
const activeSessions = new Map<number, SessionMetadata>();
|
|
|
|
/**
|
|
* get the input/output count of the session.
|
|
* @param sessionHandle the handle representing the session. should be non-zero.
|
|
* @returns a tuple including 2 numbers, representing the input count and output count.
|
|
*/
|
|
const getSessionInputOutputCount = (sessionHandle: number): [number, number] => {
|
|
const wasm = getInstance();
|
|
const stack = wasm.stackSave();
|
|
try {
|
|
const dataOffset = wasm.stackAlloc(8);
|
|
const errorCode = wasm._OrtGetInputOutputCount(sessionHandle, dataOffset, dataOffset + 4);
|
|
if (errorCode !== 0) {
|
|
checkLastError('Can\'t get session input/output count.');
|
|
}
|
|
return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]];
|
|
} finally {
|
|
wasm.stackRestore(stack);
|
|
}
|
|
};
|
|
|
|
/**
|
|
* allocate the memory and memcpy the external buffer.
|
|
*
|
|
* @param model - the external buffer containing the model data. Must not be the same buffer as the WASM heap.
|
|
* @returns a 2-elements tuple - the pointer and size of the allocated buffer
|
|
*/
|
|
export const copyFromExternalBuffer = (model: Uint8Array): [number, number] => {
|
|
const wasm = getInstance();
|
|
const modelDataOffset = wasm._malloc(model.byteLength);
|
|
if (modelDataOffset === 0) {
|
|
throw new Error(`Can't create a session. failed to allocate a buffer of size ${model.byteLength}.`);
|
|
}
|
|
wasm.HEAPU8.set(model, modelDataOffset);
|
|
return [modelDataOffset, model.byteLength];
|
|
};
|
|
|
|
/**
|
|
* create an inference session from a model data buffer.
|
|
*
|
|
* @param modelData - either a Uint8Array object representing the model data, or a 2-elements tuple containing the
|
|
* pointer and size of the model data buffer.
|
|
* @param options an optional session options object.
|
|
* @returns a 3-elements tuple containing [session handle, input names, output names]
|
|
*/
|
|
export const createSession = async(
|
|
modelData: Uint8Array|SerializableInternalBuffer,
|
|
options?: InferenceSession.SessionOptions): Promise<SerializableSessionMetadata> => {
|
|
let modelDataOffset: number, modelDataLength: number;
|
|
const wasm = getInstance();
|
|
|
|
if (Array.isArray(modelData)) {
|
|
// if model data is an array, it must be a 2-elements tuple containing the pointer and size of the model data
|
|
[modelDataOffset, modelDataLength] = modelData;
|
|
} else if (modelData.buffer === wasm.HEAPU8.buffer) {
|
|
// if model data uses the same buffer as the WASM heap, we don't need to copy it.
|
|
[modelDataOffset, modelDataLength] = [modelData.byteOffset, modelData.byteLength];
|
|
} else {
|
|
// otherwise, copy the model data to the WASM heap.
|
|
[modelDataOffset, modelDataLength] = copyFromExternalBuffer(modelData);
|
|
}
|
|
|
|
let sessionHandle = 0;
|
|
let sessionOptionsHandle = 0;
|
|
let ioBindingHandle = 0;
|
|
let allocs: number[] = [];
|
|
const inputNamesUTF8Encoded = [];
|
|
const outputNamesUTF8Encoded = [];
|
|
|
|
try {
|
|
[sessionOptionsHandle, allocs] = setSessionOptions(options);
|
|
|
|
if (options?.externalData && wasm.mountExternalData) {
|
|
const loadingPromises = [];
|
|
for (const file of options.externalData) {
|
|
const path = typeof file === 'string' ? file : file.path;
|
|
loadingPromises.push(loadFile(typeof file === 'string' ? file : file.data).then(data => {
|
|
wasm.mountExternalData!(path, data);
|
|
}));
|
|
}
|
|
|
|
// wait for all external data files to be loaded
|
|
await Promise.all(loadingPromises);
|
|
}
|
|
|
|
sessionHandle = await wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle);
|
|
if (sessionHandle === 0) {
|
|
checkLastError('Can\'t create a session.');
|
|
}
|
|
|
|
const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle);
|
|
|
|
const enableGraphCapture = !!options?.enableGraphCapture;
|
|
|
|
const inputNames = [];
|
|
const outputNames = [];
|
|
const outputPreferredLocations: SupportedTensorDataLocationForInputOutput[] = [];
|
|
for (let i = 0; i < inputCount; i++) {
|
|
const name = wasm._OrtGetInputName(sessionHandle, i);
|
|
if (name === 0) {
|
|
checkLastError('Can\'t get an input name.');
|
|
}
|
|
inputNamesUTF8Encoded.push(name);
|
|
inputNames.push(wasm.UTF8ToString(name));
|
|
}
|
|
for (let i = 0; i < outputCount; i++) {
|
|
const name = wasm._OrtGetOutputName(sessionHandle, i);
|
|
if (name === 0) {
|
|
checkLastError('Can\'t get an output name.');
|
|
}
|
|
outputNamesUTF8Encoded.push(name);
|
|
const nameString = wasm.UTF8ToString(name);
|
|
outputNames.push(nameString);
|
|
|
|
if (!BUILD_DEFS.DISABLE_WEBGPU) {
|
|
if (enableGraphCapture && options?.preferredOutputLocation === undefined) {
|
|
outputPreferredLocations.push('gpu-buffer');
|
|
continue;
|
|
}
|
|
const location = typeof options?.preferredOutputLocation === 'string' ?
|
|
options.preferredOutputLocation :
|
|
options?.preferredOutputLocation?.[nameString] ?? 'cpu';
|
|
if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer') {
|
|
throw new Error(`Not supported preferred output location: ${location}.`);
|
|
}
|
|
if (enableGraphCapture && location !== 'gpu-buffer') {
|
|
throw new Error(`Not supported preferred output location: ${
|
|
location}. Only 'gpu-buffer' location is supported when enableGraphCapture is true.`);
|
|
}
|
|
outputPreferredLocations.push(location);
|
|
}
|
|
}
|
|
|
|
// use IO binding only when at least one output is preffered to be on GPU.
|
|
let bindingState: IOBindingState|null = null;
|
|
if (!BUILD_DEFS.DISABLE_WEBGPU && outputPreferredLocations.some(l => l === 'gpu-buffer')) {
|
|
ioBindingHandle = wasm._OrtCreateBinding(sessionHandle);
|
|
if (ioBindingHandle === 0) {
|
|
checkLastError('Can\'t create IO binding.');
|
|
}
|
|
|
|
bindingState = {
|
|
handle: ioBindingHandle,
|
|
outputPreferredLocations,
|
|
outputPreferredLocationsEncoded: outputPreferredLocations.map(l => dataLocationStringToEnum(l)),
|
|
};
|
|
}
|
|
|
|
activeSessions.set(
|
|
sessionHandle,
|
|
[sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState, enableGraphCapture, false]);
|
|
return [sessionHandle, inputNames, outputNames];
|
|
} catch (e) {
|
|
inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
|
|
outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
|
|
|
|
if (ioBindingHandle !== 0) {
|
|
wasm._OrtReleaseBinding(ioBindingHandle);
|
|
}
|
|
|
|
if (sessionHandle !== 0) {
|
|
wasm._OrtReleaseSession(sessionHandle);
|
|
}
|
|
throw e;
|
|
} finally {
|
|
wasm._free(modelDataOffset);
|
|
if (sessionOptionsHandle !== 0) {
|
|
wasm._OrtReleaseSessionOptions(sessionOptionsHandle);
|
|
}
|
|
allocs.forEach(alloc => wasm._free(alloc));
|
|
|
|
// unmount external data if necessary
|
|
wasm.unmountExternalData?.();
|
|
}
|
|
};
|
|
|
|
export const releaseSession = (sessionId: number): void => {
|
|
const wasm = getInstance();
|
|
const session = activeSessions.get(sessionId);
|
|
if (!session) {
|
|
throw new Error(`cannot release session. invalid session id: ${sessionId}`);
|
|
}
|
|
const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture] = session;
|
|
|
|
if (ioBindingState) {
|
|
if (enableGraphCapture) {
|
|
wasm._OrtClearBoundOutputs(ioBindingState.handle);
|
|
}
|
|
wasm._OrtReleaseBinding(ioBindingState.handle);
|
|
}
|
|
|
|
wasm.jsepOnReleaseSession?.(sessionId);
|
|
|
|
inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
|
|
outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
|
|
wasm._OrtReleaseSession(sessionHandle);
|
|
activeSessions.delete(sessionId);
|
|
};
|
|
|
|
export const prepareInputOutputTensor =
|
|
(tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number,
|
|
enableGraphCapture = false): void => {
|
|
if (!tensor) {
|
|
tensorHandles.push(0);
|
|
return;
|
|
}
|
|
|
|
const wasm = getInstance();
|
|
|
|
const dataType = tensor[0];
|
|
const dims = tensor[1];
|
|
const location = tensor[3];
|
|
|
|
let rawData: number;
|
|
let dataByteLength: number;
|
|
|
|
if (dataType === 'string' && location === 'gpu-buffer') {
|
|
throw new Error('String tensor is not supported on GPU.');
|
|
}
|
|
|
|
if (enableGraphCapture && location !== 'gpu-buffer') {
|
|
throw new Error(
|
|
`External buffer must be provided for input/output index ${index} when enableGraphCapture is true.`);
|
|
}
|
|
|
|
if (location === 'gpu-buffer') {
|
|
const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer;
|
|
const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!;
|
|
dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes;
|
|
|
|
const registerBuffer = wasm.jsepRegisterBuffer;
|
|
if (!registerBuffer) {
|
|
throw new Error('Tensor location "gpu-buffer" is not supported without using WebGPU.');
|
|
}
|
|
rawData = registerBuffer(sessionId, index, gpuBuffer, dataByteLength);
|
|
} else {
|
|
const data = tensor[2];
|
|
|
|
if (Array.isArray(data)) {
|
|
// string tensor
|
|
dataByteLength = 4 * data.length;
|
|
rawData = wasm._malloc(dataByteLength);
|
|
allocs.push(rawData);
|
|
let dataIndex = rawData / 4;
|
|
for (let i = 0; i < data.length; i++) {
|
|
if (typeof data[i] !== 'string') {
|
|
throw new TypeError(`tensor data at index ${i} is not a string`);
|
|
}
|
|
wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs);
|
|
}
|
|
} else {
|
|
dataByteLength = data.byteLength;
|
|
rawData = wasm._malloc(dataByteLength);
|
|
allocs.push(rawData);
|
|
wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData);
|
|
}
|
|
}
|
|
|
|
const stack = wasm.stackSave();
|
|
const dimsOffset = wasm.stackAlloc(4 * dims.length);
|
|
try {
|
|
let dimIndex = dimsOffset / 4;
|
|
dims.forEach(d => wasm.HEAP32[dimIndex++] = d);
|
|
const tensor = wasm._OrtCreateTensor(
|
|
tensorDataTypeStringToEnum(dataType), rawData, dataByteLength, dimsOffset, dims.length,
|
|
dataLocationStringToEnum(location));
|
|
if (tensor === 0) {
|
|
checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`);
|
|
}
|
|
tensorHandles.push(tensor);
|
|
} finally {
|
|
wasm.stackRestore(stack);
|
|
}
|
|
};
|
|
|
|
/**
|
|
* perform inference run
|
|
*/
|
|
export const run = async(
|
|
sessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[],
|
|
outputTensors: Array<TensorMetadata|null>, options: InferenceSession.RunOptions): Promise<TensorMetadata[]> => {
|
|
const wasm = getInstance();
|
|
const session = activeSessions.get(sessionId);
|
|
if (!session) {
|
|
throw new Error(`cannot run inference. invalid session id: ${sessionId}`);
|
|
}
|
|
const sessionHandle = session[0];
|
|
const inputNamesUTF8Encoded = session[1];
|
|
const outputNamesUTF8Encoded = session[2];
|
|
const ioBindingState = session[3];
|
|
const enableGraphCapture = session[4];
|
|
const inputOutputBound = session[5];
|
|
|
|
const inputCount = inputIndices.length;
|
|
const outputCount = outputIndices.length;
|
|
|
|
let runOptionsHandle = 0;
|
|
let runOptionsAllocs: number[] = [];
|
|
|
|
const inputTensorHandles: number[] = [];
|
|
const outputTensorHandles: number[] = [];
|
|
const inputOutputAllocs: number[] = [];
|
|
|
|
const beforeRunStack = wasm.stackSave();
|
|
const inputValuesOffset = wasm.stackAlloc(inputCount * 4);
|
|
const inputNamesOffset = wasm.stackAlloc(inputCount * 4);
|
|
const outputValuesOffset = wasm.stackAlloc(outputCount * 4);
|
|
const outputNamesOffset = wasm.stackAlloc(outputCount * 4);
|
|
|
|
try {
|
|
[runOptionsHandle, runOptionsAllocs] = setRunOptions(options);
|
|
|
|
// create input tensors
|
|
for (let i = 0; i < inputCount; i++) {
|
|
prepareInputOutputTensor(
|
|
inputTensors[i], inputTensorHandles, inputOutputAllocs, sessionId, inputIndices[i], enableGraphCapture);
|
|
}
|
|
|
|
// create output tensors
|
|
for (let i = 0; i < outputCount; i++) {
|
|
prepareInputOutputTensor(
|
|
outputTensors[i], outputTensorHandles, inputOutputAllocs, sessionId, inputCount + outputIndices[i],
|
|
enableGraphCapture);
|
|
}
|
|
|
|
let inputValuesIndex = inputValuesOffset / 4;
|
|
let inputNamesIndex = inputNamesOffset / 4;
|
|
let outputValuesIndex = outputValuesOffset / 4;
|
|
let outputNamesIndex = outputNamesOffset / 4;
|
|
for (let i = 0; i < inputCount; i++) {
|
|
wasm.HEAPU32[inputValuesIndex++] = inputTensorHandles[i];
|
|
wasm.HEAPU32[inputNamesIndex++] = inputNamesUTF8Encoded[inputIndices[i]];
|
|
}
|
|
for (let i = 0; i < outputCount; i++) {
|
|
wasm.HEAPU32[outputValuesIndex++] = outputTensorHandles[i];
|
|
wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]];
|
|
}
|
|
|
|
if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState && !inputOutputBound) {
|
|
const {handle, outputPreferredLocations, outputPreferredLocationsEncoded} = ioBindingState;
|
|
|
|
if (inputNamesUTF8Encoded.length !== inputCount) {
|
|
throw new Error(`input count from feeds (${
|
|
inputCount}) is expected to be always equal to model's input count (${inputNamesUTF8Encoded.length}).`);
|
|
}
|
|
|
|
// process inputs
|
|
for (let i = 0; i < inputCount; i++) {
|
|
const index = inputIndices[i];
|
|
const errorCode = await wasm._OrtBindInput(handle, inputNamesUTF8Encoded[index], inputTensorHandles[i]);
|
|
if (errorCode !== 0) {
|
|
checkLastError(`Can't bind input[${i}] for session=${sessionId}.`);
|
|
}
|
|
}
|
|
|
|
// process pre-allocated outputs
|
|
for (let i = 0; i < outputCount; i++) {
|
|
const index = outputIndices[i];
|
|
const location = outputTensors[i]?.[3]; // undefined means output is not pre-allocated.
|
|
|
|
if (location) {
|
|
// output is pre-allocated. bind the tensor.
|
|
const errorCode = wasm._OrtBindOutput(handle, outputNamesUTF8Encoded[index], outputTensorHandles[i], 0);
|
|
if (errorCode !== 0) {
|
|
checkLastError(`Can't bind pre-allocated output[${i}] for session=${sessionId}.`);
|
|
}
|
|
} else {
|
|
// output is not pre-allocated. reset preferred location.
|
|
const errorCode =
|
|
wasm._OrtBindOutput(handle, outputNamesUTF8Encoded[index], 0, outputPreferredLocationsEncoded[index]);
|
|
if (errorCode !== 0) {
|
|
checkLastError(`Can't bind output[${i}] to ${outputPreferredLocations[i]} for session=${sessionId}.`);
|
|
}
|
|
}
|
|
}
|
|
activeSessions.set(
|
|
sessionId,
|
|
[sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture, true]);
|
|
}
|
|
|
|
wasm.jsepOnRunStart?.(sessionHandle);
|
|
let errorCode: number;
|
|
if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) {
|
|
errorCode = await wasm._OrtRunWithBinding(
|
|
sessionHandle, ioBindingState.handle, outputCount, outputValuesOffset, runOptionsHandle);
|
|
} else {
|
|
errorCode = await wasm._OrtRun(
|
|
sessionHandle, inputNamesOffset, inputValuesOffset, inputCount, outputNamesOffset, outputCount,
|
|
outputValuesOffset, runOptionsHandle);
|
|
}
|
|
|
|
if (errorCode !== 0) {
|
|
checkLastError('failed to call OrtRun().');
|
|
}
|
|
|
|
const output: TensorMetadata[] = [];
|
|
|
|
for (let i = 0; i < outputCount; i++) {
|
|
const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i];
|
|
if (tensor === outputTensorHandles[i]) {
|
|
// output tensor is pre-allocated. no need to copy data.
|
|
output.push(outputTensors[i]!);
|
|
continue;
|
|
}
|
|
|
|
const beforeGetTensorDataStack = wasm.stackSave();
|
|
// stack allocate 4 pointer value
|
|
const tensorDataOffset = wasm.stackAlloc(4 * 4);
|
|
|
|
let keepOutputTensor = false;
|
|
let type: Tensor.Type|undefined, dataOffset = 0;
|
|
try {
|
|
const errorCode = wasm._OrtGetTensorData(
|
|
tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12);
|
|
if (errorCode !== 0) {
|
|
checkLastError(`Can't access output tensor data on index ${i}.`);
|
|
}
|
|
let tensorDataIndex = tensorDataOffset / 4;
|
|
const dataType = wasm.HEAPU32[tensorDataIndex++];
|
|
dataOffset = wasm.HEAPU32[tensorDataIndex++];
|
|
const dimsOffset = wasm.HEAPU32[tensorDataIndex++];
|
|
const dimsLength = wasm.HEAPU32[tensorDataIndex++];
|
|
const dims = [];
|
|
for (let i = 0; i < dimsLength; i++) {
|
|
dims.push(wasm.HEAPU32[dimsOffset / 4 + i]);
|
|
}
|
|
wasm._OrtFree(dimsOffset);
|
|
|
|
const size = dims.reduce((a, b) => a * b, 1);
|
|
type = tensorDataTypeEnumToString(dataType);
|
|
|
|
const preferredLocation = ioBindingState?.outputPreferredLocations[outputIndices[i]];
|
|
|
|
if (type === 'string') {
|
|
if (preferredLocation === 'gpu-buffer') {
|
|
throw new Error('String tensor is not supported on GPU.');
|
|
}
|
|
const stringData: string[] = [];
|
|
let dataIndex = dataOffset / 4;
|
|
for (let i = 0; i < size; i++) {
|
|
const offset = wasm.HEAPU32[dataIndex++];
|
|
const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset;
|
|
stringData.push(wasm.UTF8ToString(offset, maxBytesToRead));
|
|
}
|
|
output.push([type, dims, stringData, 'cpu']);
|
|
} else {
|
|
// If a certain output's preferred location is GPU but the tensor is empty, we still need to create a CPU
|
|
// tensor for it. There is no mapping GPU buffer for an empty tensor.
|
|
if (preferredLocation === 'gpu-buffer' && size > 0) {
|
|
const getBuffer = wasm.jsepGetBuffer;
|
|
if (!getBuffer) {
|
|
throw new Error('preferredLocation "gpu-buffer" is not supported without using WebGPU.');
|
|
}
|
|
const gpuBuffer = getBuffer(dataOffset);
|
|
const elementSize = getTensorElementSize(dataType);
|
|
if (elementSize === undefined || !isGpuBufferSupportedType(type)) {
|
|
throw new Error(`Unsupported data type: ${type}`);
|
|
}
|
|
|
|
// do not release the tensor right now. it will be released when user calls tensor.dispose().
|
|
keepOutputTensor = true;
|
|
|
|
output.push([
|
|
type, dims, {
|
|
gpuBuffer,
|
|
download: wasm.jsepCreateDownloader!(gpuBuffer, size * elementSize, type),
|
|
dispose: () => {
|
|
wasm._OrtReleaseTensor(tensor);
|
|
}
|
|
},
|
|
'gpu-buffer'
|
|
]);
|
|
} else {
|
|
const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type);
|
|
const data = new typedArrayConstructor(size);
|
|
new Uint8Array(data.buffer, data.byteOffset, data.byteLength)
|
|
.set(wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength));
|
|
output.push([type, dims, data, 'cpu']);
|
|
}
|
|
}
|
|
} finally {
|
|
wasm.stackRestore(beforeGetTensorDataStack);
|
|
if (type === 'string' && dataOffset) {
|
|
wasm._free(dataOffset);
|
|
}
|
|
if (!keepOutputTensor) {
|
|
wasm._OrtReleaseTensor(tensor);
|
|
}
|
|
}
|
|
}
|
|
|
|
if (ioBindingState && !enableGraphCapture) {
|
|
wasm._OrtClearBoundOutputs(ioBindingState.handle);
|
|
activeSessions.set(
|
|
sessionId,
|
|
[sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture, false]);
|
|
}
|
|
return output;
|
|
} finally {
|
|
wasm.stackRestore(beforeRunStack);
|
|
|
|
inputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v));
|
|
outputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v));
|
|
inputOutputAllocs.forEach(p => wasm._free(p));
|
|
|
|
if (runOptionsHandle !== 0) {
|
|
wasm._OrtReleaseRunOptions(runOptionsHandle);
|
|
}
|
|
runOptionsAllocs.forEach(p => wasm._free(p));
|
|
}
|
|
};
|
|
|
|
/**
|
|
* end profiling
|
|
*/
|
|
export const endProfiling = (sessionId: number): void => {
|
|
const wasm = getInstance();
|
|
const session = activeSessions.get(sessionId);
|
|
if (!session) {
|
|
throw new Error('invalid session id');
|
|
}
|
|
const sessionHandle = session[0];
|
|
|
|
// profile file name is not used yet, but it must be freed.
|
|
const profileFileName = wasm._OrtEndProfiling(sessionHandle);
|
|
if (profileFileName === 0) {
|
|
checkLastError('Can\'t get an profile file name.');
|
|
}
|
|
wasm._OrtFree(profileFileName);
|
|
};
|
|
|
|
export const extractTransferableBuffers = (tensors: readonly SerializableTensorMetadata[]): ArrayBufferLike[] => {
|
|
const buffers: ArrayBufferLike[] = [];
|
|
for (const tensor of tensors) {
|
|
const data = tensor[2];
|
|
if (!Array.isArray(data) && 'buffer' in data) {
|
|
buffers.push(data.buffer);
|
|
}
|
|
}
|
|
return buffers;
|
|
};
|