From 0c8c0014f6f4a52a2e9783dfd8f12b32f48d382b Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Sat, 11 Nov 2023 09:37:45 +0800 Subject: [PATCH] [js/webgpu] Use builtin num_workgroups to fix shader key conflict (#18387) This fixes conformance failure of tinyyolov2-8 and potential shader key conflict issues. --- js/web/lib/wasm/jsep/backend-webgpu.ts | 40 +++++++++++------------ js/web/lib/wasm/jsep/webgpu/ops/common.ts | 7 ++-- 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index e4f1808057..e2c2bc8dec 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -54,20 +54,22 @@ const getProgramInputTensorInfoDependencyKey = * program. if the key is the same, the program shader source should be the same, so we can reuse the program. * */ -const getProgramInfoUniqueKey = (programInfo: ProgramInfo, inputTensors: readonly TensorView[]): string => { - // final key format: - // []:||... - let key = programInfo.name; - if (programInfo.shaderCache?.hint) { - key += '[' + programInfo.shaderCache.hint + ']'; - } - key += `:${ - getProgramInputTensorInfoDependencyKey( - inputTensors, - programInfo.shaderCache?.inputDependencies ?? - new Array(inputTensors.length).fill('dims'))}`; - return key; -}; +const getProgramInfoUniqueKey = + (programInfo: ProgramInfo, inputTensors: readonly TensorView[], is1DimensionDispatch: boolean): string => { + // final key format: + // []:is1DimensionDispatch:||... + let key = programInfo.name; + if (programInfo.shaderCache?.hint) { + key += '[' + programInfo.shaderCache.hint + ']'; + } + key += ':' + is1DimensionDispatch + + `:${ + getProgramInputTensorInfoDependencyKey( + inputTensors, + programInfo.shaderCache?.inputDependencies ?? + new Array(inputTensors.length).fill('dims'))}`; + return key; + }; /** * this class is designed to store status and being used as a singleton for JSEP. It will be passed to jsepInit() as @@ -283,10 +285,6 @@ export class WebGpuBackend { inputDatas[i] = gpuData; } - // get program info - const key = getProgramInfoUniqueKey(program, inputTensorViews); - let artifact = this.programManager.getArtifact(key); - const {outputs, dispatchGroup, programUniforms} = program.getRunData(inputTensorViews); // check output indices @@ -407,9 +405,11 @@ export class WebGpuBackend { uniformBufferBinding = {offset: 0, size: currentOffset, buffer: uniformBufferData.buffer}; } - const normalizedDispatchGroup = this.programManager.normalizeDispatchGroupSize(dispatchGroup); - + const is1DimensionDispatch = normalizedDispatchGroup[1] === 1 && normalizedDispatchGroup[2] === 1; + // get program info + const key = getProgramInfoUniqueKey(program, inputTensorViews, is1DimensionDispatch); + let artifact = this.programManager.getArtifact(key); if (!artifact) { artifact = this.programManager.build(program, normalizedDispatchGroup); this.programManager.setArtifact(key, artifact); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 1917c85fe5..38dc14f236 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -717,11 +717,12 @@ class ShaderHelperImpl implements ShaderHelper { const paramList = is1DimensionDispatch ? `@builtin(global_invocation_id) global_id : vec3, @builtin(local_invocation_id) local_id : vec3` : `@builtin(local_invocation_index) local_index : u32, - @builtin(workgroup_id) workgroup_id : vec3`; + @builtin(workgroup_id) workgroup_id : vec3, + @builtin(num_workgroups) num_workgroups : vec3`; const globalIdxDefinition = is1DimensionDispatch ? 'let global_idx = global_id.x;' : - `let global_idx = (workgroup_id.z * ${this.normalizedDispatchGroup[0] * this.normalizedDispatchGroup[1]}u + - workgroup_id.y * ${this.normalizedDispatchGroup[0]}u + workgroup_id.x) * ${ + `let global_idx = (workgroup_id.z * num_workgroups[0] * num_workgroups[1] + + workgroup_id.y * num_workgroups[0] + workgroup_id.x) * ${ workgroupSizeX * workgroupSizeY * workgroupSizeZ}u + local_index;`; return `@compute @workgroup_size(${workgroupSizeX}, ${workgroupSizeY}, ${workgroupSizeZ})