[js/webgpu] set query type in onRunStart (#19202)

### Description
<!-- Describe your changes. -->
`env.webgpu.profiling` is a global flag. It may change before each
session.run. So the best place is to update it in `onRunStart` event.
After this, we can directly check `this.queryType`'s value. Without this
pr, we need to make sure that `getCommandEncoder()` is called before
checking `this.queryType`. Otherwise, it may happen that
`pendingKernels`'s length is not equal to `pendingDispatchNumber`'s
length. See the two ugly workarounds
[1)](e630dbf528 (diff-006fc84d3997f96a29b8033bd2075d6a0a9509211bd5812a6b934fc74fedfd9dR267-R268))
and
[2)](e630dbf528 (diff-618fe297fbe7a1da586380163b8fd2627311ccc217640a3c5cdc9c17a33472c1R73-R80))
if we don't introduce `onRunStart`. Or we need to call `setQueryType` in
each kernel run.
This commit is contained in:
Jiajia Qin 2024-01-23 08:08:55 +08:00 committed by GitHub
parent 2e0a388c36
commit d226e40856
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 13 additions and 5 deletions

View file

@ -182,6 +182,10 @@ export interface OrtWasmModule extends EmscriptenModule {
jsepCreateDownloader:
(gpuBuffer: GPUBuffer, size: number,
type: Tensor.GpuBufferDataTypes) => () => Promise<Tensor.DataTypeMap[Tensor.GpuBufferDataTypes]>;
/**
* [exported from js_internal_api.js] Called when InferenceSession.run started.
*/
jsepOnRunStart: () => void;
// #endregion
}

View file

@ -208,7 +208,7 @@ export class WebGpuBackend {
Object.defineProperty(this.env.webgpu, 'device', {value: this.device});
// init queryType, which is necessary for createKernel
// init queryType, which is necessary for InferenceSession.create
this.setQueryType();
}
@ -223,8 +223,6 @@ export class WebGpuBackend {
if (!this.commandEncoder) {
this.commandEncoder = this.device.createCommandEncoder();
// refresh queryType, as sometimes we only need to enable query for a specific run
this.setQueryType();
if (this.queryType !== 'none' && typeof this.querySet === 'undefined') {
this.querySet = this.device.createQuerySet({
type: 'timestamp',
@ -639,6 +637,7 @@ export class WebGpuBackend {
return createView(data.buffer, type);
};
}
// #endregion
writeTimestamp(index: number): void {
if (this.queryType !== 'inside-passes') {
return;
@ -657,5 +656,7 @@ export class WebGpuBackend {
}
}
}
// #endregion
onRunStart(): void {
this.setQueryType();
}
}

View file

@ -488,8 +488,8 @@ export const run = async(
}
}
wasm.jsepOnRunStart?.();
let errorCode: number;
if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) {
errorCode = await wasm._OrtRunWithBinding(
sessionHandle, ioBindingState.handle, outputCount, outputValuesOffset, runOptionsHandle);

View file

@ -186,4 +186,7 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea
Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => {
return backend['createDownloader'](gpuBuffer, size, type);
};
Module['jsepOnRunStart'] = () => {
return backend['onRunStart']();
};
};