[js/webgpu] fix jsepOnRunEnd (#17300)

### Description
fix jsepOnRunEnd: jsepOnRunEnd() need to be run after runPromise is
resolved.
This commit is contained in:
Yulong Wang 2023-08-26 00:30:28 -07:00 committed by GitHub
parent 808215366d
commit ddcd46174e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 31 additions and 18 deletions

View file

@ -78,7 +78,7 @@ export interface OrtWasmModule extends EmscriptenModule {
_JsepGetNodeName(kernel: number): number;
jsepOnRunStart?(sessionId: number): void;
jsepOnRunEnd?(sessionId: number): void;
jsepOnRunEnd?(sessionId: number): Promise<void>;
jsepRunPromise?: Promise<number>;
// #endregion
}

View file

@ -258,6 +258,7 @@ export const run = async(
wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]];
}
// jsepOnRunStart is only available when JSEP is enabled.
wasm.jsepOnRunStart?.(sessionId);
// support RunOptions
@ -265,13 +266,28 @@ export const run = async(
sessionHandle, inputNamesOffset, inputValuesOffset, inputCount, outputNamesOffset, outputCount,
outputValuesOffset, runOptionsHandle);
wasm.jsepOnRunEnd?.(sessionId);
const runPromise = wasm.jsepRunPromise;
if (runPromise && typeof runPromise.then !== 'undefined') {
if (runPromise) {
// jsepRunPromise is a Promise object. It is only available when JSEP is enabled.
//
// OrtRun() is a synchrnous call, but it internally calls async functions. Emscripten's ASYNCIFY allows it to
// work in this way. However, OrtRun() does not return a promise, so when code reaches here, it is earlier than
// the async functions are finished.
//
// To make it work, we created a Promise and resolve the promise when the C++ code actually reaches the end of
// OrtRun(). If the promise exists, we need to await for the promise to be resolved.
errorCode = await runPromise;
}
const jsepOnRunEnd = wasm.jsepOnRunEnd;
if (jsepOnRunEnd) {
// jsepOnRunEnd is only available when JSEP is enabled.
//
// it returns a promise, which is resolved or rejected when the following async functions are finished:
// - collecting GPU validation errors.
await jsepOnRunEnd(sessionId);
}
const output: SerializableTensor[] = [];
if (errorCode !== 0) {

View file

@ -37,20 +37,17 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea
const errorPromises = Module.jsepSessionState.errors;
Module.jsepSessionState = null;
if (errorPromises.length > 0) {
const runPromise = Module['jsepRunPromise'];
Module['jsepRunPromise'] = new Promise((resolve, reject) => {
Promise.all(errorPromises).then(errors => {
errors = errors.filter(e => e);
if (errors.length > 0) {
reject(new Error(errors.join('\n')));
} else {
resolve(runPromise);
}
}, reason => {
reject(reason);
});
return errorPromises.length === 0 ? Promise.resolve() : new Promise((resolve, reject) => {
Promise.all(errorPromises).then(errors => {
errors = errors.filter(e => e);
if (errors.length > 0) {
reject(new Error(errors.join('\n')));
} else {
resolve();
}
}, reason => {
reject(reason);
});
}
});
};
};