[webgpu/js] Optimize resize webgpu op & fix precision issues (#23591)

### Description
<!-- Describe your changes. -->

This PR is a follow-up to
https://github.com/microsoft/onnxruntime/pull/23488 and partially
improves upon https://github.com/microsoft/onnxruntime/issues/23403. It
does the following:
- Prevents unnecessary cache shader recompilation for 'nearest' resize
operation.
- Fixes precision (offset-by-one) errors with asymmetric coordinate
transform. When running the Kokoro TTS model, values for the
`/decoder/decoder/generator/f0_upsamp/Resize_output_0` results in
differences at the end bounds due to precision issues when dividing
21600 by 72 (should be 300, but seemingly results in 299.999, which
causes issues when flooring)

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

I did a deep dive over the weekend to try fix Kokoro TTS on WebGPU and
found that the above node had a large difference. Thinking this was a
major issue, I spent some time fixing it. Turns out, it only happens for
a small number of values, leading to high maximum error, but most values
are correct (as seen here).

BEFORE:
```
[/decoder/decoder/generator/f0_upsamp/Resize_output_0] atol: 78.6640682220459 | rtol: 24.13991587587724 | avgDiff: 0.009967932171121087 | medianDiff: 0.000030517578125
```

AFTER:
```
[/decoder/decoder/generator/f0_upsamp/Resize_output_0] atol: 0.0011138916015625 | rtol: 0.0020059924232260704 | avgDiff: 0.00008570214675873825 | medianDiff: 0.000030517578125
```

So, although it has a very small impact on the final output (waveform),
this bug could appear with other models in a more severe way.

BEFORE:
```
[waveform] atol: 0.04784199967980385 | rtol: 1366.0462001093495 | avgDiff: 0.0009544936942737713 | medianDiff: 0.00015346752479672432
```

AFTER:
```
[waveform] atol: 0.04775865003466606 | rtol: 1354.7002460360852 | avgDiff: 0.000954830244055033 | medianDiff: 0.00015274062752723694
```
This commit is contained in:
Joshua Lochner 2025-02-06 20:26:25 +02:00 committed by GitHub
parent 328a13c06d
commit d981b153d3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -157,6 +157,16 @@ const validateInputs = (
} }
}; };
const getSafeIntegerDivision = (a: string, b: string, c: string, dType: string): string => `
// The whole part and the fractional part are calculated separately due to inaccuracy of floating
// point division. As an example, f32(21) / f32(7) may evaluate to 2.99... instead of 3, causing an
// offset-by-one error later in floor().
let big = (${a}) * (${b});
let whole = ${dType}(big / (${c}));
let fract = ${dType}(big % (${c})) / ${dType}(${c});
return whole + fract;
`;
const getOriginalCoordinateFromResizedCoordinate = ( const getOriginalCoordinateFromResizedCoordinate = (
coordinateTransferMode: CoordinateTransformMode, coordinateTransferMode: CoordinateTransformMode,
dType: string, dType: string,
@ -166,7 +176,13 @@ const getOriginalCoordinateFromResizedCoordinate = (
(() => { (() => {
switch (coordinateTransferMode) { switch (coordinateTransferMode) {
case 'asymmetric': case 'asymmetric':
return `return ${dType}(xResized) / ${dType}(xScale);`; return `
if (xScale < 1.0 || floor(xScale) != xScale) {
return ${dType}(xResized) / ${dType}(xScale);
} else {
${getSafeIntegerDivision('xResized', 'lengthOriginal', 'lengthResized', dType)}
}
`;
case 'pytorch_half_pixel': case 'pytorch_half_pixel':
return `if (lengthResized > 1) { return `if (lengthResized > 1) {
return (${dType}(xResized) + 0.5) / ${dType}(xScale) - 0.5; return (${dType}(xResized) + 0.5) / ${dType}(xScale) - 0.5;
@ -179,13 +195,7 @@ const getOriginalCoordinateFromResizedCoordinate = (
return `if (lengthResized == 1) { return `if (lengthResized == 1) {
return 0.0; return 0.0;
} else { } else {
// The whole part and the fractional part are calculated separately due to inaccuracy of floating ${getSafeIntegerDivision('xResized', 'lengthOriginal - 1', 'lengthResized - 1', dType)}
// point division. As an example, f32(21) / f32(7) may evaluate to 2.99... instead of 3, causing an
// offset-by-one error later in floor().
let whole = ${dType}(xResized * (lengthOriginal - 1) / (lengthResized - 1));
let fract =
${dType}(xResized * (lengthOriginal - 1) % (lengthResized - 1)) / ${dType}(lengthResized - 1);
return whole + fract;
}`; }`;
case 'tf_crop_and_resize': case 'tf_crop_and_resize':
return `if (lengthResized > 1) { return `if (lengthResized > 1) {
@ -375,7 +385,7 @@ const calculateInputIndicesFromOutputIndices = (
input_index = u32(original_idx); input_index = u32(original_idx);
} }
} }
${input.indicesSet('input_indices', 'i', ' input_index')} ${input.indicesSet('input_indices', 'i', 'input_index')}
} }
return input_indices; return input_indices;
}`; }`;
@ -758,9 +768,11 @@ const createResizeProgramInfo = (
return { return {
name: 'Resize', name: 'Resize',
shaderCache: { shaderCache: {
hint: `${attributes.cacheKey}|${opsetVersion}|${scales.length > 0 ? scales : ''}|${ hint: `${attributes.cacheKey}|${opsetVersion}|${
sizes.length > 0 ? sizes : '' scales.length > 0 ? (attributes.mode === 'cubic' ? scales : scales.length) : ''
}|${roi.length > 0 ? roi : ''}|${noScale}|${inputShape}`, }|${sizes.length > 0 ? sizes : ''}|${roi.length > 0 ? roi : ''}|${noScale}|${
attributes.mode === 'nearest' ? inputShape.length : inputShape
}`,
inputDependencies: ['rank'], inputDependencies: ['rank'],
}, },
getShaderSource, getShaderSource,