diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index d04578ca69..06fcbf6344 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -78,7 +78,7 @@ export interface OrtWasmModule extends EmscriptenModule { _JsepGetNodeName(kernel: number): number; jsepOnRunStart?(sessionId: number): void; - jsepOnRunEnd?(sessionId: number): void; + jsepOnRunEnd?(sessionId: number): Promise; jsepRunPromise?: Promise; // #endregion } diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 9dc55b0d12..fcca82ab2a 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -258,6 +258,7 @@ export const run = async( wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]]; } + // jsepOnRunStart is only available when JSEP is enabled. wasm.jsepOnRunStart?.(sessionId); // support RunOptions @@ -265,13 +266,28 @@ export const run = async( sessionHandle, inputNamesOffset, inputValuesOffset, inputCount, outputNamesOffset, outputCount, outputValuesOffset, runOptionsHandle); - wasm.jsepOnRunEnd?.(sessionId); - const runPromise = wasm.jsepRunPromise; - if (runPromise && typeof runPromise.then !== 'undefined') { + if (runPromise) { + // jsepRunPromise is a Promise object. It is only available when JSEP is enabled. + // + // OrtRun() is a synchrnous call, but it internally calls async functions. Emscripten's ASYNCIFY allows it to + // work in this way. However, OrtRun() does not return a promise, so when code reaches here, it is earlier than + // the async functions are finished. + // + // To make it work, we created a Promise and resolve the promise when the C++ code actually reaches the end of + // OrtRun(). If the promise exists, we need to await for the promise to be resolved. errorCode = await runPromise; } + const jsepOnRunEnd = wasm.jsepOnRunEnd; + if (jsepOnRunEnd) { + // jsepOnRunEnd is only available when JSEP is enabled. + // + // it returns a promise, which is resolved or rejected when the following async functions are finished: + // - collecting GPU validation errors. + await jsepOnRunEnd(sessionId); + } + const output: SerializableTensor[] = []; if (errorCode !== 0) { diff --git a/onnxruntime/wasm/js_internal_api.js b/onnxruntime/wasm/js_internal_api.js index c7bc0e39fc..15d393f4ce 100644 --- a/onnxruntime/wasm/js_internal_api.js +++ b/onnxruntime/wasm/js_internal_api.js @@ -37,20 +37,17 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea const errorPromises = Module.jsepSessionState.errors; Module.jsepSessionState = null; - if (errorPromises.length > 0) { - const runPromise = Module['jsepRunPromise']; - Module['jsepRunPromise'] = new Promise((resolve, reject) => { - Promise.all(errorPromises).then(errors => { - errors = errors.filter(e => e); - if (errors.length > 0) { - reject(new Error(errors.join('\n'))); - } else { - resolve(runPromise); - } - }, reason => { - reject(reason); - }); + return errorPromises.length === 0 ? Promise.resolve() : new Promise((resolve, reject) => { + Promise.all(errorPromises).then(errors => { + errors = errors.filter(e => e); + if (errors.length > 0) { + reject(new Error(errors.join('\n'))); + } else { + resolve(); + } + }, reason => { + reject(reason); }); - } + }); }; };