[js/webgpu] support error pop and kernel name (#17260)

### Description
This PR contains changes to support error pop and kernel name.

- Add a function `JsepGetNodeName` to allow reading kernel name from JS
to C++
- When in debug mode ( `env.debug = true;` ) or in profiling mode (
`env.webgpu.profilingMode = 'default';` ), kernel name will be read from
ORT; otherwise use the kernel pointer ( a number ) as kernel name to
save calls from JS to C++.
- When in debug mode, WebGPU validation errors will be recorded and if
any error occurs, `inferenceSession.run()` will fail (Promise get
rejected). Behavior when not in debug mode is not changed. This is
because recording errors are not zero-overhead, and GPU validation
errors should occur consistently in and not in debug mode.
- Add `jsepOnRunStart()` and `jsepOnRunEnd()` hook to:
   - allow implementation of the features mentioned above.
   - pass session ID to backend.
This commit is contained in:
Yulong Wang 2023-08-25 08:08:15 -07:00 committed by GitHub
parent da180b20fa
commit 79c4ed9a45
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 112 additions and 40 deletions

View file

@ -206,7 +206,7 @@ else()
set(EXPORTED_RUNTIME_METHODS "['stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8']")
if (onnxruntime_USE_JSEP)
set(EXPORTED_FUNCTIONS "_malloc,_free,_JsepOutput")
set(EXPORTED_FUNCTIONS "_malloc,_free,_JsepOutput,_JsepGetNodeName")
else()
set(EXPORTED_FUNCTIONS "_malloc,_free")
endif()

View file

@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
declare namespace JSEP {
export declare namespace JSEP {
type BackendType = unknown;
type AllocFunction = (size: number) => number;
type FreeFunction = (size: number) => number;
@ -9,7 +9,11 @@ declare namespace JSEP {
type DownloadFunction = (gpuDataId: number, dataOffset: number, size: number) => Promise<void>;
type CreateKernelFunction = (name: string, kernel: number, attribute: unknown) => void;
type ReleaseKernelFunction = (kernel: number) => void;
type RunFunction = (kernel: number, contextDataOffset: number) => number;
type RunFunction = (kernel: number, contextDataOffset: number, sessionState: SessionState) => number;
export interface SessionState {
sessionId: number;
errors: Array<Promise<string|null>>;
}
}
export interface OrtWasmModule extends EmscriptenModule {
@ -71,7 +75,10 @@ export interface OrtWasmModule extends EmscriptenModule {
releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction): void;
_JsepOutput(context: number, index: number, data: number): number;
_JsepGetNodeName(kernel: number): number;
jsepOnRunStart?(sessionId: number): void;
jsepOnRunEnd?(sessionId: number): void;
jsepRunPromise?: Promise<number>;
// #endregion
}

View file

@ -82,9 +82,10 @@ export class WebGpuBackend {
}
/**
* a KernelID -> kernel info mapping. value is [ name, run function, [optional] preprocess_attribute_once function ]
* a KernelID -> kernel info mapping. value is
* [ op_type, name, run function, [optional] preprocess_attribute_once function ]
*/
kernels: Map<number, [string, RunFunction, [((attribute: unknown) => unknown) | undefined, unknown]]>;
kernels: Map<number, [string, string, RunFunction, [((attribute: unknown) => unknown) | undefined, unknown]]>;
commandEncoder: GPUCommandEncoder|null = null;
computePassEncoder: GPUComputePassEncoder|null = null;
@ -313,13 +314,13 @@ export class WebGpuBackend {
return this.gpuDataManager.release(ptr);
}
createKernel(name: string, kernelId: number, attribute: unknown): void {
const op = WEBGPU_OP_RESOLVE_RULES.get(name);
createKernel(opType: string, kernelId: number, attribute: unknown, nodeName: string): void {
const op = WEBGPU_OP_RESOLVE_RULES.get(opType);
if (!op) {
throw new Error(`kernel not implemented: ${name}`);
throw new Error(`kernel not implemented: ${opType}`);
}
this.kernels.set(kernelId, [name, op[0], [op[1], attribute]]);
this.kernels.set(kernelId, [opType, nodeName, op[0], [op[1], attribute]]);
}
releaseKernel(kernelId: number): void {
@ -335,14 +336,14 @@ export class WebGpuBackend {
this.kernels.delete(kernelId);
}
computeKernel(kernelId: number, context: ComputeContext): number {
computeKernel(kernelId: number, context: ComputeContext, errors: Array<Promise<string|null>>): number {
const kernel = this.kernels.get(kernelId);
if (!kernel) {
throw new Error(`kernel not created: ${kernelId}`);
}
const [name, kernelEntry, attributes] = kernel;
const [opType, nodeName, kernelEntry, attributes] = kernel;
if (this.currentKernelId !== null) {
throw new Error(`kernel "${name}" is not allowed to be called recursively`);
throw new Error(`kernel "[${opType}] ${nodeName}" is not allowed to be called recursively`);
}
this.currentKernelId = kernelId;
@ -352,16 +353,27 @@ export class WebGpuBackend {
attributes[0] = undefined;
}
LOG_DEBUG('info', () => `[WebGPU] Start to run kernel "${name}"...`);
LOG_DEBUG('info', () => `[WebGPU] Start to run kernel "[${opType}] ${nodeName}"...`);
const useErrorScope = this.env.debug;
this.temporaryData = [];
try {
if (useErrorScope) {
this.device.pushErrorScope('validation');
}
kernelEntry(context, attributes[1]);
return 0; // ORT_OK
} catch (e) {
LOG_DEBUG('warning', `[WebGPU] Kernel "${name}" failed. Error: ${e}`);
LOG_DEBUG('warning', `[WebGPU] Kernel "[${opType}] ${nodeName}" failed. Error: ${e}`);
return 1; // ORT_FAIL
} finally {
if (useErrorScope) {
errors.push(this.device.popErrorScope().then(
err => err ? `GPU validation error for kernel "[${opType}] ${nodeName}": ${err.message}` : null));
}
for (const data of this.temporaryData) {
this.gpuDataManager.release(data.id);
}

View file

@ -3,7 +3,7 @@
import {Env} from 'onnxruntime-common';
import {OrtWasmModule} from '../binding/ort-wasm';
import {JSEP, OrtWasmModule} from '../binding/ort-wasm';
import {DataType, getTensorElementSize} from '../wasm-common';
import {WebGpuBackend} from './backend-webgpu';
@ -169,16 +169,22 @@ export const init = async(module: OrtWasmModule, env: Env): Promise<void> => {
},
// jsepCreateKernel
(name: string, kernel: number, attribute: unknown) => backend.createKernel(name, kernel, attribute),
(name: string, kernel: number, attribute: unknown) => backend.createKernel(
name, kernel, attribute,
env.debug || env.webgpu.profilingMode === 'default' ? module.UTF8ToString(module._JsepGetNodeName(kernel)) :
`${kernel}`),
// jsepReleaseKernel
(kernel: number) => backend.releaseKernel(kernel),
// jsepRun
(kernel: number, contextDataOffset: number) => {
LOG_DEBUG('verbose', () => `[WebGPU] jsepRun: kernel=${kernel}, contextDataOffset=${contextDataOffset}`);
(kernel: number, contextDataOffset: number, sessionState: JSEP.SessionState) => {
LOG_DEBUG(
'verbose',
() => `[WebGPU] jsepRun: sessionId=${sessionState.sessionId}, kernel=${kernel}, contextDataOffset=${
contextDataOffset}`);
const context = new ComputeContextImpl(module, backend, contextDataOffset);
return backend.computeKernel(kernel, context);
return backend.computeKernel(kernel, context, sessionState.errors);
});
}
};

View file

@ -77,7 +77,8 @@ export class ProgramManager {
this.backend.flush();
const kernelId = this.backend.currentKernelId!;
const kernelName = this.backend.kernels.get(kernelId)![0];
const kernelInfo = this.backend.kernels.get(kernelId)!;
const kernelName = `[${kernelInfo[0]}] ${kernelInfo[1]}`;
syncData.buffer.mapAsync(GPUMapMode.READ).then(() => {
const mappedData = new BigUint64Array(syncData.buffer.getMappedRange());

View file

@ -258,12 +258,15 @@ export const run = async(
wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]];
}
wasm.jsepOnRunStart?.(sessionId);
// support RunOptions
let errorCode = wasm._OrtRun(
sessionHandle, inputNamesOffset, inputValuesOffset, inputCount, outputNamesOffset, outputCount,
outputValuesOffset, runOptionsHandle);
// eslint-disable-next-line @typescript-eslint/naming-convention
wasm.jsepOnRunEnd?.(sessionId);
const runPromise = wasm.jsepRunPromise;
if (runPromise && typeof runPromise.then !== 'undefined') {
errorCode = await runPromise;

View file

@ -5,8 +5,8 @@
#include "core/framework/op_kernel.h"
const void* JsepOutput(void* context, int index, void* data) {
uint32_t* data_offset = reinterpret_cast<uint32_t*>(data);
const void* JsepOutput(void* context, int index, const void* data) {
const uint32_t* data_offset = reinterpret_cast<const uint32_t*>(data);
uint32_t dim = *data_offset++;
size_t dim_size = static_cast<size_t>(dim);
std::vector<int64_t> dims;
@ -24,3 +24,8 @@ const void* JsepOutput(void* context, int index, void* data) {
LOGF_DEFAULT(VERBOSE, "JsepOutput -- data=%zu", (size_t)(r));
return r;
}
const void* JsepGetNodeName(const void* kernel) {
const auto& name = reinterpret_cast<const onnxruntime::OpKernel*>(kernel)->Node().Name();
return name.c_str();
}

View file

@ -7,8 +7,7 @@
#include <stddef.h>
// TODO: Move to api.h
extern "C" {
const void* EMSCRIPTEN_KEEPALIVE JsepOutput(void* context, int index, void* data);
const void* EMSCRIPTEN_KEEPALIVE JsepOutput(void* context, int index, const void* data);
const void* EMSCRIPTEN_KEEPALIVE JsepGetNodeName(const void* context);
};

View file

@ -194,7 +194,9 @@ class JsKernel : public OpKernel {
return status;
}
int status_code = EM_ASM_INT({ return Module.jsepRun($0, $1); }, this, reinterpret_cast<int32_t>(p_serialized_kernel_context));
int status_code = EM_ASM_INT(
{ return Module.jsepRunKernel($0, $1, Module.jsepSessionState); },
this, reinterpret_cast<int32_t>(p_serialized_kernel_context));
LOGS_DEFAULT(VERBOSE) << "outputs = " << context->OutputCount() << ". Y.data="
<< (size_t)(context->Output<Tensor>(0)->DataRaw()) << ".";

View file

@ -368,12 +368,9 @@ int OrtRun(OrtSession* session,
const char** input_names, const ort_tensor_handle_t* inputs, size_t input_count,
const char** output_names, size_t output_count, ort_tensor_handle_t* outputs,
OrtRunOptions* run_options) {
#if defined(USE_JSEP)
EM_ASM({ Module["jsepRunPromise"] = new Promise(function(r) { Module.jsepRunPromiseResolve = r; }); });
#endif
auto status_code = CHECK_STATUS(Run, session, run_options, input_names, inputs, input_count, output_names, output_count, outputs);
#if defined(USE_JSEP)
EM_ASM({ Module.jsepRunPromiseResolve($0); }, status_code);
EM_ASM({ Module.jsepRunPromiseResolve ?.($0); }, status_code);
#endif
return status_code;
}

View file

@ -4,13 +4,53 @@
'use strict';
// init JSEP
Module["jsepInit"] = function (backend, alloc, free, copy, copyAsync, createKernel, releaseKernel, run) {
Module.jsepBackend = backend;
Module.jsepAlloc = alloc;
Module.jsepFree = free;
Module.jsepCopy = copy;
Module.jsepCopyAsync = copyAsync;
Module.jsepCreateKernel = createKernel;
Module.jsepReleaseKernel = releaseKernel;
Module.jsepRun = run;
Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, releaseKernel, runKernel) => {
Module.jsepBackend = backend;
Module.jsepAlloc = alloc;
Module.jsepFree = free;
Module.jsepCopy = copy;
Module.jsepCopyAsync = copyAsync;
Module.jsepCreateKernel = createKernel;
Module.jsepReleaseKernel = releaseKernel;
Module.jsepRunKernel = runKernel;
Module['jsepOnRunStart'] = sessionId => {
Module['jsepRunPromise'] = new Promise(r => {
Module.jsepRunPromiseResolve = r;
});
if (Module.jsepSessionState) {
throw new Error('Session already started');
}
Module.jsepSessionState = {
sessionId,
errors: []
};
};
Module['jsepOnRunEnd'] = sessionId => {
if (Module.jsepSessionState.sessionId !== sessionId) {
throw new Error('Session ID mismatch');
}
const errorPromises = Module.jsepSessionState.errors;
Module.jsepSessionState = null;
if (errorPromises.length > 0) {
const runPromise = Module['jsepRunPromise'];
Module['jsepRunPromise'] = new Promise((resolve, reject) => {
Promise.all(errorPromises).then(errors => {
errors = errors.filter(e => e);
if (errors.length > 0) {
reject(new Error(errors.join('\n')));
} else {
resolve(runPromise);
}
}, reason => {
reject(reason);
});
});
}
};
};