From 82786baed1f0087e7adb3d4fc19bf142cfb675bd Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 3 Oct 2022 10:38:45 -0700 Subject: [PATCH] [js/web] add 'xnnpack' to EP list (#12723) **Description**: This PR adds support for "XNNPACK EP" in ORTWeb and changes the behavior of how ORTWeb deals with "backends", or "EPs" in API. **Background**: Term "backend" is introduced in ONNX.js to representing a TypeScript type which implements a "backend" interface, which is a similar but different concept to ORT's EP (execution provider). There was 3 backends in ONNX.js: "cpu", "wasm" and "webgl". When ORT Web is launched, the concept is derived to help users to integrate smoothly. Technically, when "wasm" backend is used, users need to also specify "EP" in the session options. Considering it may get complicated and confused for users to figure out the difference between "backend" and "EP", the JS API hide the "backend" concept and made a mapping between names, backends and EPs: "webgl" (Name) <==> "onnxjsBackend" (Backend) "wasm" (Name) <==> "wasmBackend" (Backend) <==> "CPU" (EP) **Details**: The following changes are applied in this PR: 1. allow multi-registration for backends using the same name. This is for use scenarios where both "onnxruntime-node" and "onnxruntime-web" are consumed in a Node.js App ( so "cpu" will be registered twice in this scenario. ) 2. re-assign priority values to backends. I give 100 as base to "cpu" for node and react_native, and 10 as base to "cpu" in web. 3. add "cpu", "xnnpack" as new names of backends. 4. update onnxruntime wasm exported functions to support EP registration. 5. update implementations in ort web to handle execution providers in session options. 6. add '--use_xnnpack' as default build flag for ort-web --- js/common/lib/backend-impl.ts | 14 +++++++-- js/common/lib/inference-session.ts | 7 +++-- js/node/lib/index.ts | 2 +- js/web/lib/index.ts | 6 ++-- js/web/lib/wasm/binding/ort-wasm.d.ts | 1 + js/web/lib/wasm/session-options.ts | 29 +++++++++++++++++++ js/web/script/test-runner-cli-args.ts | 5 ++-- js/web/test/test-runner.ts | 2 +- onnxruntime/wasm/api.cc | 7 +++-- onnxruntime/wasm/api.h | 7 +++++ .../azure-pipelines/templates/win-wasm-ci.yml | 2 +- 11 files changed, 67 insertions(+), 15 deletions(-) diff --git a/js/common/lib/backend-impl.ts b/js/common/lib/backend-impl.ts index 91c92247d2..49e363e398 100644 --- a/js/common/lib/backend-impl.ts +++ b/js/common/lib/backend-impl.ts @@ -30,13 +30,21 @@ export const registerBackend = (name: string, backend: Backend, priority: number const currentBackend = backends[name]; if (currentBackend === undefined) { backends[name] = {backend, priority}; - } else if (currentBackend.backend === backend) { + } else if (currentBackend.priority > priority) { + // same name is already registered with a higher priority. skip registeration. return; - } else { - throw new Error(`backend "${name}" is already registered`); + } else if (currentBackend.priority === priority) { + if (currentBackend.backend !== backend) { + throw new Error(`cannot register backend "${name}" using priority ${priority}`); + } } if (priority >= 0) { + const i = backendsSortedByPriority.indexOf(name); + if (i !== -1) { + backendsSortedByPriority.splice(i, 1); + } + for (let i = 0; i < backendsSortedByPriority.length; i++) { if (backends[backendsSortedByPriority[i]].priority <= priority) { backendsSortedByPriority.splice(i, 0, name); diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index 06864f6a2a..1f2f855a3e 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -157,13 +157,14 @@ export declare namespace InferenceSession { // Currently, we have the following backends to support execution providers: // Backend Node.js binding: supports 'cpu' and 'cuda'. - // Backend WebAssembly: supports 'wasm'. + // Backend WebAssembly: supports 'cpu', 'wasm' and 'xnnpack'. // Backend ONNX.js: supports 'webgl'. interface ExecutionProviderOptionMap { cpu: CpuExecutionProviderOption; cuda: CudaExecutionProviderOption; wasm: WebAssemblyExecutionProviderOption; webgl: WebGLExecutionProviderOption; + xnnpack: XnnpackExecutionProviderOption; } type ExecutionProviderName = keyof ExecutionProviderOptionMap; @@ -183,12 +184,14 @@ export declare namespace InferenceSession { } export interface WebAssemblyExecutionProviderOption extends ExecutionProviderOption { readonly name: 'wasm'; - // TODO: add flags } export interface WebGLExecutionProviderOption extends ExecutionProviderOption { readonly name: 'webgl'; // TODO: add flags } + export interface XnnpackExecutionProviderOption extends ExecutionProviderOption { + readonly name: 'xnnpack'; + } // #endregion // #endregion diff --git a/js/node/lib/index.ts b/js/node/lib/index.ts index 957b22aab0..e455f69a8f 100644 --- a/js/node/lib/index.ts +++ b/js/node/lib/index.ts @@ -5,4 +5,4 @@ export * from 'onnxruntime-common'; import {registerBackend} from 'onnxruntime-common'; import {onnxruntimeBackend} from './backend'; -registerBackend('cpu', onnxruntimeBackend, 1); +registerBackend('cpu', onnxruntimeBackend, 100); diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index fea2cd17e8..0e4a3f6d57 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -11,9 +11,11 @@ import {registerBackend} from 'onnxruntime-common'; if (!BUILD_DEFS.DISABLE_WEBGL) { const onnxjsBackend = require('./backend-onnxjs').onnxjsBackend; - registerBackend('webgl', onnxjsBackend, -1); + registerBackend('webgl', onnxjsBackend, -10); } if (!BUILD_DEFS.DISABLE_WASM) { const wasmBackend = require('./backend-wasm').wasmBackend; - registerBackend('wasm', wasmBackend, 0); + registerBackend('cpu', wasmBackend, 10); + registerBackend('wasm', wasmBackend, 10); + registerBackend('xnnpack', wasmBackend, 9); } diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 565e7627f4..fd82a83bd7 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -37,6 +37,7 @@ export interface OrtWasmModule extends EmscriptenModule { graphOptimizationLevel: number, enableCpuMemArena: boolean, enableMemPattern: boolean, executionMode: number, enableProfiling: boolean, profileFilePrefix: number, logId: number, logSeverityLevel: number, logVerbosityLevel: number): number; + _OrtAppendExecutionProvider(sessionOptionsHandle: number, name: number): number; _OrtAddSessionConfigEntry(sessionOptionsHandle: number, configKey: number, configValue: number): number; _OrtReleaseSessionOptions(sessionOptionsHandle: number): void; diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index 5b32fd5636..6d4d8eeb34 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -47,6 +47,31 @@ const appendDefaultOptions = (options: InferenceSession.SessionOptions): void => } }; +const setExecutionProviders = + (sessionOptionsHandle: number, executionProviders: readonly InferenceSession.ExecutionProviderConfig[], + allocs: number[]): void => { + for (const ep of executionProviders) { + let epName = typeof ep === 'string' ? ep : ep.name; + + // check EP name + switch (epName) { + case 'xnnpack': + epName = 'XNNPACK'; + break; + case 'wasm': + case 'cpu': + continue; + default: + throw new Error(`not supported EP: ${epName}`); + } + + const epNameDataOffset = allocWasmString(epName, allocs); + if (getInstance()._OrtAppendExecutionProvider(sessionOptionsHandle, epNameDataOffset) !== 0) { + throw new Error(`Can't append execution provider: ${epName}`); + } + } + }; + export const setSessionOptions = (options?: InferenceSession.SessionOptions): [number, number[]] => { const wasm = getInstance(); let sessionOptionsHandle = 0; @@ -105,6 +130,10 @@ export const setSessionOptions = (options?: InferenceSession.SessionOptions): [n throw new Error('Can\'t create session options'); } + if (options?.executionProviders) { + setExecutionProviders(sessionOptionsHandle, options.executionProviders, allocs); + } + if (options?.extra !== undefined) { iterateExtraOptions(options.extra, '', new WeakSet>(), (key, value) => { const keyDataOffset = allocWasmString(key, allocs); diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index c390b330e8..c008933bc4 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -35,6 +35,7 @@ Options: Backends can be one or more of the following, splitted by comma: webgl wasm + xnnpack -e=<...>, --env=<...> Specify the environment to run the test. Should be one of the following: chrome (default) edge (Windows only) @@ -97,7 +98,7 @@ Examples: export declare namespace TestRunnerCliArgs { type Mode = 'suite0'|'suite1'|'model'|'unittest'|'op'; - type Backend = 'cpu'|'webgl'|'wasm'|'onnxruntime'; + type Backend = 'cpu'|'webgl'|'wasm'|'onnxruntime'|'xnnpack'; type Environment = 'chrome'|'edge'|'firefox'|'electron'|'safari'|'node'|'bs'; type BundleMode = 'prod'|'dev'|'perf'; } @@ -333,7 +334,7 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs } // Option: -b=<...>, --backend=<...> - const browserBackends = ['webgl', 'wasm']; + const browserBackends = ['webgl', 'wasm', 'xnnpack']; const nodejsBackends = ['cpu', 'wasm']; const backendArgs = args.backend || args.b; const backend = diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index 9f587907ef..4bdea197bb 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -271,7 +271,7 @@ export class TensorResultValidator { this.absoluteThreshold = WEBGL_THRESHOLD_ABSOLUTE_ERROR; this.relativeThreshold = WEBGL_THRESHOLD_RELATIVE_ERROR; } - } else if (backend === 'wasm') { + } else if (backend === 'wasm' || backend === 'xnnpack') { this.absoluteThreshold = WASM_THRESHOLD_ABSOLUTE_ERROR; this.relativeThreshold = WASM_THRESHOLD_RELATIVE_ERROR; } else if (backend === 'onnxruntime') { diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 2e52460b34..ea39b1c3f2 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -112,13 +112,14 @@ OrtSessionOptions* OrtCreateSessionOptions(size_t graph_optimization_level, // Enable ORT CustomOps in onnxruntime-extensions RETURN_NULLPTR_IF_ERROR(EnableOrtCustomOps, session_options); #endif -#if defined(USE_XNNPACK) - RETURN_NULLPTR_IF_ERROR(SessionOptionsAppendExecutionProvider, session_options, "XNNPACK", nullptr, nullptr, 0); -#endif return session_options; } +int OrtAppendExecutionProvider(ort_session_options_handle_t session_options, const char* name) { + return CHECK_STATUS(SessionOptionsAppendExecutionProvider, session_options, name, nullptr, nullptr, 0); +} + int OrtAddSessionConfigEntry(OrtSessionOptions* session_options, const char* config_key, const char* config_value) { diff --git a/onnxruntime/wasm/api.h b/onnxruntime/wasm/api.h index 2c6ca6b9eb..d3435f2958 100644 --- a/onnxruntime/wasm/api.h +++ b/onnxruntime/wasm/api.h @@ -58,6 +58,13 @@ ort_session_options_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateSessionOptions(size_t size_t log_severity_level, size_t log_verbosity_level); +/** + * append an execution provider for a session. + * @param name the name of the execution provider + */ +int EMSCRIPTEN_KEEPALIVE OrtAppendExecutionProvider(ort_session_options_handle_t session_options, + const char* name); + /** * store configurations for a session. * @param session_options a handle to session options created by OrtCreateSessionOptions diff --git a/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml index 62e952b821..26e4d84898 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml @@ -29,7 +29,7 @@ jobs: variables: EnvSetupScript: setup_env.bat buildArch: x64 - CommonBuildArgs: '--parallel --config ${{ parameters.BuildConfig }} --skip_submodule_sync --cmake_generator "Visual Studio 16 2019" --build_wasm --emsdk_version releases-upstream-4c3772879a04140298c3abde90962d5567b5e2fc-64bit ${{ parameters.ExtraBuildArgs }}' + CommonBuildArgs: '--parallel --config ${{ parameters.BuildConfig }} --skip_submodule_sync --cmake_generator "Visual Studio 16 2019" --build_wasm --use_xnnpack --emsdk_version releases-upstream-4c3772879a04140298c3abde90962d5567b5e2fc-64bit ${{ parameters.ExtraBuildArgs }}' runCodesignValidationInjection: false timeoutInMinutes: ${{ parameters.TimeoutInMinutes }} workspace: