[WebNN EP] Support WebNN async API with Asyncify (#19145)

This commit is contained in:
Wanming Lin 2024-01-25 07:37:35 +08:00 committed by GitHub
parent c456f19dba
commit 7252c6e747
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 30 additions and 49 deletions

View file

@ -21,10 +21,6 @@ interface BuildDefinitions {
/**
* defines whether to disable the whole WebNN backend in the build.
*/
readonly DISABLE_WEBNN: boolean;
/**
* defines whether to disable the whole WebAssembly backend in the build.
*/
readonly DISABLE_WASM: boolean;
/**
* defines whether to disable proxy feature in WebAssembly backend in the build.

View file

@ -23,12 +23,10 @@ if (!BUILD_DEFS.DISABLE_WASM) {
require('./backend-wasm-training').wasmBackend;
if (!BUILD_DEFS.DISABLE_WEBGPU) {
registerBackend('webgpu', wasmBackend, 5);
registerBackend('webnn', wasmBackend, 5);
}
registerBackend('cpu', wasmBackend, 10);
registerBackend('wasm', wasmBackend, 10);
if (!BUILD_DEFS.DISABLE_WEBNN) {
registerBackend('webnn', wasmBackend, 9);
}
}
Object.defineProperty(env.versions, 'web', {value: version, enumerable: true});

View file

@ -31,7 +31,7 @@ export interface OrtWasmModule extends EmscriptenModule {
_OrtGetLastError(errorCodeOffset: number, errorMessageOffset: number): void;
_OrtCreateSession(dataOffset: number, dataLength: number, sessionOptionsHandle: number): number;
_OrtCreateSession(dataOffset: number, dataLength: number, sessionOptionsHandle: number): Promise<number>;
_OrtReleaseSession(sessionHandle: number): void;
_OrtGetInputOutputCount(sessionHandle: number, inputCountOffset: number, outputCountOffset: number): number;
_OrtGetInputName(sessionHandle: number, index: number): number;

View file

@ -84,7 +84,7 @@ export const initRuntime = async(env: Env): Promise<void> => {
* @param epName
*/
export const initEp = async(env: Env, epName: string): Promise<void> => {
if (!BUILD_DEFS.DISABLE_WEBGPU && epName === 'webgpu') {
if (!BUILD_DEFS.DISABLE_WEBGPU && (epName === 'webgpu' || epName === 'webnn')) {
// perform WebGPU availability check
if (typeof navigator === 'undefined' || !navigator.gpu) {
throw new Error('WebGPU is not supported in current environment');
@ -228,7 +228,7 @@ export const createSession = async(
await Promise.all(loadingPromises);
}
sessionHandle = wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle);
sessionHandle = await wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle);
if (sessionHandle === 0) {
checkLastError('Can\'t create a session.');
}

View file

@ -44,7 +44,6 @@ const SOURCE_ROOT_FOLDER = path.join(__dirname, '../..'); // <ORT_ROOT>/js/
const DEFAULT_DEFINE = {
'BUILD_DEFS.DISABLE_WEBGL': 'false',
'BUILD_DEFS.DISABLE_WEBGPU': 'false',
'BUILD_DEFS.DISABLE_WEBNN': 'false',
'BUILD_DEFS.DISABLE_WASM': 'false',
'BUILD_DEFS.DISABLE_WASM_PROXY': 'false',
'BUILD_DEFS.DISABLE_WASM_THREAD': 'false',
@ -364,7 +363,6 @@ async function main() {
...DEFAULT_DEFINE,
'BUILD_DEFS.DISABLE_WEBGPU': 'true',
'BUILD_DEFS.DISABLE_WEBGL': 'true',
'BUILD_DEFS.DISABLE_WEBNN': 'true',
'BUILD_DEFS.DISABLE_WASM_PROXY': 'true',
'BUILD_DEFS.DISABLE_WASM_THREAD': 'true',
},
@ -397,7 +395,7 @@ async function main() {
// ort.webgpu[.min].js
await addAllWebBuildTasks({
outputBundleName: 'ort.webgpu',
define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true', 'BUILD_DEFS.DISABLE_WEBNN': 'true'},
define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true'},
});
// ort.wasm[.min].js
await addAllWebBuildTasks({
@ -411,7 +409,6 @@ async function main() {
...DEFAULT_DEFINE,
'BUILD_DEFS.DISABLE_WEBGPU': 'true',
'BUILD_DEFS.DISABLE_WASM': 'true',
'BUILD_DEFS.DISABLE_WEBNN': 'true',
},
});
// ort.wasm-core[.min].js
@ -421,7 +418,6 @@ async function main() {
...DEFAULT_DEFINE,
'BUILD_DEFS.DISABLE_WEBGPU': 'true',
'BUILD_DEFS.DISABLE_WEBGL': 'true',
'BUILD_DEFS.DISABLE_WEBNN': 'true',
'BUILD_DEFS.DISABLE_WASM_PROXY': 'true',
'BUILD_DEFS.DISABLE_WASM_THREAD': 'true',
},
@ -434,7 +430,6 @@ async function main() {
'BUILD_DEFS.DISABLE_TRAINING': 'false',
'BUILD_DEFS.DISABLE_WEBGPU': 'true',
'BUILD_DEFS.DISABLE_WEBGL': 'true',
'BUILD_DEFS.DISABLE_WEBNN': 'true',
},
});
}

View file

@ -396,10 +396,6 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs
const globalEnvFlags = parseGlobalEnvFlags(args);
if (backend.includes('webnn') && !globalEnvFlags.wasm!.proxy) {
throw new Error('Backend webnn requires flag "wasm-enable-proxy" to be set to true.');
}
// Options:
// --log-verbose=<...>
// --log-info=<...>

View file

@ -70,22 +70,13 @@ Status Model::Predict(const InlinedHashMap<std::string, OnnxTensorData>& inputs,
"The input of graph has unsupported type, name: ",
name, " type: ", tensor.tensor_info.data_type);
}
#ifdef ENABLE_WEBASSEMBLY_THREADS
// Copy the inputs from Wasm SharedArrayBuffer to the pre-allocated ArrayBuffers.
// Copy the inputs from Wasm ArrayBuffer to the WebNN inputs ArrayBuffer.
// As Wasm ArrayBuffer is not detachable.
wnn_inputs_[name].call<void>("set", view);
#else
wnn_inputs_.set(name, view);
#endif
}
#ifdef ENABLE_WEBASSEMBLY_THREADS
// This vector uses for recording output buffers from WebNN graph compution when WebAssembly
// multi-threads is enabled, since WebNN API only accepts non-shared ArrayBufferView,
// https://www.w3.org/TR/webnn/#typedefdef-mlnamedarraybufferviews
// and at this time the 'view' defined by Emscripten is shared ArrayBufferView, the memory
// address is different from the non-shared one, additional memory copy is required here.
InlinedHashMap<std::string, emscripten::val> output_views;
#endif
for (const auto& output : outputs) {
const std::string& name = output.first;
const struct OnnxTensorData tensor = output.second;
@ -131,21 +122,23 @@ Status Model::Predict(const InlinedHashMap<std::string, OnnxTensorData>& inputs,
name, " type: ", tensor.tensor_info.data_type);
}
#ifdef ENABLE_WEBASSEMBLY_THREADS
output_views.insert({name, view});
#else
wnn_outputs_.set(name, view);
#endif
}
wnn_context_.call<emscripten::val>("computeSync", wnn_graph_, wnn_inputs_, wnn_outputs_);
#ifdef ENABLE_WEBASSEMBLY_THREADS
// Copy the outputs from pre-allocated ArrayBuffers back to the Wasm SharedArrayBuffer.
emscripten::val results = wnn_context_.call<emscripten::val>(
"compute", wnn_graph_, wnn_inputs_, wnn_outputs_)
.await();
// Copy the outputs from pre-allocated ArrayBuffers back to the Wasm ArrayBuffer.
for (const auto& output : outputs) {
const std::string& name = output.first;
emscripten::val view = output_views.at(name);
view.call<void>("set", wnn_outputs_[name]);
view.call<void>("set", results["outputs"][name]);
}
#endif
// WebNN compute() method would return the input and output buffers via the promise
// resolution. Reuse the buffers to avoid additional allocation.
wnn_inputs_ = results["inputs"];
wnn_outputs_ = results["outputs"];
return Status::OK();
}

