[js/webgpu] Simplify the Resize shader when noScale is true (#18321)

### Description
For Resize, when `noScale` is true, the shader can become very simple,
which is not related with `attributes.mode` anymore. So we should remove
those parts of shader code for simplification.

This PR can also fix #18311 since the `noScale` are all true in that
model.

However, #18311 also exposes that the Resize implementation for `linear`
mode has bug. It seems that the currently implementation always treat
the input as either 2d or 4d tensor, however, the actual input is 3d
tensor, that's why the shader compilation is failed. We may need to fix
it in a separate PR.
This commit is contained in:
Jiajia Qin 2023-11-08 04:54:20 +08:00 committed by GitHub
parent 6127dd1d2d
commit 606356d0b1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -454,6 +454,7 @@ const createResizeProgramInfo =
const noScale = inputShape.length === outputShape.length && inputShape.every((d, i) => d === outputShape[i]);
const useExtrapolation = attributes.coordinateTransformMode === 'tf_crop_and_resize';
const getShaderSource = (shaderHelper: ShaderHelper) => `
${noScale ? '' : `
${getOriginalCoordinateFromResizedCoordinate(attributes.coordinateTransformMode)};
${(() => {
switch (attributes.mode) {
@ -483,23 +484,22 @@ const createResizeProgramInfo =
throw Error('Invalid resize mode');
}
})()};
`}
${shaderHelper.declareVariables(input, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
if (${noScale}) {
output[global_idx] = input[global_idx];
} else {
let outputIndices = ${output.offsetToIndices('global_idx')};
var inputIndices: ${input.type.indices};
${(() => {
${noScale ? 'output[global_idx] = input[global_idx];' : `
let outputIndices = ${output.offsetToIndices('global_idx')};
var inputIndices: ${input.type.indices};
${(() => {
switch (attributes.mode) {
case 'nearest':
return `inputIndices = calculateInputIndicesFromOutputIndices(outputIndices);
if (checkInputIndices(inputIndices)) {
output[global_idx] = input[${input.indicesToOffset('inputIndices')}];
} else {
output[global_idx] = ${attributes.extrapolationValue};
}`;
if (checkInputIndices(inputIndices)) {
output[global_idx] = input[${input.indicesToOffset('inputIndices')}];
} else {
output[global_idx] = ${attributes.extrapolationValue};
}`;
case 'linear':
return 'output[global_idx] = bilinearInterpolation(outputIndices);';
case 'cubic':
@ -508,14 +508,14 @@ const createResizeProgramInfo =
throw Error(`Unsupported resize mode: ${attributes.mode}`);
}
})()};
}
`}
}`;
return {
name: 'Resize',
shaderCache: {
hint: `${attributes.cacheKey}|${opsetVersion}|${scales.length > 0 ? scales : ''}|${
sizes.length > 0 ? sizes : ''}`
sizes.length > 0 ? sizes : ''}|${noScale}`
},
getShaderSource,
getRunData: () => ({