mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
[js/webgpu] Use builtin num_workgroups to fix shader key conflict (#18387)
This fixes conformance failure of tinyyolov2-8 and potential shader key conflict issues.
This commit is contained in:
parent
6b0c97b43f
commit
0c8c0014f6
2 changed files with 24 additions and 23 deletions
|
|
@ -54,20 +54,22 @@ const getProgramInputTensorInfoDependencyKey =
|
|||
* program. if the key is the same, the program shader source should be the same, so we can reuse the program.
|
||||
*
|
||||
*/
|
||||
const getProgramInfoUniqueKey = (programInfo: ProgramInfo, inputTensors: readonly TensorView[]): string => {
|
||||
// final key format:
|
||||
// <PROGRAM_NAME>[<PROGRAM_CUSTOM_CACHE_HINT>]:<INPUTS_INFO_0>|<INPUTS_INFO_1>|...
|
||||
let key = programInfo.name;
|
||||
if (programInfo.shaderCache?.hint) {
|
||||
key += '[' + programInfo.shaderCache.hint + ']';
|
||||
}
|
||||
key += `:${
|
||||
getProgramInputTensorInfoDependencyKey(
|
||||
inputTensors,
|
||||
programInfo.shaderCache?.inputDependencies ??
|
||||
new Array<ProgramInputTensorInfoDependency>(inputTensors.length).fill('dims'))}`;
|
||||
return key;
|
||||
};
|
||||
const getProgramInfoUniqueKey =
|
||||
(programInfo: ProgramInfo, inputTensors: readonly TensorView[], is1DimensionDispatch: boolean): string => {
|
||||
// final key format:
|
||||
// <PROGRAM_NAME>[<PROGRAM_CUSTOM_CACHE_HINT>]:is1DimensionDispatch:<INPUTS_INFO_0>|<INPUTS_INFO_1>|...
|
||||
let key = programInfo.name;
|
||||
if (programInfo.shaderCache?.hint) {
|
||||
key += '[' + programInfo.shaderCache.hint + ']';
|
||||
}
|
||||
key += ':' + is1DimensionDispatch +
|
||||
`:${
|
||||
getProgramInputTensorInfoDependencyKey(
|
||||
inputTensors,
|
||||
programInfo.shaderCache?.inputDependencies ??
|
||||
new Array<ProgramInputTensorInfoDependency>(inputTensors.length).fill('dims'))}`;
|
||||
return key;
|
||||
};
|
||||
|
||||
/**
|
||||
* this class is designed to store status and being used as a singleton for JSEP. It will be passed to jsepInit() as
|
||||
|
|
@ -283,10 +285,6 @@ export class WebGpuBackend {
|
|||
inputDatas[i] = gpuData;
|
||||
}
|
||||
|
||||
// get program info
|
||||
const key = getProgramInfoUniqueKey(program, inputTensorViews);
|
||||
let artifact = this.programManager.getArtifact(key);
|
||||
|
||||
const {outputs, dispatchGroup, programUniforms} = program.getRunData(inputTensorViews);
|
||||
|
||||
// check output indices
|
||||
|
|
@ -407,9 +405,11 @@ export class WebGpuBackend {
|
|||
uniformBufferBinding = {offset: 0, size: currentOffset, buffer: uniformBufferData.buffer};
|
||||
}
|
||||
|
||||
|
||||
const normalizedDispatchGroup = this.programManager.normalizeDispatchGroupSize(dispatchGroup);
|
||||
|
||||
const is1DimensionDispatch = normalizedDispatchGroup[1] === 1 && normalizedDispatchGroup[2] === 1;
|
||||
// get program info
|
||||
const key = getProgramInfoUniqueKey(program, inputTensorViews, is1DimensionDispatch);
|
||||
let artifact = this.programManager.getArtifact(key);
|
||||
if (!artifact) {
|
||||
artifact = this.programManager.build(program, normalizedDispatchGroup);
|
||||
this.programManager.setArtifact(key, artifact);
|
||||
|
|
|
|||
|
|
@ -717,11 +717,12 @@ class ShaderHelperImpl implements ShaderHelper {
|
|||
const paramList = is1DimensionDispatch ? `@builtin(global_invocation_id) global_id : vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id : vec3<u32>` :
|
||||
`@builtin(local_invocation_index) local_index : u32,
|
||||
@builtin(workgroup_id) workgroup_id : vec3<u32>`;
|
||||
@builtin(workgroup_id) workgroup_id : vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups : vec3<u32>`;
|
||||
const globalIdxDefinition = is1DimensionDispatch ?
|
||||
'let global_idx = global_id.x;' :
|
||||
`let global_idx = (workgroup_id.z * ${this.normalizedDispatchGroup[0] * this.normalizedDispatchGroup[1]}u +
|
||||
workgroup_id.y * ${this.normalizedDispatchGroup[0]}u + workgroup_id.x) * ${
|
||||
`let global_idx = (workgroup_id.z * num_workgroups[0] * num_workgroups[1] +
|
||||
workgroup_id.y * num_workgroups[0] + workgroup_id.x) * ${
|
||||
workgroupSizeX * workgroupSizeY * workgroupSizeZ}u + local_index;`;
|
||||
|
||||
return `@compute @workgroup_size(${workgroupSizeX}, ${workgroupSizeY}, ${workgroupSizeZ})
|
||||
|
|
|
|||
Loading…
Reference in a new issue