From 44584c3ebe7058ce78eb6c1ceebec6697885302b Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Fri, 29 Dec 2023 07:43:08 +0800 Subject: [PATCH] [js/webgpu] only declare shape and strides in shader when necessary (#18940) ### Description Previously, shape and strides were added unconditionally even they are not used. This PR fixes this issue and only adds shape and strides when they are required. It's useful when some shapes are not used as uniform if the program depends on type instead of rank. --- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 0eb0d40a3e..3ce114c5d3 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -588,30 +588,39 @@ const createIndicesHelper = const impl = () => { const impls = []; - if (!useUniform) { - impls.push(`const ${shape} = ${type.indices}(${shapeOrRank.join(',')});`); - impls.push(`const ${strides} = ${type.indices}(${ShapeUtil.computeStrides(shapeOrRank).join(',')});`); - } + let needShapeStrides = false; if (implementationUsed.offsetToIndices) { impls.push(offsetToIndicesImplementation); + needShapeStrides = true; } if (implementationUsed.indicesToOffset) { impls.push(indicesToOffsetImplementation); + needShapeStrides = true; } if (implementationUsed.broadcastedIndicesToOffset) { Object.values(broadcastedIndicesToOffsetImplementation).forEach(impl => impls.push(impl)); + needShapeStrides = true; } if (implementationUsed.set) { impls.push(setImplementation); + needShapeStrides = true; } if (implementationUsed.setByIndices) { impls.push(setByIndicesImplementation); + needShapeStrides = true; } if (implementationUsed.get) { impls.push(getImplementation); + needShapeStrides = true; } if (implementationUsed.getByIndices) { impls.push(getByIndicesImplementation); + needShapeStrides = true; + } + if (!useUniform && needShapeStrides) { + impls.unshift( + `const ${shape} = ${type.indices}(${shapeOrRank.join(',')});`, + `const ${strides} = ${type.indices}(${ShapeUtil.computeStrides(shapeOrRank).join(',')});`); } return impls.join('\n'); };