[JS/WebGPU] Fixed bugs in inputs validation of Resize (#21955)

- 'scales' and 'sizes' may be empty tensor, make sure it's 1D tensor and
non-empty
- Make sure 'scales' and 'sizes' if present its length is non-zero
This commit is contained in:
Wanming Lin 2024-10-05 09:29:53 +08:00 committed by GitHub
parent b5ef85555a
commit 39c8b3759f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -112,7 +112,12 @@ const validateInputs = (
throw new Error('Resize requires RoI input to be specified when coordinateTransformMode is tfCropAndResize');
}
if (scalesInputIndex > 0 && inputs.length > scalesInputIndex && inputs[scalesInputIndex].dims.length > 0) {
if (
scalesInputIndex > 0 &&
inputs.length > scalesInputIndex &&
inputs[scalesInputIndex].dims.length === 1 &&
inputs[scalesInputIndex].dims[0] > 0
) {
inputs[scalesInputIndex].getFloat32Array().forEach((value) => scales.push(value));
if (
scales.length !== 0 &&
@ -127,18 +132,23 @@ const validateInputs = (
updateScales(scales, attributes.axes, rank).forEach((value, index) => (scales[index] = value));
}
}
if (sizesInputIndex > 0 && inputs.length > sizesInputIndex) {
if (
sizesInputIndex > 0 &&
inputs.length > sizesInputIndex &&
inputs[sizesInputIndex].dims.length === 1 &&
inputs[sizesInputIndex].dims[0] > 0
) {
inputs[sizesInputIndex].getBigInt64Array().forEach((value) => sizes.push(Number(value)));
if (sizes.length !== rank || (opsetVersion >= 18 && sizes.length === attributes.axes.length)) {
if (sizes.length !== 0 && sizes.length !== rank && opsetVersion >= 18 && sizes.length !== attributes.axes.length) {
throw new Error('Resize requires sizes input size to be same as input rank or axes size for opset 18 and up');
}
}
if (attributes.axes.length > 0) {
if (scales.length !== attributes.axes.length) {
if (scales.length !== 0 && scales.length !== attributes.axes.length) {
throw new Error('Resize requires "scales" input size to be of axes rank when axes attributes is specified');
}
if (sizes.length !== attributes.axes.length) {
if (sizes.length !== 0 && sizes.length !== attributes.axes.length) {
throw new Error('Resize requires "sizes" input size to be of rank axes rank when axes attributes is specified');
}
}