mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-12 00:59:23 +00:00
[WebNN EP] Support WebNN async API with Asyncify (#19145)
This commit is contained in:
parent
c456f19dba
commit
7252c6e747
10 changed files with 30 additions and 49 deletions
4
js/web/lib/build-def.d.ts
vendored
4
js/web/lib/build-def.d.ts
vendored
|
|
@ -21,10 +21,6 @@ interface BuildDefinitions {
|
|||
/**
|
||||
* defines whether to disable the whole WebNN backend in the build.
|
||||
*/
|
||||
readonly DISABLE_WEBNN: boolean;
|
||||
/**
|
||||
* defines whether to disable the whole WebAssembly backend in the build.
|
||||
*/
|
||||
readonly DISABLE_WASM: boolean;
|
||||
/**
|
||||
* defines whether to disable proxy feature in WebAssembly backend in the build.
|
||||
|
|
|
|||
|
|
@ -23,12 +23,10 @@ if (!BUILD_DEFS.DISABLE_WASM) {
|
|||
require('./backend-wasm-training').wasmBackend;
|
||||
if (!BUILD_DEFS.DISABLE_WEBGPU) {
|
||||
registerBackend('webgpu', wasmBackend, 5);
|
||||
registerBackend('webnn', wasmBackend, 5);
|
||||
}
|
||||
registerBackend('cpu', wasmBackend, 10);
|
||||
registerBackend('wasm', wasmBackend, 10);
|
||||
if (!BUILD_DEFS.DISABLE_WEBNN) {
|
||||
registerBackend('webnn', wasmBackend, 9);
|
||||
}
|
||||
}
|
||||
|
||||
Object.defineProperty(env.versions, 'web', {value: version, enumerable: true});
|
||||
|
|
|
|||
2
js/web/lib/wasm/binding/ort-wasm.d.ts
vendored
2
js/web/lib/wasm/binding/ort-wasm.d.ts
vendored
|
|
@ -31,7 +31,7 @@ export interface OrtWasmModule extends EmscriptenModule {
|
|||
|
||||
_OrtGetLastError(errorCodeOffset: number, errorMessageOffset: number): void;
|
||||
|
||||
_OrtCreateSession(dataOffset: number, dataLength: number, sessionOptionsHandle: number): number;
|
||||
_OrtCreateSession(dataOffset: number, dataLength: number, sessionOptionsHandle: number): Promise<number>;
|
||||
_OrtReleaseSession(sessionHandle: number): void;
|
||||
_OrtGetInputOutputCount(sessionHandle: number, inputCountOffset: number, outputCountOffset: number): number;
|
||||
_OrtGetInputName(sessionHandle: number, index: number): number;
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ export const initRuntime = async(env: Env): Promise<void> => {
|
|||
* @param epName
|
||||
*/
|
||||
export const initEp = async(env: Env, epName: string): Promise<void> => {
|
||||
if (!BUILD_DEFS.DISABLE_WEBGPU && epName === 'webgpu') {
|
||||
if (!BUILD_DEFS.DISABLE_WEBGPU && (epName === 'webgpu' || epName === 'webnn')) {
|
||||
// perform WebGPU availability check
|
||||
if (typeof navigator === 'undefined' || !navigator.gpu) {
|
||||
throw new Error('WebGPU is not supported in current environment');
|
||||
|
|
@ -228,7 +228,7 @@ export const createSession = async(
|
|||
await Promise.all(loadingPromises);
|
||||
}
|
||||
|
||||
sessionHandle = wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle);
|
||||
sessionHandle = await wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle);
|
||||
if (sessionHandle === 0) {
|
||||
checkLastError('Can\'t create a session.');
|
||||
}
|
||||
|
|
|
|||
|
|
@ -44,7 +44,6 @@ const SOURCE_ROOT_FOLDER = path.join(__dirname, '../..'); // <ORT_ROOT>/js/
|
|||
const DEFAULT_DEFINE = {
|
||||
'BUILD_DEFS.DISABLE_WEBGL': 'false',
|
||||
'BUILD_DEFS.DISABLE_WEBGPU': 'false',
|
||||
'BUILD_DEFS.DISABLE_WEBNN': 'false',
|
||||
'BUILD_DEFS.DISABLE_WASM': 'false',
|
||||
'BUILD_DEFS.DISABLE_WASM_PROXY': 'false',
|
||||
'BUILD_DEFS.DISABLE_WASM_THREAD': 'false',
|
||||
|
|
@ -364,7 +363,6 @@ async function main() {
|
|||
...DEFAULT_DEFINE,
|
||||
'BUILD_DEFS.DISABLE_WEBGPU': 'true',
|
||||
'BUILD_DEFS.DISABLE_WEBGL': 'true',
|
||||
'BUILD_DEFS.DISABLE_WEBNN': 'true',
|
||||
'BUILD_DEFS.DISABLE_WASM_PROXY': 'true',
|
||||
'BUILD_DEFS.DISABLE_WASM_THREAD': 'true',
|
||||
},
|
||||
|
|
@ -397,7 +395,7 @@ async function main() {
|
|||
// ort.webgpu[.min].js
|
||||
await addAllWebBuildTasks({
|
||||
outputBundleName: 'ort.webgpu',
|
||||
define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true', 'BUILD_DEFS.DISABLE_WEBNN': 'true'},
|
||||
define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true'},
|
||||
});
|
||||
// ort.wasm[.min].js
|
||||
await addAllWebBuildTasks({
|
||||
|
|
@ -411,7 +409,6 @@ async function main() {
|
|||
...DEFAULT_DEFINE,
|
||||
'BUILD_DEFS.DISABLE_WEBGPU': 'true',
|
||||
'BUILD_DEFS.DISABLE_WASM': 'true',
|
||||
'BUILD_DEFS.DISABLE_WEBNN': 'true',
|
||||
},
|
||||
});
|
||||
// ort.wasm-core[.min].js
|
||||
|
|
@ -421,7 +418,6 @@ async function main() {
|
|||
...DEFAULT_DEFINE,
|
||||
'BUILD_DEFS.DISABLE_WEBGPU': 'true',
|
||||
'BUILD_DEFS.DISABLE_WEBGL': 'true',
|
||||
'BUILD_DEFS.DISABLE_WEBNN': 'true',
|
||||
'BUILD_DEFS.DISABLE_WASM_PROXY': 'true',
|
||||
'BUILD_DEFS.DISABLE_WASM_THREAD': 'true',
|
||||
},
|
||||
|
|
@ -434,7 +430,6 @@ async function main() {
|
|||
'BUILD_DEFS.DISABLE_TRAINING': 'false',
|
||||
'BUILD_DEFS.DISABLE_WEBGPU': 'true',
|
||||
'BUILD_DEFS.DISABLE_WEBGL': 'true',
|
||||
'BUILD_DEFS.DISABLE_WEBNN': 'true',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -396,10 +396,6 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs
|
|||
|
||||
const globalEnvFlags = parseGlobalEnvFlags(args);
|
||||
|
||||
if (backend.includes('webnn') && !globalEnvFlags.wasm!.proxy) {
|
||||
throw new Error('Backend webnn requires flag "wasm-enable-proxy" to be set to true.');
|
||||
}
|
||||
|
||||
// Options:
|
||||
// --log-verbose=<...>
|
||||
// --log-info=<...>
|
||||
|
|
|
|||
|
|
@ -70,22 +70,13 @@ Status Model::Predict(const InlinedHashMap<std::string, OnnxTensorData>& inputs,
|
|||
"The input of graph has unsupported type, name: ",
|
||||
name, " type: ", tensor.tensor_info.data_type);
|
||||
}
|
||||
#ifdef ENABLE_WEBASSEMBLY_THREADS
|
||||
// Copy the inputs from Wasm SharedArrayBuffer to the pre-allocated ArrayBuffers.
|
||||
// Copy the inputs from Wasm ArrayBuffer to the WebNN inputs ArrayBuffer.
|
||||
// As Wasm ArrayBuffer is not detachable.
|
||||
wnn_inputs_[name].call<void>("set", view);
|
||||
#else
|
||||
wnn_inputs_.set(name, view);
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef ENABLE_WEBASSEMBLY_THREADS
|
||||
// This vector uses for recording output buffers from WebNN graph compution when WebAssembly
|
||||
// multi-threads is enabled, since WebNN API only accepts non-shared ArrayBufferView,
|
||||
// https://www.w3.org/TR/webnn/#typedefdef-mlnamedarraybufferviews
|
||||
// and at this time the 'view' defined by Emscripten is shared ArrayBufferView, the memory
|
||||
// address is different from the non-shared one, additional memory copy is required here.
|
||||
InlinedHashMap<std::string, emscripten::val> output_views;
|
||||
#endif
|
||||
|
||||
for (const auto& output : outputs) {
|
||||
const std::string& name = output.first;
|
||||
const struct OnnxTensorData tensor = output.second;
|
||||
|
|
@ -131,21 +122,23 @@ Status Model::Predict(const InlinedHashMap<std::string, OnnxTensorData>& inputs,
|
|||
name, " type: ", tensor.tensor_info.data_type);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_WEBASSEMBLY_THREADS
|
||||
output_views.insert({name, view});
|
||||
#else
|
||||
wnn_outputs_.set(name, view);
|
||||
#endif
|
||||
}
|
||||
wnn_context_.call<emscripten::val>("computeSync", wnn_graph_, wnn_inputs_, wnn_outputs_);
|
||||
#ifdef ENABLE_WEBASSEMBLY_THREADS
|
||||
// Copy the outputs from pre-allocated ArrayBuffers back to the Wasm SharedArrayBuffer.
|
||||
emscripten::val results = wnn_context_.call<emscripten::val>(
|
||||
"compute", wnn_graph_, wnn_inputs_, wnn_outputs_)
|
||||
.await();
|
||||
|
||||
// Copy the outputs from pre-allocated ArrayBuffers back to the Wasm ArrayBuffer.
|
||||
for (const auto& output : outputs) {
|
||||
const std::string& name = output.first;
|
||||
emscripten::val view = output_views.at(name);
|
||||
view.call<void>("set", wnn_outputs_[name]);
|
||||
view.call<void>("set", results["outputs"][name]);
|
||||
}
|
||||
#endif
|
||||
// WebNN compute() method would return the input and output buffers via the promise
|
||||
// resolution. Reuse the buffers to avoid additional allocation.
|
||||
wnn_inputs_ = results["inputs"];
|
||||
wnn_outputs_ = results["outputs"];
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -386,7 +386,8 @@ Status ModelBuilder::Compile(std::unique_ptr<Model>& model) {
|
|||
for (auto& name : output_names_) {
|
||||
named_operands.set(name, wnn_operands_.at(name));
|
||||
}
|
||||
emscripten::val wnn_graph = wnn_builder_.call<emscripten::val>("buildSync", named_operands);
|
||||
|
||||
emscripten::val wnn_graph = wnn_builder_.call<emscripten::val>("build", named_operands).await();
|
||||
if (!wnn_graph.as<bool>()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to build WebNN graph.");
|
||||
}
|
||||
|
|
@ -395,13 +396,10 @@ Status ModelBuilder::Compile(std::unique_ptr<Model>& model) {
|
|||
model->SetOutputs(std::move(output_names_));
|
||||
model->SetScalarOutputs(std::move(scalar_outputs_));
|
||||
model->SetInputOutputInfo(std::move(input_output_info_));
|
||||
#ifdef ENABLE_WEBASSEMBLY_THREADS
|
||||
// Pre-allocate the input and output tensors for the WebNN graph
|
||||
// when WebAssembly multi-threads is enabled since WebNN API only
|
||||
// accepts non-shared ArrayBufferView.
|
||||
// https://www.w3.org/TR/webnn/#typedefdef-mlnamedarraybufferviews
|
||||
// Wasm heap is not transferrable, we have to pre-allocate the MLNamedArrayBufferViews
|
||||
// for inputs and outputs because they will be transferred after compute() done.
|
||||
// https://webmachinelearning.github.io/webnn/#api-mlcontext-async-execution
|
||||
model->AllocateInputOutputBuffers();
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -42,7 +42,8 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f
|
|||
if (webnn_power_flags.compare("default") != 0) {
|
||||
context_options.set("powerPreference", emscripten::val(webnn_power_flags));
|
||||
}
|
||||
wnn_context_ = ml.call<emscripten::val>("createContextSync", context_options);
|
||||
|
||||
wnn_context_ = ml.call<emscripten::val>("createContext", context_options).await();
|
||||
if (!wnn_context_.as<bool>()) {
|
||||
ORT_THROW("Failed to create WebNN context.");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -160,6 +160,10 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea
|
|||
};
|
||||
|
||||
// replace the original functions with asyncified versions
|
||||
Module['_OrtCreateSession'] = jsepWrapAsync(
|
||||
Module['_OrtCreateSession'],
|
||||
() => Module['_OrtCreateSession'],
|
||||
v => Module['_OrtCreateSession'] = v);
|
||||
Module['_OrtRun'] = runAsync(jsepWrapAsync(
|
||||
Module['_OrtRun'],
|
||||
() => Module['_OrtRun'],
|
||||
|
|
|
|||
Loading…
Reference in a new issue