onnxruntime/js/web/lib/wasm/wasm-factory.ts
Yulong Wang b03c9496aa
[js/web] allow load WebAssembly binary from buffer (#21534)
### Description

This PR adds a new option `ort.env.wasm.wasmBinary`, which allows user
to set to a buffer containing preload .wasm file content.

This PR should resolve the problem from latest discussion in #20876.
2024-07-29 13:39:38 -07:00

201 lines
6.5 KiB
TypeScript

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {Env} from 'onnxruntime-common';
import type {OrtWasmModule} from './wasm-types';
import {importWasmModule} from './wasm-utils-import';
let wasm: OrtWasmModule|undefined;
let initialized = false;
let initializing = false;
let aborted = false;
const isMultiThreadSupported = (): boolean => {
// If 'SharedArrayBuffer' is not available, WebAssembly threads will not work.
if (typeof SharedArrayBuffer === 'undefined') {
return false;
}
try {
// Test for transferability of SABs (for browsers. needed for Firefox)
// https://groups.google.com/forum/#!msg/mozilla.dev.platform/IHkBZlHETpA/dwsMNchWEQAJ
if (typeof MessageChannel !== 'undefined') {
new MessageChannel().port1.postMessage(new SharedArrayBuffer(1));
}
// Test for WebAssembly threads capability (for both browsers and Node.js)
// This typed array is a WebAssembly program containing threaded instructions.
return WebAssembly.validate(new Uint8Array([
0, 97, 115, 109, 1, 0, 0, 0, 1, 4, 1, 96, 0, 0, 3, 2, 1, 0, 5,
4, 1, 3, 1, 1, 10, 11, 1, 9, 0, 65, 0, 254, 16, 2, 0, 26, 11
]));
} catch (e) {
return false;
}
};
const isSimdSupported = (): boolean => {
try {
// Test for WebAssembly SIMD capability (for both browsers and Node.js)
// This typed array is a WebAssembly program containing SIMD instructions.
// The binary data is generated from the following code by wat2wasm:
//
// (module
// (type $t0 (func))
// (func $f0 (type $t0)
// (drop
// (i32x4.dot_i16x8_s
// (i8x16.splat
// (i32.const 0))
// (v128.const i32x4 0x00000000 0x00000000 0x00000000 0x00000000)))))
return WebAssembly.validate(new Uint8Array([
0, 97, 115, 109, 1, 0, 0, 0, 1, 4, 1, 96, 0, 0, 3, 2, 1, 0, 10, 30, 1, 28, 0, 65, 0,
253, 15, 253, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 253, 186, 1, 26, 11
]));
} catch (e) {
return false;
}
};
export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise<void> => {
if (initialized) {
return Promise.resolve();
}
if (initializing) {
throw new Error('multiple calls to \'initializeWebAssembly()\' detected.');
}
if (aborted) {
throw new Error('previous call to \'initializeWebAssembly()\' failed.');
}
initializing = true;
// wasm flags are already initialized
const timeout = flags.initTimeout!;
let numThreads = flags.numThreads!;
// ensure SIMD is supported
if (!isSimdSupported()) {
throw new Error('WebAssembly SIMD is not supported in the current environment.');
}
// check if multi-threading is supported
const multiThreadSupported = isMultiThreadSupported();
if (numThreads > 1 && !multiThreadSupported) {
if (typeof self !== 'undefined' && !self.crossOriginIsolated) {
// eslint-disable-next-line no-console
console.warn(
'env.wasm.numThreads is set to ' + numThreads +
', but this will not work unless you enable crossOriginIsolated mode. ' +
'See https://web.dev/cross-origin-isolation-guide/ for more info.');
}
// eslint-disable-next-line no-console
console.warn(
'WebAssembly multi-threading is not supported in the current environment. ' +
'Falling back to single-threading.');
// set flags.numThreads to 1 so that OrtInit() will not create a global thread pool.
flags.numThreads = numThreads = 1;
}
const wasmPaths = flags.wasmPaths;
const wasmPrefixOverride = typeof wasmPaths === 'string' ? wasmPaths : undefined;
const mjsPathOverrideFlag = (wasmPaths as Env.WasmFilePaths)?.mjs;
const mjsPathOverride = (mjsPathOverrideFlag as URL)?.href ?? mjsPathOverrideFlag;
const wasmPathOverrideFlag = (wasmPaths as Env.WasmFilePaths)?.wasm;
const wasmPathOverride = (wasmPathOverrideFlag as URL)?.href ?? wasmPathOverrideFlag;
const wasmBinaryOverride = flags.wasmBinary;
const [objectUrl, ortWasmFactory] = (await importWasmModule(mjsPathOverride, wasmPrefixOverride, numThreads > 1));
let isTimeout = false;
const tasks: Array<Promise<void>> = [];
// promise for timeout
if (timeout > 0) {
tasks.push(new Promise((resolve) => {
setTimeout(() => {
isTimeout = true;
resolve();
}, timeout);
}));
}
// promise for module initialization
tasks.push(new Promise((resolve, reject) => {
const config: Partial<OrtWasmModule> = {
/**
* The number of threads. WebAssembly will create (Module.numThreads - 1) workers. If it is 1, no worker will be
* created.
*/
numThreads,
};
if (wasmBinaryOverride) {
/**
* Set a custom buffer which contains the WebAssembly binary. This will skip the wasm file fetching.
*/
config.wasmBinary = wasmBinaryOverride;
} else if (wasmPathOverride || wasmPrefixOverride) {
/**
* A callback function to locate the WebAssembly file. The function should return the full path of the file.
*
* Since Emscripten 3.1.58, this function is only called for the .wasm file.
*/
config.locateFile = (fileName, scriptDirectory) =>
wasmPathOverride ?? (wasmPrefixOverride ?? scriptDirectory) + fileName;
}
ortWasmFactory(config).then(
// wasm module initialized successfully
module => {
initializing = false;
initialized = true;
wasm = module;
resolve();
if (objectUrl) {
URL.revokeObjectURL(objectUrl);
}
},
// wasm module failed to initialize
(what) => {
initializing = false;
aborted = true;
reject(what);
});
}));
await Promise.race(tasks);
if (isTimeout) {
throw new Error(`WebAssembly backend initializing failed due to timeout: ${timeout}ms`);
}
};
export const getInstance = (): OrtWasmModule => {
if (initialized && wasm) {
return wasm;
}
throw new Error('WebAssembly is not initialized yet.');
};
export const dispose = (): void => {
if (initialized && !initializing && !aborted) {
// TODO: currently "PThread.terminateAllThreads()" is not exposed in the wasm module.
// And this function is not yet called by any code.
// If it is needed in the future, we should expose it in the wasm module and uncomment the following line.
// wasm?.PThread?.terminateAllThreads();
wasm = undefined;
initializing = false;
initialized = false;
aborted = true;
}
};