View file

@ -386,7 +386,8 @@ Status ModelBuilder::Compile(std::unique_ptr<Model>& model) {
for (auto& name : output_names_) {
named_operands.set(name, wnn_operands_.at(name));
}
emscripten::val wnn_graph = wnn_builder_.call<emscripten::val>("buildSync", named_operands);
emscripten::val wnn_graph = wnn_builder_.call<emscripten::val>("build", named_operands).await();
if (!wnn_graph.as<bool>()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to build WebNN graph.");
}
@ -395,13 +396,10 @@ Status ModelBuilder::Compile(std::unique_ptr<Model>& model) {
model->SetOutputs(std::move(output_names_));
model->SetScalarOutputs(std::move(scalar_outputs_));
model->SetInputOutputInfo(std::move(input_output_info_));
#ifdef ENABLE_WEBASSEMBLY_THREADS
// Pre-allocate the input and output tensors for the WebNN graph
// when WebAssembly multi-threads is enabled since WebNN API only
// accepts non-shared ArrayBufferView.
// https://www.w3.org/TR/webnn/#typedefdef-mlnamedarraybufferviews
// Wasm heap is not transferrable, we have to pre-allocate the MLNamedArrayBufferViews
// for inputs and outputs because they will be transferred after compute() done.
// https://webmachinelearning.github.io/webnn/#api-mlcontext-async-execution
model->AllocateInputOutputBuffers();
#endif
return Status::OK();
}

View file

@ -42,7 +42,8 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f
if (webnn_power_flags.compare("default") != 0) {
context_options.set("powerPreference", emscripten::val(webnn_power_flags));
}
wnn_context_ = ml.call<emscripten::val>("createContextSync", context_options);
wnn_context_ = ml.call<emscripten::val>("createContext", context_options).await();
if (!wnn_context_.as<bool>()) {
ORT_THROW("Failed to create WebNN context.");
}

View file

@ -160,6 +160,10 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea
};
// replace the original functions with asyncified versions
Module['_OrtCreateSession'] = jsepWrapAsync(
Module['_OrtCreateSession'],
() => Module['_OrtCreateSession'],
v => Module['_OrtCreateSession'] = v);
Module['_OrtRun'] = runAsync(jsepWrapAsync(
Module['_OrtRun'],
() => Module['_OrtRun'],