mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
[js/webgpu] support error pop and kernel name (#17260)
### Description This PR contains changes to support error pop and kernel name. - Add a function `JsepGetNodeName` to allow reading kernel name from JS to C++ - When in debug mode ( `env.debug = true;` ) or in profiling mode ( `env.webgpu.profilingMode = 'default';` ), kernel name will be read from ORT; otherwise use the kernel pointer ( a number ) as kernel name to save calls from JS to C++. - When in debug mode, WebGPU validation errors will be recorded and if any error occurs, `inferenceSession.run()` will fail (Promise get rejected). Behavior when not in debug mode is not changed. This is because recording errors are not zero-overhead, and GPU validation errors should occur consistently in and not in debug mode. - Add `jsepOnRunStart()` and `jsepOnRunEnd()` hook to: - allow implementation of the features mentioned above. - pass session ID to backend.
This commit is contained in:
parent
da180b20fa
commit
79c4ed9a45
11 changed files with 112 additions and 40 deletions
|
|
@ -206,7 +206,7 @@ else()
|
|||
|
||||
set(EXPORTED_RUNTIME_METHODS "['stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8']")
|
||||
if (onnxruntime_USE_JSEP)
|
||||
set(EXPORTED_FUNCTIONS "_malloc,_free,_JsepOutput")
|
||||
set(EXPORTED_FUNCTIONS "_malloc,_free,_JsepOutput,_JsepGetNodeName")
|
||||
else()
|
||||
set(EXPORTED_FUNCTIONS "_malloc,_free")
|
||||
endif()
|
||||
|
|
|
|||
11
js/web/lib/wasm/binding/ort-wasm.d.ts
vendored
11
js/web/lib/wasm/binding/ort-wasm.d.ts
vendored
|
|
@ -1,7 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
declare namespace JSEP {
|
||||
export declare namespace JSEP {
|
||||
type BackendType = unknown;
|
||||
type AllocFunction = (size: number) => number;
|
||||
type FreeFunction = (size: number) => number;
|
||||
|
|
@ -9,7 +9,11 @@ declare namespace JSEP {
|
|||
type DownloadFunction = (gpuDataId: number, dataOffset: number, size: number) => Promise<void>;
|
||||
type CreateKernelFunction = (name: string, kernel: number, attribute: unknown) => void;
|
||||
type ReleaseKernelFunction = (kernel: number) => void;
|
||||
type RunFunction = (kernel: number, contextDataOffset: number) => number;
|
||||
type RunFunction = (kernel: number, contextDataOffset: number, sessionState: SessionState) => number;
|
||||
export interface SessionState {
|
||||
sessionId: number;
|
||||
errors: Array<Promise<string|null>>;
|
||||
}
|
||||
}
|
||||
|
||||
export interface OrtWasmModule extends EmscriptenModule {
|
||||
|
|
@ -71,7 +75,10 @@ export interface OrtWasmModule extends EmscriptenModule {
|
|||
releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction): void;
|
||||
|
||||
_JsepOutput(context: number, index: number, data: number): number;
|
||||
_JsepGetNodeName(kernel: number): number;
|
||||
|
||||
jsepOnRunStart?(sessionId: number): void;
|
||||
jsepOnRunEnd?(sessionId: number): void;
|
||||
jsepRunPromise?: Promise<number>;
|
||||
// #endregion
|
||||
}
|
||||
|
|
|
|||
|
|
@ -82,9 +82,10 @@ export class WebGpuBackend {
|
|||
}
|
||||
|
||||
/**
|
||||
* a KernelID -> kernel info mapping. value is [ name, run function, [optional] preprocess_attribute_once function ]
|
||||
* a KernelID -> kernel info mapping. value is
|
||||
* [ op_type, name, run function, [optional] preprocess_attribute_once function ]
|
||||
*/
|
||||
kernels: Map<number, [string, RunFunction, [((attribute: unknown) => unknown) | undefined, unknown]]>;
|
||||
kernels: Map<number, [string, string, RunFunction, [((attribute: unknown) => unknown) | undefined, unknown]]>;
|
||||
|
||||
commandEncoder: GPUCommandEncoder|null = null;
|
||||
computePassEncoder: GPUComputePassEncoder|null = null;
|
||||
|
|
@ -313,13 +314,13 @@ export class WebGpuBackend {
|
|||
return this.gpuDataManager.release(ptr);
|
||||
}
|
||||
|
||||
createKernel(name: string, kernelId: number, attribute: unknown): void {
|
||||
const op = WEBGPU_OP_RESOLVE_RULES.get(name);
|
||||
createKernel(opType: string, kernelId: number, attribute: unknown, nodeName: string): void {
|
||||
const op = WEBGPU_OP_RESOLVE_RULES.get(opType);
|
||||
if (!op) {
|
||||
throw new Error(`kernel not implemented: ${name}`);
|
||||
throw new Error(`kernel not implemented: ${opType}`);
|
||||
}
|
||||
|
||||
this.kernels.set(kernelId, [name, op[0], [op[1], attribute]]);
|
||||
this.kernels.set(kernelId, [opType, nodeName, op[0], [op[1], attribute]]);
|
||||
}
|
||||
|
||||
releaseKernel(kernelId: number): void {
|
||||
|
|
@ -335,14 +336,14 @@ export class WebGpuBackend {
|
|||
this.kernels.delete(kernelId);
|
||||
}
|
||||
|
||||
computeKernel(kernelId: number, context: ComputeContext): number {
|
||||
computeKernel(kernelId: number, context: ComputeContext, errors: Array<Promise<string|null>>): number {
|
||||
const kernel = this.kernels.get(kernelId);
|
||||
if (!kernel) {
|
||||
throw new Error(`kernel not created: ${kernelId}`);
|
||||
}
|
||||
const [name, kernelEntry, attributes] = kernel;
|
||||
const [opType, nodeName, kernelEntry, attributes] = kernel;
|
||||
if (this.currentKernelId !== null) {
|
||||
throw new Error(`kernel "${name}" is not allowed to be called recursively`);
|
||||
throw new Error(`kernel "[${opType}] ${nodeName}" is not allowed to be called recursively`);
|
||||
}
|
||||
this.currentKernelId = kernelId;
|
||||
|
||||
|
|
@ -352,16 +353,27 @@ export class WebGpuBackend {
|
|||
attributes[0] = undefined;
|
||||
}
|
||||
|
||||
LOG_DEBUG('info', () => `[WebGPU] Start to run kernel "${name}"...`);
|
||||
LOG_DEBUG('info', () => `[WebGPU] Start to run kernel "[${opType}] ${nodeName}"...`);
|
||||
|
||||
const useErrorScope = this.env.debug;
|
||||
|
||||
this.temporaryData = [];
|
||||
try {
|
||||
if (useErrorScope) {
|
||||
this.device.pushErrorScope('validation');
|
||||
}
|
||||
|
||||
kernelEntry(context, attributes[1]);
|
||||
return 0; // ORT_OK
|
||||
} catch (e) {
|
||||
LOG_DEBUG('warning', `[WebGPU] Kernel "${name}" failed. Error: ${e}`);
|
||||
LOG_DEBUG('warning', `[WebGPU] Kernel "[${opType}] ${nodeName}" failed. Error: ${e}`);
|
||||
return 1; // ORT_FAIL
|
||||
} finally {
|
||||
if (useErrorScope) {
|
||||
errors.push(this.device.popErrorScope().then(
|
||||
err => err ? `GPU validation error for kernel "[${opType}] ${nodeName}": ${err.message}` : null));
|
||||
}
|
||||
|
||||
for (const data of this.temporaryData) {
|
||||
this.gpuDataManager.release(data.id);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
import {Env} from 'onnxruntime-common';
|
||||
|
||||
import {OrtWasmModule} from '../binding/ort-wasm';
|
||||
import {JSEP, OrtWasmModule} from '../binding/ort-wasm';
|
||||
import {DataType, getTensorElementSize} from '../wasm-common';
|
||||
|
||||
import {WebGpuBackend} from './backend-webgpu';
|
||||
|
|
@ -169,16 +169,22 @@ export const init = async(module: OrtWasmModule, env: Env): Promise<void> => {
|
|||
},
|
||||
|
||||
// jsepCreateKernel
|
||||
(name: string, kernel: number, attribute: unknown) => backend.createKernel(name, kernel, attribute),
|
||||
(name: string, kernel: number, attribute: unknown) => backend.createKernel(
|
||||
name, kernel, attribute,
|
||||
env.debug || env.webgpu.profilingMode === 'default' ? module.UTF8ToString(module._JsepGetNodeName(kernel)) :
|
||||
`${kernel}`),
|
||||
|
||||
// jsepReleaseKernel
|
||||
(kernel: number) => backend.releaseKernel(kernel),
|
||||
|
||||
// jsepRun
|
||||
(kernel: number, contextDataOffset: number) => {
|
||||
LOG_DEBUG('verbose', () => `[WebGPU] jsepRun: kernel=${kernel}, contextDataOffset=${contextDataOffset}`);
|
||||
(kernel: number, contextDataOffset: number, sessionState: JSEP.SessionState) => {
|
||||
LOG_DEBUG(
|
||||
'verbose',
|
||||
() => `[WebGPU] jsepRun: sessionId=${sessionState.sessionId}, kernel=${kernel}, contextDataOffset=${
|
||||
contextDataOffset}`);
|
||||
const context = new ComputeContextImpl(module, backend, contextDataOffset);
|
||||
return backend.computeKernel(kernel, context);
|
||||
return backend.computeKernel(kernel, context, sessionState.errors);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -77,7 +77,8 @@ export class ProgramManager {
|
|||
this.backend.flush();
|
||||
|
||||
const kernelId = this.backend.currentKernelId!;
|
||||
const kernelName = this.backend.kernels.get(kernelId)![0];
|
||||
const kernelInfo = this.backend.kernels.get(kernelId)!;
|
||||
const kernelName = `[${kernelInfo[0]}] ${kernelInfo[1]}`;
|
||||
|
||||
syncData.buffer.mapAsync(GPUMapMode.READ).then(() => {
|
||||
const mappedData = new BigUint64Array(syncData.buffer.getMappedRange());
|
||||
|
|
|
|||
|
|
@ -258,12 +258,15 @@ export const run = async(
|
|||
wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]];
|
||||
}
|
||||
|
||||
wasm.jsepOnRunStart?.(sessionId);
|
||||
|
||||
// support RunOptions
|
||||
let errorCode = wasm._OrtRun(
|
||||
sessionHandle, inputNamesOffset, inputValuesOffset, inputCount, outputNamesOffset, outputCount,
|
||||
outputValuesOffset, runOptionsHandle);
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/naming-convention
|
||||
wasm.jsepOnRunEnd?.(sessionId);
|
||||
|
||||
const runPromise = wasm.jsepRunPromise;
|
||||
if (runPromise && typeof runPromise.then !== 'undefined') {
|
||||
errorCode = await runPromise;
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@
|
|||
|
||||
#include "core/framework/op_kernel.h"
|
||||
|
||||
const void* JsepOutput(void* context, int index, void* data) {
|
||||
uint32_t* data_offset = reinterpret_cast<uint32_t*>(data);
|
||||
const void* JsepOutput(void* context, int index, const void* data) {
|
||||
const uint32_t* data_offset = reinterpret_cast<const uint32_t*>(data);
|
||||
uint32_t dim = *data_offset++;
|
||||
size_t dim_size = static_cast<size_t>(dim);
|
||||
std::vector<int64_t> dims;
|
||||
|
|
@ -24,3 +24,8 @@ const void* JsepOutput(void* context, int index, void* data) {
|
|||
LOGF_DEFAULT(VERBOSE, "JsepOutput -- data=%zu", (size_t)(r));
|
||||
return r;
|
||||
}
|
||||
|
||||
const void* JsepGetNodeName(const void* kernel) {
|
||||
const auto& name = reinterpret_cast<const onnxruntime::OpKernel*>(kernel)->Node().Name();
|
||||
return name.c_str();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,8 +7,7 @@
|
|||
|
||||
#include <stddef.h>
|
||||
|
||||
// TODO: Move to api.h
|
||||
|
||||
extern "C" {
|
||||
const void* EMSCRIPTEN_KEEPALIVE JsepOutput(void* context, int index, void* data);
|
||||
const void* EMSCRIPTEN_KEEPALIVE JsepOutput(void* context, int index, const void* data);
|
||||
const void* EMSCRIPTEN_KEEPALIVE JsepGetNodeName(const void* context);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -194,7 +194,9 @@ class JsKernel : public OpKernel {
|
|||
return status;
|
||||
}
|
||||
|
||||
int status_code = EM_ASM_INT({ return Module.jsepRun($0, $1); }, this, reinterpret_cast<int32_t>(p_serialized_kernel_context));
|
||||
int status_code = EM_ASM_INT(
|
||||
{ return Module.jsepRunKernel($0, $1, Module.jsepSessionState); },
|
||||
this, reinterpret_cast<int32_t>(p_serialized_kernel_context));
|
||||
|
||||
LOGS_DEFAULT(VERBOSE) << "outputs = " << context->OutputCount() << ". Y.data="
|
||||
<< (size_t)(context->Output<Tensor>(0)->DataRaw()) << ".";
|
||||
|
|
|
|||
|
|
@ -368,12 +368,9 @@ int OrtRun(OrtSession* session,
|
|||
const char** input_names, const ort_tensor_handle_t* inputs, size_t input_count,
|
||||
const char** output_names, size_t output_count, ort_tensor_handle_t* outputs,
|
||||
OrtRunOptions* run_options) {
|
||||
#if defined(USE_JSEP)
|
||||
EM_ASM({ Module["jsepRunPromise"] = new Promise(function(r) { Module.jsepRunPromiseResolve = r; }); });
|
||||
#endif
|
||||
auto status_code = CHECK_STATUS(Run, session, run_options, input_names, inputs, input_count, output_names, output_count, outputs);
|
||||
#if defined(USE_JSEP)
|
||||
EM_ASM({ Module.jsepRunPromiseResolve($0); }, status_code);
|
||||
EM_ASM({ Module.jsepRunPromiseResolve ?.($0); }, status_code);
|
||||
#endif
|
||||
return status_code;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,13 +4,53 @@
|
|||
'use strict';
|
||||
|
||||
// init JSEP
|
||||
Module["jsepInit"] = function (backend, alloc, free, copy, copyAsync, createKernel, releaseKernel, run) {
|
||||
Module.jsepBackend = backend;
|
||||
Module.jsepAlloc = alloc;
|
||||
Module.jsepFree = free;
|
||||
Module.jsepCopy = copy;
|
||||
Module.jsepCopyAsync = copyAsync;
|
||||
Module.jsepCreateKernel = createKernel;
|
||||
Module.jsepReleaseKernel = releaseKernel;
|
||||
Module.jsepRun = run;
|
||||
Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, releaseKernel, runKernel) => {
|
||||
Module.jsepBackend = backend;
|
||||
Module.jsepAlloc = alloc;
|
||||
Module.jsepFree = free;
|
||||
Module.jsepCopy = copy;
|
||||
Module.jsepCopyAsync = copyAsync;
|
||||
Module.jsepCreateKernel = createKernel;
|
||||
Module.jsepReleaseKernel = releaseKernel;
|
||||
Module.jsepRunKernel = runKernel;
|
||||
|
||||
Module['jsepOnRunStart'] = sessionId => {
|
||||
Module['jsepRunPromise'] = new Promise(r => {
|
||||
Module.jsepRunPromiseResolve = r;
|
||||
});
|
||||
|
||||
if (Module.jsepSessionState) {
|
||||
throw new Error('Session already started');
|
||||
}
|
||||
|
||||
Module.jsepSessionState = {
|
||||
sessionId,
|
||||
errors: []
|
||||
};
|
||||
};
|
||||
|
||||
Module['jsepOnRunEnd'] = sessionId => {
|
||||
if (Module.jsepSessionState.sessionId !== sessionId) {
|
||||
throw new Error('Session ID mismatch');
|
||||
}
|
||||
|
||||
const errorPromises = Module.jsepSessionState.errors;
|
||||
Module.jsepSessionState = null;
|
||||
|
||||
if (errorPromises.length > 0) {
|
||||
const runPromise = Module['jsepRunPromise'];
|
||||
Module['jsepRunPromise'] = new Promise((resolve, reject) => {
|
||||
Promise.all(errorPromises).then(errors => {
|
||||
errors = errors.filter(e => e);
|
||||
if (errors.length > 0) {
|
||||
reject(new Error(errors.join('\n')));
|
||||
} else {
|
||||
resolve(runPromise);
|
||||
}
|
||||
}, reason => {
|
||||
reject(reason);
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in a new issue