onnxruntime/js/web/lib/wasm/proxy-wrapper.ts
Ye Wang 83dc22585c
Second round cherry-pick to rel-1.9.0 (#9062)
* Adding async fetching for webgl backend (#8951)

* Adding async fetching for webgl backend

* fix PR comments and CI failure.

* fixing a bug

* adding a flag

* Enable linking in exception throwing support library when build onnxruntime wasm. (#8973)

* Enable linking in exception throwing support library when build onnxruntime webassembly containing onnxruntime-extensions.

* Add flag in build.py to enable linking exceptions throwing library.

* Update onnxruntime-extensions document and bind custom_ops build flag with use_extensions.

* Update doc.

* Update cgmanifest.json.

Co-authored-by: Zuwei Zhao <zuzhao@microsoft.com>

* Remove document text from error message in a couple of ops (#9003)

* do not add pkg wheel entry to the index html file if it already exists (#9004)

* do not add pkg wheel entry to the index html file if it already exists

* [js/web] fix ort web e2e test (#9025)

* Fix cmake POWER10 detection

Recent commit 60c98a8 changed variable mlas_common_srcs which affects
POWER10 detection.

* Fix Where op type reduction processing (#9033)

* Update type reduction script to track Where Op's second input type.

* Clean up op_kernel_type_control.h includes.

* Use more maintainable include.

* Fix ROCm wheels CI pipeline break by installing latest protobuf from source (#9047)

* install protobuf from source

* fix rm command in Dockerfile

* fix options on rm command

* fix cd into protobuf source directory

* try again

* remove strip step

* debug list the files

* ls on /usr

* more debug

* more debug

* adjust LD_LIBRARY_PATH

* try remove protobuf before ORT build

* [js/web] a bugfix and add tests for wasm proxy worker (#9048)

* [js/web] add tests for wasm proxy worker

* fix script src override

* Set onnxruntime_DISABLE_RTTI to default OFF (#9049)

Co-authored-by: Du Li <duli1@microsoft.com>
Co-authored-by: Zuwei Zhao <4123666+Zuwei-Zhao@users.noreply.github.com>
Co-authored-by: Zuwei Zhao <zuzhao@microsoft.com>
Co-authored-by: Hariharan Seshadri <shariharan91@gmail.com>
Co-authored-by: liqun Fu <liqfu@microsoft.com>
Co-authored-by: Yulong Wang <yulongw@microsoft.com>
Co-authored-by: Rajalakshmi Srinivasaraghavan <rajis@linux.ibm.com>
Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
Co-authored-by: Suffian Khan <sukha@microsoft.com>
Co-authored-by: Changming Sun <chasun@microsoft.com>
2021-09-15 18:02:07 -07:00

186 lines
6 KiB
TypeScript

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {env, InferenceSession} from 'onnxruntime-common';
import {OrtWasmMessage, SerializableSessionMetadata, SerializableTensor} from './proxy-messages';
import * as core from './wasm-core-impl';
import {initializeWebAssembly} from './wasm-factory';
const isProxy = (): boolean => !!env.wasm.proxy && typeof document !== 'undefined';
let proxyWorker: Worker|undefined;
let initializing = false;
let initialized = false;
let aborted = false;
// resolve; reject
type PromiseCallbacks<T = void> = [(result: T) => void, (reason: unknown) => void];
let initWasmCallbacks: PromiseCallbacks;
let initOrtCallbacks: PromiseCallbacks;
const createSessionCallbacks: Array<PromiseCallbacks<SerializableSessionMetadata>> = [];
const releaseSessionCallbacks: Array<PromiseCallbacks<void>> = [];
const runCallbacks: Array<PromiseCallbacks<SerializableTensor[]>> = [];
const endProfilingCallbacks: Array<PromiseCallbacks<void>> = [];
const ensureWorker = (): void => {
if (initializing || !initialized || aborted || !proxyWorker) {
throw new Error('worker not ready');
}
};
const onProxyWorkerMessage = (ev: MessageEvent<OrtWasmMessage>): void => {
switch (ev.data.type) {
case 'init-wasm':
initializing = false;
if (ev.data.err) {
aborted = true;
initWasmCallbacks[1](ev.data.err);
} else {
initialized = true;
initWasmCallbacks[0]();
}
break;
case 'init-ort':
if (ev.data.err) {
initOrtCallbacks[1](ev.data.err);
} else {
initOrtCallbacks[0]();
}
break;
case 'create':
if (ev.data.err) {
createSessionCallbacks.shift()![1](ev.data.err);
} else {
createSessionCallbacks.shift()![0](ev.data.out!);
}
break;
case 'release':
if (ev.data.err) {
releaseSessionCallbacks.shift()![1](ev.data.err);
} else {
releaseSessionCallbacks.shift()![0]();
}
break;
case 'run':
if (ev.data.err) {
runCallbacks.shift()![1](ev.data.err);
} else {
runCallbacks.shift()![0](ev.data.out!);
}
break;
case 'end-profiling':
if (ev.data.err) {
endProfilingCallbacks.shift()![1](ev.data.err);
} else {
endProfilingCallbacks.shift()![0]();
}
break;
default:
}
};
const scriptSrc = typeof document !== 'undefined' ? (document?.currentScript as HTMLScriptElement)?.src : undefined;
export const initWasm = async(): Promise<void> => {
if (isProxy()) {
if (initialized) {
return;
}
if (initializing) {
throw new Error('multiple calls to \'initWasm()\' detected.');
}
if (aborted) {
throw new Error('previous call to \'initWasm()\' failed.');
}
initializing = true;
// overwrite wasm filepaths
if (env.wasm.wasmPaths === undefined) {
if (scriptSrc && scriptSrc.indexOf('blob:') !== 0) {
env.wasm.wasmPaths = scriptSrc.substr(0, (scriptSrc as string).lastIndexOf('/') + 1);
}
}
return new Promise<void>((resolve, reject) => {
proxyWorker?.terminate();
// eslint-disable-next-line @typescript-eslint/no-var-requires, @typescript-eslint/no-require-imports
proxyWorker = require('worker-loader?inline=no-fallback!./proxy-worker/main').default() as Worker;
proxyWorker.onmessage = onProxyWorkerMessage;
initWasmCallbacks = [resolve, reject];
const message: OrtWasmMessage = {type: 'init-wasm', in : env.wasm};
proxyWorker.postMessage(message);
});
} else {
return initializeWebAssembly(env.wasm);
}
};
export const initOrt = async(numThreads: number, loggingLevel: number): Promise<void> => {
if (isProxy()) {
ensureWorker();
return new Promise<void>((resolve, reject) => {
initOrtCallbacks = [resolve, reject];
const message: OrtWasmMessage = {type: 'init-ort', in : {numThreads, loggingLevel}};
proxyWorker!.postMessage(message);
});
} else {
core.initOrt(numThreads, loggingLevel);
}
};
export const createSession =
async(model: Uint8Array, options?: InferenceSession.SessionOptions): Promise<SerializableSessionMetadata> => {
if (isProxy()) {
ensureWorker();
return new Promise<SerializableSessionMetadata>((resolve, reject) => {
createSessionCallbacks.push([resolve, reject]);
const message: OrtWasmMessage = {type: 'create', in : {model, options}};
proxyWorker!.postMessage(message, [model.buffer]);
});
} else {
return core.createSession(model, options);
}
};
export const releaseSession = async(sessionId: number): Promise<void> => {
if (isProxy()) {
ensureWorker();
return new Promise<void>((resolve, reject) => {
releaseSessionCallbacks.push([resolve, reject]);
const message: OrtWasmMessage = {type: 'release', in : sessionId};
proxyWorker!.postMessage(message);
});
} else {
core.releaseSession(sessionId);
}
};
export const run = async(
sessionId: number, inputIndices: number[], inputs: SerializableTensor[], outputIndices: number[],
options: InferenceSession.RunOptions): Promise<SerializableTensor[]> => {
if (isProxy()) {
ensureWorker();
return new Promise<SerializableTensor[]>((resolve, reject) => {
runCallbacks.push([resolve, reject]);
const message: OrtWasmMessage = {type: 'run', in : {sessionId, inputIndices, inputs, outputIndices, options}};
proxyWorker!.postMessage(message, core.extractTransferableBuffers(inputs));
});
} else {
return core.run(sessionId, inputIndices, inputs, outputIndices, options);
}
};
export const endProfiling = async(sessionId: number): Promise<void> => {
if (isProxy()) {
ensureWorker();
return new Promise<void>((resolve, reject) => {
endProfilingCallbacks.push([resolve, reject]);
const message: OrtWasmMessage = {type: 'end-profiling', in : sessionId};
proxyWorker!.postMessage(message);
});
} else {
core.endProfiling(sessionId);
}
};