diff --git a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts index edc0ea12b6..a0abfc0270 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts @@ -157,6 +157,16 @@ const validateInputs = ( } }; +const getSafeIntegerDivision = (a: string, b: string, c: string, dType: string): string => ` + // The whole part and the fractional part are calculated separately due to inaccuracy of floating + // point division. As an example, f32(21) / f32(7) may evaluate to 2.99... instead of 3, causing an + // offset-by-one error later in floor(). + let big = (${a}) * (${b}); + let whole = ${dType}(big / (${c})); + let fract = ${dType}(big % (${c})) / ${dType}(${c}); + return whole + fract; +`; + const getOriginalCoordinateFromResizedCoordinate = ( coordinateTransferMode: CoordinateTransformMode, dType: string, @@ -166,7 +176,13 @@ const getOriginalCoordinateFromResizedCoordinate = ( (() => { switch (coordinateTransferMode) { case 'asymmetric': - return `return ${dType}(xResized) / ${dType}(xScale);`; + return ` + if (xScale < 1.0 || floor(xScale) != xScale) { + return ${dType}(xResized) / ${dType}(xScale); + } else { + ${getSafeIntegerDivision('xResized', 'lengthOriginal', 'lengthResized', dType)} + } + `; case 'pytorch_half_pixel': return `if (lengthResized > 1) { return (${dType}(xResized) + 0.5) / ${dType}(xScale) - 0.5; @@ -179,13 +195,7 @@ const getOriginalCoordinateFromResizedCoordinate = ( return `if (lengthResized == 1) { return 0.0; } else { - // The whole part and the fractional part are calculated separately due to inaccuracy of floating - // point division. As an example, f32(21) / f32(7) may evaluate to 2.99... instead of 3, causing an - // offset-by-one error later in floor(). - let whole = ${dType}(xResized * (lengthOriginal - 1) / (lengthResized - 1)); - let fract = - ${dType}(xResized * (lengthOriginal - 1) % (lengthResized - 1)) / ${dType}(lengthResized - 1); - return whole + fract; + ${getSafeIntegerDivision('xResized', 'lengthOriginal - 1', 'lengthResized - 1', dType)} }`; case 'tf_crop_and_resize': return `if (lengthResized > 1) { @@ -375,7 +385,7 @@ const calculateInputIndicesFromOutputIndices = ( input_index = u32(original_idx); } } - ${input.indicesSet('input_indices', 'i', ' input_index')} + ${input.indicesSet('input_indices', 'i', 'input_index')} } return input_indices; }`; @@ -758,9 +768,11 @@ const createResizeProgramInfo = ( return { name: 'Resize', shaderCache: { - hint: `${attributes.cacheKey}|${opsetVersion}|${scales.length > 0 ? scales : ''}|${ - sizes.length > 0 ? sizes : '' - }|${roi.length > 0 ? roi : ''}|${noScale}|${inputShape}`, + hint: `${attributes.cacheKey}|${opsetVersion}|${ + scales.length > 0 ? (attributes.mode === 'cubic' ? scales : scales.length) : '' + }|${sizes.length > 0 ? sizes : ''}|${roi.length > 0 ? roi : ''}|${noScale}|${ + attributes.mode === 'nearest' ? inputShape.length : inputShape + }`, inputDependencies: ['rank'], }, getShaderSource,