[js/webgpu] allow to specify callback for profiling data (#18732)

### Description

**This PR is a replacement of #17820.**

allow to specify callback for profiling data

*Previous*:
```js
ort.env.webgpu.profilingMode = 'default';  // enable profiling

// profiling data will output to console.
```

*Now*:
```js
ort.env.webgpu.profiling = {
  mode: 'default';  // enable profiling
  ondata: (data) => {
    // .. process the profiling data
  }
};

//for each kernel, "ondata" will be called once. only output to console if ondata is not specified.
```
This commit is contained in:
Yulong Wang 2023-12-07 14:10:28 -08:00 committed by GitHub
parent 4abec9749e
commit efbef5f611
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 71 additions and 22 deletions

View file

@ -92,11 +92,48 @@ export declare namespace Env {
async?: boolean;
}
export interface WebGpuProfilingDataV1TensorMetadata {
dims: readonly number[];
dataType: string;
}
export interface WebGpuProfilingDataV1 {
version: 1;
inputsMetadata: readonly WebGpuProfilingDataV1TensorMetadata[];
outputsMetadata: readonly WebGpuProfilingDataV1TensorMetadata[];
kernelId: number;
kernelType: string;
kernelName: string;
startTime: number;
endTime: number;
}
export type WebGpuProfilingData = WebGpuProfilingDataV1;
export interface WebGpuFlags {
/**
* Set or get the profiling mode.
*
* @deprecated Use `env.webgpu.profiling.mode` instead. If `env.webgpu.profiling.mode` is set, this property will be
* ignored.
*/
profilingMode?: 'off'|'default';
/**
* Set or get the profiling configuration.
*/
profiling?: {
/**
* Set or get the profiling mode.
*
* @defaultValue `'off'`
*/
mode?: 'off'|'default';
/**
* Set or get a callback function when a profiling data is received. If not set, the profiling data will be
* printed to console.
*/
ondata?: (data: WebGpuProfilingData) => void;
};
/**
* Get the device for WebGPU.
*

View file

@ -254,11 +254,9 @@ export class WebGpuBackend {
}
isQueryEnabled(): boolean {
if (this.device.features.has('timestamp-query') && this.env.webgpu.profilingMode === 'default') {
return true;
} else {
return false;
}
return this.device.features.has('timestamp-query') &&
(this.env.webgpu.profiling?.mode === 'default' ||
(!this.env.webgpu.profiling?.mode && this.env.webgpu.profilingMode === 'default'));
}
/**

View file

@ -175,8 +175,7 @@ export const init = async(module: OrtWasmModule, env: Env): Promise<void> => {
// jsepCreateKernel
(name: string, kernel: number, attribute: unknown) => backend.createKernel(
name, kernel, attribute,
env.debug || env.webgpu.profilingMode === 'default' ? module.UTF8ToString(module._JsepGetNodeName(kernel)) :
`${kernel}`),
env.debug || backend.isQueryEnabled() ? module.UTF8ToString(module._JsepGetNodeName(kernel)) : `${kernel}`),
// jsepReleaseKernel
(kernel: number) => backend.releaseKernel(kernel),

View file

@ -75,12 +75,11 @@ export class ProgramManager {
const kernelId = this.backend.currentKernelId!;
const kernelInfo = this.backend.kernels.get(kernelId)!;
const kernelName = `[${kernelInfo[0]}] ${kernelInfo[1]}`;
void syncData.buffer.mapAsync(GPUMapMode.READ).then(() => {
const mappedData = new BigUint64Array(syncData.buffer.getMappedRange());
const startTimeU64 = mappedData[0];
const endTimeU64 = mappedData[1];
const [startTimeU64, endTimeU64] = mappedData;
const [kernelType, kernelName] = kernelInfo;
syncData.buffer.unmap();
@ -96,17 +95,33 @@ export class ProgramManager {
}
this.backend.gpuDataManager.release(syncData.id);
let inputShapes = '';
inputTensorViews.forEach((value, i) => {
inputShapes += `input[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `;
});
let outputShapes = '';
outputTensorViews.forEach((value, i) => {
outputShapes += `output[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `;
});
// eslint-disable-next-line no-console
console.log(`[profiling] kernel "${kernelId}|${kernelName}|${buildArtifact.programInfo.name}" ${inputShapes}${
outputShapes}execution time: ${endTime - startTime} ns`);
if (this.backend.env.webgpu.profiling?.ondata) {
this.backend.env.webgpu.profiling.ondata({
version: 1,
inputsMetadata: inputTensorViews.map(
value => ({dims: value.dims, dataType: tensorDataTypeEnumToString(value.dataType)})),
outputsMetadata: outputTensorViews.map(
value => ({dims: value.dims, dataType: tensorDataTypeEnumToString(value.dataType)})),
kernelId,
kernelType,
kernelName,
startTime,
endTime,
});
} else {
// if no callback is provided, print the profiling message to console
let inputShapes = '';
inputTensorViews.forEach((value, i) => {
inputShapes += `input[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `;
});
let outputShapes = '';
inputTensorViews.forEach((value, i) => {
outputShapes += `output[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `;
});
// eslint-disable-next-line no-console
console.log(`[profiling] kernel "${kernelId}|${kernelName}|${buildArtifact.programInfo.name}" ${inputShapes}${
outputShapes}execution time: ${endTime - startTime} ns`);
}
});
}

View file

@ -56,7 +56,7 @@ if (options.globalEnvFlags) {
ort.env.wasm.initTimeout = flags.wasm.initTimeout;
}
if (flags.webgpu?.profilingMode !== undefined) {
ort.env.webgpu.profilingMode = flags.webgpu.profilingMode;
ort.env.webgpu.profiling = {mode: flags.webgpu.profilingMode};
}
if (flags.webgpu?.validateInputContent !== undefined) {
ort.env.webgpu.validateInputContent = flags.webgpu.validateInputContent;