diff --git a/js/common/lib/backend-impl.ts b/js/common/lib/backend-impl.ts index 91c92247d2..49e363e398 100644 --- a/js/common/lib/backend-impl.ts +++ b/js/common/lib/backend-impl.ts @@ -30,13 +30,21 @@ export const registerBackend = (name: string, backend: Backend, priority: number const currentBackend = backends[name]; if (currentBackend === undefined) { backends[name] = {backend, priority}; - } else if (currentBackend.backend === backend) { + } else if (currentBackend.priority > priority) { + // same name is already registered with a higher priority. skip registeration. return; - } else { - throw new Error(`backend "${name}" is already registered`); + } else if (currentBackend.priority === priority) { + if (currentBackend.backend !== backend) { + throw new Error(`cannot register backend "${name}" using priority ${priority}`); + } } if (priority >= 0) { + const i = backendsSortedByPriority.indexOf(name); + if (i !== -1) { + backendsSortedByPriority.splice(i, 1); + } + for (let i = 0; i < backendsSortedByPriority.length; i++) { if (backends[backendsSortedByPriority[i]].priority <= priority) { backendsSortedByPriority.splice(i, 0, name); diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index 06864f6a2a..1f2f855a3e 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -157,13 +157,14 @@ export declare namespace InferenceSession { // Currently, we have the following backends to support execution providers: // Backend Node.js binding: supports 'cpu' and 'cuda'. - // Backend WebAssembly: supports 'wasm'. + // Backend WebAssembly: supports 'cpu', 'wasm' and 'xnnpack'. // Backend ONNX.js: supports 'webgl'. interface ExecutionProviderOptionMap { cpu: CpuExecutionProviderOption; cuda: CudaExecutionProviderOption; wasm: WebAssemblyExecutionProviderOption; webgl: WebGLExecutionProviderOption; + xnnpack: XnnpackExecutionProviderOption; } type ExecutionProviderName = keyof ExecutionProviderOptionMap; @@ -183,12 +184,14 @@ export declare namespace InferenceSession { } export interface WebAssemblyExecutionProviderOption extends ExecutionProviderOption { readonly name: 'wasm'; - // TODO: add flags } export interface WebGLExecutionProviderOption extends ExecutionProviderOption { readonly name: 'webgl'; // TODO: add flags } + export interface XnnpackExecutionProviderOption extends ExecutionProviderOption { + readonly name: 'xnnpack'; + } // #endregion // #endregion diff --git a/js/node/lib/index.ts b/js/node/lib/index.ts index 957b22aab0..e455f69a8f 100644 --- a/js/node/lib/index.ts +++ b/js/node/lib/index.ts @@ -5,4 +5,4 @@ export * from 'onnxruntime-common'; import {registerBackend} from 'onnxruntime-common'; import {onnxruntimeBackend} from './backend'; -registerBackend('cpu', onnxruntimeBackend, 1); +registerBackend('cpu', onnxruntimeBackend, 100); diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index fea2cd17e8..0e4a3f6d57 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -11,9 +11,11 @@ import {registerBackend} from 'onnxruntime-common'; if (!BUILD_DEFS.DISABLE_WEBGL) { const onnxjsBackend = require('./backend-onnxjs').onnxjsBackend; - registerBackend('webgl', onnxjsBackend, -1); + registerBackend('webgl', onnxjsBackend, -10); } if (!BUILD_DEFS.DISABLE_WASM) { const wasmBackend = require('./backend-wasm').wasmBackend; - registerBackend('wasm', wasmBackend, 0); + registerBackend('cpu', wasmBackend, 10); + registerBackend('wasm', wasmBackend, 10); + registerBackend('xnnpack', wasmBackend, 9); } diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 565e7627f4..fd82a83bd7 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -37,6 +37,7 @@ export interface OrtWasmModule extends EmscriptenModule { graphOptimizationLevel: number, enableCpuMemArena: boolean, enableMemPattern: boolean, executionMode: number, enableProfiling: boolean, profileFilePrefix: number, logId: number, logSeverityLevel: number, logVerbosityLevel: number): number; + _OrtAppendExecutionProvider(sessionOptionsHandle: number, name: number): number; _OrtAddSessionConfigEntry(sessionOptionsHandle: number, configKey: number, configValue: number): number; _OrtReleaseSessionOptions(sessionOptionsHandle: number): void; diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index 5b32fd5636..6d4d8eeb34 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -47,6 +47,31 @@ const appendDefaultOptions = (options: InferenceSession.SessionOptions): void => } }; +const setExecutionProviders = + (sessionOptionsHandle: number, executionProviders: readonly InferenceSession.ExecutionProviderConfig[], + allocs: number[]): void => { + for (const ep of executionProviders) { + let epName = typeof ep === 'string' ? ep : ep.name; + + // check EP name + switch (epName) { + case 'xnnpack': + epName = 'XNNPACK'; + break; + case 'wasm': + case 'cpu': + continue; + default: + throw new Error(`not supported EP: ${epName}`); + } + + const epNameDataOffset = allocWasmString(epName, allocs); + if (getInstance()._OrtAppendExecutionProvider(sessionOptionsHandle, epNameDataOffset) !== 0) { + throw new Error(`Can't append execution provider: ${epName}`); + } + } + }; + export const setSessionOptions = (options?: InferenceSession.SessionOptions): [number, number[]] => { const wasm = getInstance(); let sessionOptionsHandle = 0; @@ -105,6 +130,10 @@ export const setSessionOptions = (options?: InferenceSession.SessionOptions): [n throw new Error('Can\'t create session options'); } + if (options?.executionProviders) { + setExecutionProviders(sessionOptionsHandle, options.executionProviders, allocs); + } + if (options?.extra !== undefined) { iterateExtraOptions(options.extra, '', new WeakSet>(), (key, value) => { const keyDataOffset = allocWasmString(key, allocs); diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index c390b330e8..c008933bc4 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -35,6 +35,7 @@ Options: Backends can be one or more of the following, splitted by comma: webgl wasm + xnnpack -e=<...>, --env=<...> Specify the environment to run the test. Should be one of the following: chrome (default) edge (Windows only) @@ -97,7 +98,7 @@ Examples: export declare namespace TestRunnerCliArgs { type Mode = 'suite0'|'suite1'|'model'|'unittest'|'op'; - type Backend = 'cpu'|'webgl'|'wasm'|'onnxruntime'; + type Backend = 'cpu'|'webgl'|'wasm'|'onnxruntime'|'xnnpack'; type Environment = 'chrome'|'edge'|'firefox'|'electron'|'safari'|'node'|'bs'; type BundleMode = 'prod'|'dev'|'perf'; } @@ -333,7 +334,7 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs } // Option: -b=<...>, --backend=<...> - const browserBackends = ['webgl', 'wasm']; + const browserBackends = ['webgl', 'wasm', 'xnnpack']; const nodejsBackends = ['cpu', 'wasm']; const backendArgs = args.backend || args.b; const backend = diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index 9f587907ef..4bdea197bb 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -271,7 +271,7 @@ export class TensorResultValidator { this.absoluteThreshold = WEBGL_THRESHOLD_ABSOLUTE_ERROR; this.relativeThreshold = WEBGL_THRESHOLD_RELATIVE_ERROR; } - } else if (backend === 'wasm') { + } else if (backend === 'wasm' || backend === 'xnnpack') { this.absoluteThreshold = WASM_THRESHOLD_ABSOLUTE_ERROR; this.relativeThreshold = WASM_THRESHOLD_RELATIVE_ERROR; } else if (backend === 'onnxruntime') { diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 2e52460b34..ea39b1c3f2 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -112,13 +112,14 @@ OrtSessionOptions* OrtCreateSessionOptions(size_t graph_optimization_level, // Enable ORT CustomOps in onnxruntime-extensions RETURN_NULLPTR_IF_ERROR(EnableOrtCustomOps, session_options); #endif -#if defined(USE_XNNPACK) - RETURN_NULLPTR_IF_ERROR(SessionOptionsAppendExecutionProvider, session_options, "XNNPACK", nullptr, nullptr, 0); -#endif return session_options; } +int OrtAppendExecutionProvider(ort_session_options_handle_t session_options, const char* name) { + return CHECK_STATUS(SessionOptionsAppendExecutionProvider, session_options, name, nullptr, nullptr, 0); +} + int OrtAddSessionConfigEntry(OrtSessionOptions* session_options, const char* config_key, const char* config_value) { diff --git a/onnxruntime/wasm/api.h b/onnxruntime/wasm/api.h index 2c6ca6b9eb..d3435f2958 100644 --- a/onnxruntime/wasm/api.h +++ b/onnxruntime/wasm/api.h @@ -58,6 +58,13 @@ ort_session_options_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateSessionOptions(size_t size_t log_severity_level, size_t log_verbosity_level); +/** + * append an execution provider for a session. + * @param name the name of the execution provider + */ +int EMSCRIPTEN_KEEPALIVE OrtAppendExecutionProvider(ort_session_options_handle_t session_options, + const char* name); + /** * store configurations for a session. * @param session_options a handle to session options created by OrtCreateSessionOptions diff --git a/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml index 62e952b821..26e4d84898 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml @@ -29,7 +29,7 @@ jobs: variables: EnvSetupScript: setup_env.bat buildArch: x64 - CommonBuildArgs: '--parallel --config ${{ parameters.BuildConfig }} --skip_submodule_sync --cmake_generator "Visual Studio 16 2019" --build_wasm --emsdk_version releases-upstream-4c3772879a04140298c3abde90962d5567b5e2fc-64bit ${{ parameters.ExtraBuildArgs }}' + CommonBuildArgs: '--parallel --config ${{ parameters.BuildConfig }} --skip_submodule_sync --cmake_generator "Visual Studio 16 2019" --build_wasm --use_xnnpack --emsdk_version releases-upstream-4c3772879a04140298c3abde90962d5567b5e2fc-64bit ${{ parameters.ExtraBuildArgs }}' runCodesignValidationInjection: false timeoutInMinutes: ${{ parameters.TimeoutInMinutes }} workspace: