onnxruntime/js/web/lib/wasm/session-options.ts
Wanming Lin 00b1e79e04
Support WebNN EP (#15698)
**Description**: 

This PR intends to enable WebNN EP in ONNX Runtime Web. It translates
the ONNX nodes by [WebNN
API](https://webmachinelearning.github.io/webnn/), which is implemented
in C++ and uses Emscripten [Embind
API](https://emscripten.org/docs/porting/connecting_cpp_and_javascript/embind.html#).
Temporarily using preferred layout **NHWC** for WebNN graph partitions
since the restriction in WebNN XNNPack backend implementation and the
ongoing
[discussion](https://github.com/webmachinelearning/webnn/issues/324) in
WebNN spec that whether WebNN should support both 'NHWC' and 'NCHW'
layouts. No WebNN native EP, only for Web.

**Motivation and Context**:
Allow ONNXRuntime Web developers to access WebNN API to benefit from
hardware acceleration.

**WebNN API Implementation Status in Chromium**:
- Tracked in Chromium issue:
[#1273291](https://bugs.chromium.org/p/chromium/issues/detail?id=1273291)
- **CPU device**: based on XNNPack backend, and had been available on
Chrome Canary M112 behind "#enable-experimental-web-platform-features"
flag for Windows and Linux platforms. Further implementation for more
ops is ongoing.
- **GPU device**: based on DML, implementation is ongoing.

**Open**:
- GitHub CI: WebNN currently is only available on Chrome Canary/Dev with
XNNPack backend for Linux and Windows. This is an open to reviewers to
help identify which GitHub CI should involved the WebNN EP and guide me
to enable it. Thanks!
2023-05-08 21:25:10 -07:00

166 lines
6.3 KiB
TypeScript

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {InferenceSession} from 'onnxruntime-common';
import {iterateExtraOptions} from './options-utils';
import {allocWasmString} from './string-utils';
import {getInstance} from './wasm-factory';
const getGraphOptimzationLevel = (graphOptimizationLevel: string|unknown): number => {
switch (graphOptimizationLevel) {
case 'disabled':
return 0;
case 'basic':
return 1;
case 'extended':
return 2;
case 'all':
return 99;
default:
throw new Error(`unsupported graph optimization level: ${graphOptimizationLevel}`);
}
};
const getExecutionMode = (executionMode: 'sequential'|'parallel'): number => {
switch (executionMode) {
case 'sequential':
return 0;
case 'parallel':
return 1;
default:
throw new Error(`unsupported execution mode: ${executionMode}`);
}
};
const appendDefaultOptions = (options: InferenceSession.SessionOptions): void => {
if (!options.extra) {
options.extra = {};
}
if (!options.extra.session) {
options.extra.session = {};
}
const session = options.extra.session as Record<string, string>;
if (!session.use_ort_model_bytes_directly) {
// eslint-disable-next-line camelcase
session.use_ort_model_bytes_directly = '1';
}
// if using JSEP with WebGPU, always disable memory pattern
if (options.executionProviders &&
options.executionProviders.some(ep => (typeof ep === 'string' ? ep : ep.name) === 'webgpu')) {
options.enableMemPattern = false;
}
};
const setExecutionProviders =
(sessionOptionsHandle: number, executionProviders: readonly InferenceSession.ExecutionProviderConfig[],
allocs: number[]): void => {
for (const ep of executionProviders) {
let epName = typeof ep === 'string' ? ep : ep.name;
// check EP name
switch (epName) {
case 'xnnpack':
epName = 'XNNPACK';
break;
case 'webnn':
epName = 'WEBNN';
if (typeof ep !== 'string') {
const webnnOptions = ep as InferenceSession.WebNNExecutionProviderOption;
if (webnnOptions?.deviceType) {
const keyDataOffset = allocWasmString('deviceType', allocs);
const valueDataOffset = allocWasmString(webnnOptions.deviceType, allocs);
if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !==
0) {
throw new Error(`Can't set a session config entry: 'deviceType' - ${webnnOptions.deviceType}`);
}
}
if (webnnOptions?.powerPreference) {
const keyDataOffset = allocWasmString('powerPreference', allocs);
const valueDataOffset = allocWasmString(webnnOptions.powerPreference, allocs);
if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !==
0) {
throw new Error(
`Can't set a session config entry: 'powerPreference' - ${webnnOptions.powerPreference}`);
}
}
}
break;
case 'webgpu':
epName = 'JS';
break;
case 'wasm':
case 'cpu':
continue;
default:
throw new Error(`not supported EP: ${epName}`);
}
const epNameDataOffset = allocWasmString(epName, allocs);
if (getInstance()._OrtAppendExecutionProvider(sessionOptionsHandle, epNameDataOffset) !== 0) {
throw new Error(`Can't append execution provider: ${epName}`);
}
}
};
export const setSessionOptions = (options?: InferenceSession.SessionOptions): [number, number[]] => {
const wasm = getInstance();
let sessionOptionsHandle = 0;
const allocs: number[] = [];
const sessionOptions: InferenceSession.SessionOptions = options || {};
appendDefaultOptions(sessionOptions);
try {
const graphOptimizationLevel = getGraphOptimzationLevel(sessionOptions.graphOptimizationLevel ?? 'all');
const executionMode = getExecutionMode(sessionOptions.executionMode ?? 'sequential');
const logIdDataOffset =
typeof sessionOptions.logId === 'string' ? allocWasmString(sessionOptions.logId, allocs) : 0;
const logSeverityLevel = sessionOptions.logSeverityLevel ?? 2; // Default to 2 - warning
if (!Number.isInteger(logSeverityLevel) || logSeverityLevel < 0 || logSeverityLevel > 4) {
throw new Error(`log serverity level is not valid: ${logSeverityLevel}`);
}
const logVerbosityLevel = sessionOptions.logVerbosityLevel ?? 0; // Default to 0 - verbose
if (!Number.isInteger(logVerbosityLevel) || logVerbosityLevel < 0 || logVerbosityLevel > 4) {
throw new Error(`log verbosity level is not valid: ${logVerbosityLevel}`);
}
const optimizedModelFilePathOffset = typeof sessionOptions.optimizedModelFilePath === 'string' ?
allocWasmString(sessionOptions.optimizedModelFilePath, allocs) :
0;
sessionOptionsHandle = wasm._OrtCreateSessionOptions(
graphOptimizationLevel, !!sessionOptions.enableCpuMemArena, !!sessionOptions.enableMemPattern, executionMode,
!!sessionOptions.enableProfiling, 0, logIdDataOffset, logSeverityLevel, logVerbosityLevel,
optimizedModelFilePathOffset);
if (sessionOptionsHandle === 0) {
throw new Error('Can\'t create session options');
}
if (sessionOptions.executionProviders) {
setExecutionProviders(sessionOptionsHandle, sessionOptions.executionProviders, allocs);
}
if (sessionOptions.extra !== undefined) {
iterateExtraOptions(sessionOptions.extra, '', new WeakSet<Record<string, unknown>>(), (key, value) => {
const keyDataOffset = allocWasmString(key, allocs);
const valueDataOffset = allocWasmString(value, allocs);
if (wasm._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) {
throw new Error(`Can't set a session config entry: ${key} - ${value}`);
}
});
}
return [sessionOptionsHandle, allocs];
} catch (e) {
if (sessionOptionsHandle !== 0) {
wasm._OrtReleaseSessionOptions(sessionOptionsHandle);
}
allocs.forEach(wasm._free);
throw e;
}
};