mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-12 00:59:23 +00:00
[js/webgpu] set query type in onRunStart (#19202)
### Description <!-- Describe your changes. --> `env.webgpu.profiling` is a global flag. It may change before each session.run. So the best place is to update it in `onRunStart` event. After this, we can directly check `this.queryType`'s value. Without this pr, we need to make sure that `getCommandEncoder()` is called before checking `this.queryType`. Otherwise, it may happen that `pendingKernels`'s length is not equal to `pendingDispatchNumber`'s length. See the two ugly workarounds [1)](e630dbf528 (diff-006fc84d3997f96a29b8033bd2075d6a0a9509211bd5812a6b934fc74fedfd9dR267-R268)) and [2)](e630dbf528 (diff-618fe297fbe7a1da586380163b8fd2627311ccc217640a3c5cdc9c17a33472c1R73-R80)) if we don't introduce `onRunStart`. Or we need to call `setQueryType` in each kernel run.
This commit is contained in:
parent
2e0a388c36
commit
d226e40856
4 changed files with 13 additions and 5 deletions
4
js/web/lib/wasm/binding/ort-wasm.d.ts
vendored
4
js/web/lib/wasm/binding/ort-wasm.d.ts
vendored
|
|
@ -182,6 +182,10 @@ export interface OrtWasmModule extends EmscriptenModule {
|
|||
jsepCreateDownloader:
|
||||
(gpuBuffer: GPUBuffer, size: number,
|
||||
type: Tensor.GpuBufferDataTypes) => () => Promise<Tensor.DataTypeMap[Tensor.GpuBufferDataTypes]>;
|
||||
/**
|
||||
* [exported from js_internal_api.js] Called when InferenceSession.run started.
|
||||
*/
|
||||
jsepOnRunStart: () => void;
|
||||
// #endregion
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -208,7 +208,7 @@ export class WebGpuBackend {
|
|||
|
||||
Object.defineProperty(this.env.webgpu, 'device', {value: this.device});
|
||||
|
||||
// init queryType, which is necessary for createKernel
|
||||
// init queryType, which is necessary for InferenceSession.create
|
||||
this.setQueryType();
|
||||
}
|
||||
|
||||
|
|
@ -223,8 +223,6 @@ export class WebGpuBackend {
|
|||
if (!this.commandEncoder) {
|
||||
this.commandEncoder = this.device.createCommandEncoder();
|
||||
|
||||
// refresh queryType, as sometimes we only need to enable query for a specific run
|
||||
this.setQueryType();
|
||||
if (this.queryType !== 'none' && typeof this.querySet === 'undefined') {
|
||||
this.querySet = this.device.createQuerySet({
|
||||
type: 'timestamp',
|
||||
|
|
@ -639,6 +637,7 @@ export class WebGpuBackend {
|
|||
return createView(data.buffer, type);
|
||||
};
|
||||
}
|
||||
// #endregion
|
||||
writeTimestamp(index: number): void {
|
||||
if (this.queryType !== 'inside-passes') {
|
||||
return;
|
||||
|
|
@ -657,5 +656,7 @@ export class WebGpuBackend {
|
|||
}
|
||||
}
|
||||
}
|
||||
// #endregion
|
||||
onRunStart(): void {
|
||||
this.setQueryType();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -488,8 +488,8 @@ export const run = async(
|
|||
}
|
||||
}
|
||||
|
||||
wasm.jsepOnRunStart?.();
|
||||
let errorCode: number;
|
||||
|
||||
if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) {
|
||||
errorCode = await wasm._OrtRunWithBinding(
|
||||
sessionHandle, ioBindingState.handle, outputCount, outputValuesOffset, runOptionsHandle);
|
||||
|
|
|
|||
|
|
@ -186,4 +186,7 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea
|
|||
Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => {
|
||||
return backend['createDownloader'](gpuBuffer, size, type);
|
||||
};
|
||||
Module['jsepOnRunStart'] = () => {
|
||||
return backend['onRunStart']();
|
||||
};
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in a new issue