mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
[js/webgpu] allow to specify callback for profiling data (#18732)
### Description
**This PR is a replacement of #17820.**
allow to specify callback for profiling data
*Previous*:
```js
ort.env.webgpu.profilingMode = 'default'; // enable profiling
// profiling data will output to console.
```
*Now*:
```js
ort.env.webgpu.profiling = {
mode: 'default'; // enable profiling
ondata: (data) => {
// .. process the profiling data
}
};
//for each kernel, "ondata" will be called once. only output to console if ondata is not specified.
```
This commit is contained in:
parent
4abec9749e
commit
efbef5f611
5 changed files with 71 additions and 22 deletions
|
|
@ -92,11 +92,48 @@ export declare namespace Env {
|
|||
async?: boolean;
|
||||
}
|
||||
|
||||
export interface WebGpuProfilingDataV1TensorMetadata {
|
||||
dims: readonly number[];
|
||||
dataType: string;
|
||||
}
|
||||
export interface WebGpuProfilingDataV1 {
|
||||
version: 1;
|
||||
inputsMetadata: readonly WebGpuProfilingDataV1TensorMetadata[];
|
||||
outputsMetadata: readonly WebGpuProfilingDataV1TensorMetadata[];
|
||||
kernelId: number;
|
||||
kernelType: string;
|
||||
kernelName: string;
|
||||
startTime: number;
|
||||
endTime: number;
|
||||
}
|
||||
|
||||
export type WebGpuProfilingData = WebGpuProfilingDataV1;
|
||||
|
||||
export interface WebGpuFlags {
|
||||
/**
|
||||
* Set or get the profiling mode.
|
||||
*
|
||||
* @deprecated Use `env.webgpu.profiling.mode` instead. If `env.webgpu.profiling.mode` is set, this property will be
|
||||
* ignored.
|
||||
*/
|
||||
profilingMode?: 'off'|'default';
|
||||
/**
|
||||
* Set or get the profiling configuration.
|
||||
*/
|
||||
profiling?: {
|
||||
/**
|
||||
* Set or get the profiling mode.
|
||||
*
|
||||
* @defaultValue `'off'`
|
||||
*/
|
||||
mode?: 'off'|'default';
|
||||
|
||||
/**
|
||||
* Set or get a callback function when a profiling data is received. If not set, the profiling data will be
|
||||
* printed to console.
|
||||
*/
|
||||
ondata?: (data: WebGpuProfilingData) => void;
|
||||
};
|
||||
/**
|
||||
* Get the device for WebGPU.
|
||||
*
|
||||
|
|
|
|||
|
|
@ -254,11 +254,9 @@ export class WebGpuBackend {
|
|||
}
|
||||
|
||||
isQueryEnabled(): boolean {
|
||||
if (this.device.features.has('timestamp-query') && this.env.webgpu.profilingMode === 'default') {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
return this.device.features.has('timestamp-query') &&
|
||||
(this.env.webgpu.profiling?.mode === 'default' ||
|
||||
(!this.env.webgpu.profiling?.mode && this.env.webgpu.profilingMode === 'default'));
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -175,8 +175,7 @@ export const init = async(module: OrtWasmModule, env: Env): Promise<void> => {
|
|||
// jsepCreateKernel
|
||||
(name: string, kernel: number, attribute: unknown) => backend.createKernel(
|
||||
name, kernel, attribute,
|
||||
env.debug || env.webgpu.profilingMode === 'default' ? module.UTF8ToString(module._JsepGetNodeName(kernel)) :
|
||||
`${kernel}`),
|
||||
env.debug || backend.isQueryEnabled() ? module.UTF8ToString(module._JsepGetNodeName(kernel)) : `${kernel}`),
|
||||
|
||||
// jsepReleaseKernel
|
||||
(kernel: number) => backend.releaseKernel(kernel),
|
||||
|
|
|
|||
|
|
@ -75,12 +75,11 @@ export class ProgramManager {
|
|||
|
||||
const kernelId = this.backend.currentKernelId!;
|
||||
const kernelInfo = this.backend.kernels.get(kernelId)!;
|
||||
const kernelName = `[${kernelInfo[0]}] ${kernelInfo[1]}`;
|
||||
|
||||
void syncData.buffer.mapAsync(GPUMapMode.READ).then(() => {
|
||||
const mappedData = new BigUint64Array(syncData.buffer.getMappedRange());
|
||||
const startTimeU64 = mappedData[0];
|
||||
const endTimeU64 = mappedData[1];
|
||||
const [startTimeU64, endTimeU64] = mappedData;
|
||||
const [kernelType, kernelName] = kernelInfo;
|
||||
|
||||
syncData.buffer.unmap();
|
||||
|
||||
|
|
@ -96,17 +95,33 @@ export class ProgramManager {
|
|||
}
|
||||
|
||||
this.backend.gpuDataManager.release(syncData.id);
|
||||
let inputShapes = '';
|
||||
inputTensorViews.forEach((value, i) => {
|
||||
inputShapes += `input[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `;
|
||||
});
|
||||
let outputShapes = '';
|
||||
outputTensorViews.forEach((value, i) => {
|
||||
outputShapes += `output[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `;
|
||||
});
|
||||
// eslint-disable-next-line no-console
|
||||
console.log(`[profiling] kernel "${kernelId}|${kernelName}|${buildArtifact.programInfo.name}" ${inputShapes}${
|
||||
outputShapes}execution time: ${endTime - startTime} ns`);
|
||||
if (this.backend.env.webgpu.profiling?.ondata) {
|
||||
this.backend.env.webgpu.profiling.ondata({
|
||||
version: 1,
|
||||
inputsMetadata: inputTensorViews.map(
|
||||
value => ({dims: value.dims, dataType: tensorDataTypeEnumToString(value.dataType)})),
|
||||
outputsMetadata: outputTensorViews.map(
|
||||
value => ({dims: value.dims, dataType: tensorDataTypeEnumToString(value.dataType)})),
|
||||
kernelId,
|
||||
kernelType,
|
||||
kernelName,
|
||||
startTime,
|
||||
endTime,
|
||||
});
|
||||
} else {
|
||||
// if no callback is provided, print the profiling message to console
|
||||
let inputShapes = '';
|
||||
inputTensorViews.forEach((value, i) => {
|
||||
inputShapes += `input[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `;
|
||||
});
|
||||
let outputShapes = '';
|
||||
inputTensorViews.forEach((value, i) => {
|
||||
outputShapes += `output[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `;
|
||||
});
|
||||
// eslint-disable-next-line no-console
|
||||
console.log(`[profiling] kernel "${kernelId}|${kernelName}|${buildArtifact.programInfo.name}" ${inputShapes}${
|
||||
outputShapes}execution time: ${endTime - startTime} ns`);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ if (options.globalEnvFlags) {
|
|||
ort.env.wasm.initTimeout = flags.wasm.initTimeout;
|
||||
}
|
||||
if (flags.webgpu?.profilingMode !== undefined) {
|
||||
ort.env.webgpu.profilingMode = flags.webgpu.profilingMode;
|
||||
ort.env.webgpu.profiling = {mode: flags.webgpu.profilingMode};
|
||||
}
|
||||
if (flags.webgpu?.validateInputContent !== undefined) {
|
||||
ort.env.webgpu.validateInputContent = flags.webgpu.validateInputContent;
|
||||
|
|
|
|||
Loading…
Reference in a new issue