From 39c8b3759f07059cde50df7ed48e947c2579aa98 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Sat, 5 Oct 2024 09:29:53 +0800 Subject: [PATCH] [JS/WebGPU] Fixed bugs in inputs validation of Resize (#21955) - 'scales' and 'sizes' may be empty tensor, make sure it's 1D tensor and non-empty - Make sure 'scales' and 'sizes' if present its length is non-zero --- js/web/lib/wasm/jsep/webgpu/ops/resize.ts | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts index 3cd7540ca0..edc0ea12b6 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts @@ -112,7 +112,12 @@ const validateInputs = ( throw new Error('Resize requires RoI input to be specified when coordinateTransformMode is tfCropAndResize'); } - if (scalesInputIndex > 0 && inputs.length > scalesInputIndex && inputs[scalesInputIndex].dims.length > 0) { + if ( + scalesInputIndex > 0 && + inputs.length > scalesInputIndex && + inputs[scalesInputIndex].dims.length === 1 && + inputs[scalesInputIndex].dims[0] > 0 + ) { inputs[scalesInputIndex].getFloat32Array().forEach((value) => scales.push(value)); if ( scales.length !== 0 && @@ -127,18 +132,23 @@ const validateInputs = ( updateScales(scales, attributes.axes, rank).forEach((value, index) => (scales[index] = value)); } } - if (sizesInputIndex > 0 && inputs.length > sizesInputIndex) { + if ( + sizesInputIndex > 0 && + inputs.length > sizesInputIndex && + inputs[sizesInputIndex].dims.length === 1 && + inputs[sizesInputIndex].dims[0] > 0 + ) { inputs[sizesInputIndex].getBigInt64Array().forEach((value) => sizes.push(Number(value))); - if (sizes.length !== rank || (opsetVersion >= 18 && sizes.length === attributes.axes.length)) { + if (sizes.length !== 0 && sizes.length !== rank && opsetVersion >= 18 && sizes.length !== attributes.axes.length) { throw new Error('Resize requires sizes input size to be same as input rank or axes size for opset 18 and up'); } } if (attributes.axes.length > 0) { - if (scales.length !== attributes.axes.length) { + if (scales.length !== 0 && scales.length !== attributes.axes.length) { throw new Error('Resize requires "scales" input size to be of axes rank when axes attributes is specified'); } - if (sizes.length !== attributes.axes.length) { + if (sizes.length !== 0 && sizes.length !== attributes.axes.length) { throw new Error('Resize requires "sizes" input size to be of rank axes rank when axes attributes is specified'); } }