diff --git a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts index fed1dbcf51..07cfefb8f1 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts @@ -454,6 +454,7 @@ const createResizeProgramInfo = const noScale = inputShape.length === outputShape.length && inputShape.every((d, i) => d === outputShape[i]); const useExtrapolation = attributes.coordinateTransformMode === 'tf_crop_and_resize'; const getShaderSource = (shaderHelper: ShaderHelper) => ` + ${noScale ? '' : ` ${getOriginalCoordinateFromResizedCoordinate(attributes.coordinateTransformMode)}; ${(() => { switch (attributes.mode) { @@ -483,23 +484,22 @@ const createResizeProgramInfo = throw Error('Invalid resize mode'); } })()}; + `} ${shaderHelper.declareVariables(input, output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} - if (${noScale}) { - output[global_idx] = input[global_idx]; - } else { - let outputIndices = ${output.offsetToIndices('global_idx')}; - var inputIndices: ${input.type.indices}; - ${(() => { + ${noScale ? 'output[global_idx] = input[global_idx];' : ` + let outputIndices = ${output.offsetToIndices('global_idx')}; + var inputIndices: ${input.type.indices}; + ${(() => { switch (attributes.mode) { case 'nearest': return `inputIndices = calculateInputIndicesFromOutputIndices(outputIndices); - if (checkInputIndices(inputIndices)) { - output[global_idx] = input[${input.indicesToOffset('inputIndices')}]; - } else { - output[global_idx] = ${attributes.extrapolationValue}; - }`; + if (checkInputIndices(inputIndices)) { + output[global_idx] = input[${input.indicesToOffset('inputIndices')}]; + } else { + output[global_idx] = ${attributes.extrapolationValue}; + }`; case 'linear': return 'output[global_idx] = bilinearInterpolation(outputIndices);'; case 'cubic': @@ -508,14 +508,14 @@ const createResizeProgramInfo = throw Error(`Unsupported resize mode: ${attributes.mode}`); } })()}; - } + `} }`; return { name: 'Resize', shaderCache: { hint: `${attributes.cacheKey}|${opsetVersion}|${scales.length > 0 ? scales : ''}|${ - sizes.length > 0 ? sizes : ''}` + sizes.length > 0 ? sizes : ''}|${noScale}` }, getShaderSource, getRunData: () => ({