diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 0eb0d40a3e..3ce114c5d3 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -588,30 +588,39 @@ const createIndicesHelper = const impl = () => { const impls = []; - if (!useUniform) { - impls.push(`const ${shape} = ${type.indices}(${shapeOrRank.join(',')});`); - impls.push(`const ${strides} = ${type.indices}(${ShapeUtil.computeStrides(shapeOrRank).join(',')});`); - } + let needShapeStrides = false; if (implementationUsed.offsetToIndices) { impls.push(offsetToIndicesImplementation); + needShapeStrides = true; } if (implementationUsed.indicesToOffset) { impls.push(indicesToOffsetImplementation); + needShapeStrides = true; } if (implementationUsed.broadcastedIndicesToOffset) { Object.values(broadcastedIndicesToOffsetImplementation).forEach(impl => impls.push(impl)); + needShapeStrides = true; } if (implementationUsed.set) { impls.push(setImplementation); + needShapeStrides = true; } if (implementationUsed.setByIndices) { impls.push(setByIndicesImplementation); + needShapeStrides = true; } if (implementationUsed.get) { impls.push(getImplementation); + needShapeStrides = true; } if (implementationUsed.getByIndices) { impls.push(getByIndicesImplementation); + needShapeStrides = true; + } + if (!useUniform && needShapeStrides) { + impls.unshift( + `const ${shape} = ${type.indices}(${shapeOrRank.join(',')});`, + `const ${strides} = ${type.indices}(${ShapeUtil.computeStrides(shapeOrRank).join(',')});`); } return impls.join('\n'); };