diff --git a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts index e1369c2c2b..d20ef63222 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts @@ -219,7 +219,7 @@ const initOutputShape = return outputShape; }; -const adjustOutputShape = (inputShape: readonly number[], scales: number[], attributes: ResizeAttributes): number[] => { +const adjustOutputShape = (inputShape: readonly number[], scales: number[], attributes: ResizeAttributes) => { const scaleInPolicy = (() => { switch (attributes.keepAspectRatioPolicy) { case 'not_larger': @@ -312,21 +312,27 @@ const checkInputIndices = (input: IndicesHelper, inputShape: readonly number[]): return true; }`; +const setChannelAndBatchIndices = + (input: IndicesHelper, channelIdx: number, batchIdx: number, spacialDims: number): string => + input.rank > spacialDims ? ` + ${input.indicesSet('input_indices', channelIdx, 'channel')}; + ${input.indicesSet('input_indices', batchIdx, 'batch')}; +` : + ''; + const bilinearInterpolation = - (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], scales: readonly number[], - useExtrapolation: boolean, extrapolationValue: number): string => { + (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], useExtrapolation: boolean, + extrapolationValue: number): string => { + const isNchw = true; const [batchIdx, heightIdx, widthIdx, channelIdx] = - inputShape.length === 2 ? [-1, 0, 1, -1] : (scales[1] === 1.0 ? [0, 2, 3, 1] : [0, 1, 2, 3]); + inputShape.length === 2 ? [-1, 0, 1, -1] : (isNchw ? [0, 2, 3, 1] : [0, 1, 2, 3]); const dType = input.type.value; return ` fn getInputValue(batch: u32, channel: u32, row: u32, col: u32) -> ${dType} { var input_indices: ${input.type.indices}; ${input.indicesSet('input_indices', heightIdx, `max(0, min(row, ${inputShape[heightIdx]} - 1))`)}; ${input.indicesSet('input_indices', widthIdx, `max(0, min(col, ${inputShape[widthIdx]} - 1))`)}; - if (${inputShape.length} > 2) { - ${input.indicesSet('input_indices', channelIdx, 'channel')}; - ${input.indicesSet('input_indices', batchIdx, 'batch')}; - }; + ${setChannelAndBatchIndices(input, channelIdx, batchIdx, 2)} return ${input.getByIndices('input_indices')}; } @@ -334,30 +340,36 @@ const bilinearInterpolation = var originalIndices = calculateOriginalIndicesFromOutputIndices(output_indices); var row:${dType} = originalIndices[${heightIdx}]; var col:${dType} = originalIndices[${widthIdx}]; - if (${useExtrapolation} && (row < 0 || row > (${inputShape[heightIdx]} - 1) || col < 0 || col > ${ - inputShape[widthIdx]} - 1)) { + ${ + useExtrapolation ? + `if (row < 0 || row > (${inputShape[heightIdx]} - 1) || col < 0 || col > (${inputShape[widthIdx]} - 1))) { return ${extrapolationValue}; - } + }` : + ''}; row = max(0, min(row, ${inputShape[heightIdx]} - 1)); col = max(0, min(col, ${inputShape[widthIdx]} - 1)); var row1: u32 = u32(row); var col1: u32 = u32(col); var row2: u32 = u32(row + 1); var col2: u32 = u32(col + 1); - var channel: u32 = 0; - var batch: u32 = 0; - if (${inputShape.length > 2}) { - channel = u32(originalIndices[${channelIdx}]); - batch = u32(originalIndices[${batchIdx}]); - } + var channel: u32 = ${inputShape.length > 2 ? `u32(originalIndices[${channelIdx}])` : '0'}; + var batch: u32 = ${inputShape.length > 2 ? `u32(originalIndices[${batchIdx}])` : '0'}; var x11: ${dType} = getInputValue(batch, channel, row1, col1); var x12: ${dType} = getInputValue(batch, channel, row1, col2); var x21: ${dType} = getInputValue(batch, channel, row2, col1); var x22: ${dType} = getInputValue(batch, channel, row2, col2); - var dx1: ${dType} = row - ${dType}(row1); - var dx2: ${dType} = ${dType}(row2) - row; - var dy1 = col - ${dType}(col1); - var dy2 = ${dType}(col2) - col; + var dx1: ${dType} = abs(row - ${dType}(row1)); + var dx2: ${dType} = abs(${dType}(row2) - row); + var dy1: ${dType} = abs(col - ${dType}(col1)); + var dy2: ${dType} = abs(${dType}(col2) - col); + if (row1 == row2) { + dx1 = 0.5; + dx2 = 0.5; + } + if (col1 == col2) { + dy1 = 0.5; + dy2 = 0.5; + } return (x11 * dx2 * dy2 + x12 * dx2 * dy1 + x21 * dx1 * dy2 + x22 * dx1 * dy1); }`; }; @@ -366,7 +378,9 @@ const bicubicInterpolation = (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[], scales: readonly number[], roi: readonly number[], cubicCoeffA: number, useExtrapolation: boolean, extrapolationValue: number, excludeOutside: boolean): string => { - const [heightIdx, widthIdx] = inputShape.length === 2 ? [0, 1] : (scales[1] === 1.0) ? [2, 3] : [1, 2]; + const is2D = inputShape.length === 2; + const isNchw = true; + const [heightIdx, widthIdx] = is2D ? [0, 1] : isNchw ? [2, 3] : [1, 2]; const dType = input.type.value; const createCubicInterpolationFunction = (idx: number): string => { const direction = idx === heightIdx ? 'row' : 'col'; @@ -386,16 +400,18 @@ const bicubicInterpolation = for (var i: i32 = -1; i < 3; i++) { var ${direction}: ${dType} = originalIdx + ${dType}(i); if (${direction} < 0 || ${direction} >= ${inputShape[idx]}) { - if (${excludeOutside}) { - coefs[i + 1] = 0.0; - continue; - } else if (${useExtrapolation}) { - return ${extrapolationValue}; - } else { - ${direction} = max(0, min(${direction}, ${inputShape[idx]} - 1)); - } + ${(() => { + if (excludeOutside) { + return `coefs[i + 1] = 0.0; + continue;`; + } else if (useExtrapolation) { + return `return ${extrapolationValue};`; + } else { + return `${direction} = max(0, min(${direction}, ${inputShape[idx]} - 1));`; } - var input_indices_copy: ${input.type.indices} = input_indices; + })()}; + } + var input_indices_copy: ${input.type.indices} = input_indices; ${input.indicesSet('input_indices_copy', idx, `u32(${direction})`)}; data[i + 1] = ${ idx === heightIdx ? input.getByIndices('input_indices_copy') : @@ -435,6 +451,78 @@ const bicubicInterpolation = `; }; +const trilinearInterpolation = + (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], useExtrapolation: boolean, + extrapolationValue: number): string => { + const isNchw = true; + const [batchIdx, depthIdx, heightIdx, widthIdx, channelIdx] = + inputShape.length === 3 ? [-1, 0, 1, 2, -1] : (isNchw ? [0, 2, 3, 4, 1] : [0, 1, 2, 3, 4]); + const dType = input.type.value; + return ` + fn getInputValue(batch: u32, channel: u32, depth:u32, height: u32, width: u32) -> ${dType} { + var input_indices: ${input.type.indices}; + ${input.indicesSet('input_indices', depthIdx, `max(0, min(depth, ${inputShape[depthIdx]} - 1))`)}; + ${input.indicesSet('input_indices', heightIdx, `max(0, min(height, ${inputShape[heightIdx]} - 1))`)}; + ${input.indicesSet('input_indices', widthIdx, `max(0, min(width, ${inputShape[widthIdx]} - 1))`)}; + ${setChannelAndBatchIndices(input, channelIdx, batchIdx, 3)} + return ${input.getByIndices('input_indices')}; + } + + fn trilinearInterpolation(output_indices: ${output.type.indices}) -> ${dType} { + var originalIndices = calculateOriginalIndicesFromOutputIndices(output_indices); + var depth:${dType} = originalIndices[${depthIdx}]; + var height:${dType} = originalIndices[${heightIdx}]; + var width:${dType} = originalIndices[${widthIdx}]; + ${ + useExtrapolation ? `if (depth < 0 || depth > (${inputShape[depthIdx]} - 1) || height < 0 || height > (${ + inputShape[heightIdx]} - 1) || width < 0 || (width > ${inputShape[widthIdx]} - 1))) { + return ${extrapolationValue}; + }` : + ''}; + + depth = max(0, min(depth, ${inputShape[depthIdx]} - 1)); + height = max(0, min(height, ${inputShape[heightIdx]} - 1)); + width = max(0, min(width, ${inputShape[widthIdx]} - 1)); + var depth1: u32 = u32(depth); + var height1: u32 = u32(height); + var width1: u32 = u32(width); + var depth2: u32 = u32(depth + 1); + var height2: u32 = u32(height + 1); + var width2: u32 = u32(width + 1); + var channel: u32 = ${inputShape.length > 3 ? `u32(originalIndices[${channelIdx}])` : '0'}; + var batch: u32 = ${inputShape.length > 3 ? `u32(originalIndices[${batchIdx}])` : '0'}; + + var x111: ${dType} = getInputValue(batch, channel, depth1, height1, width1); + var x112: ${dType} = getInputValue(batch, channel, depth1, height1, width2); + var x121: ${dType} = getInputValue(batch, channel, depth1, height2, width1); + var x122: ${dType} = getInputValue(batch, channel, depth1, height2, width2); + var x211: ${dType} = getInputValue(batch, channel, depth2, height1, width1); + var x212: ${dType} = getInputValue(batch, channel, depth2, height1, width2); + var x221: ${dType} = getInputValue(batch, channel, depth2, height2, width1); + var x222: ${dType} = getInputValue(batch, channel, depth2, height2, width2); + var dx1: ${dType} = abs(depth - ${dType}(depth1)); + var dx2: ${dType} = abs(${dType}(depth2) - depth); + var dy1: ${dType} = abs(height - ${dType}(height1)); + var dy2: ${dType} = abs(${dType}(height2) - height); + var dz1: ${dType} = abs(width - ${dType}(width1)); + var dz2: ${dType} = abs(${dType}(width2) - width); + if (depth1 == depth2) { + dx1 = 0.5; + dx2 = 0.5; + } + if (height1 == height2) { + dy1 = 0.5; + dy2 = 0.5; + } + if (width1 == width2) { + dz1 = 0.5; + dz2 = 0.5; + } + return (x111 * dx2 * dy2 * dz2 + x112 * dx2 * dy2 * dz1 + x121 * dx2 * dy1 *dz2 + x122 * dx2 * dy1 * dz1 + + x211 * dx1 * dy2 * dz2 + x212 * dx1 * dy2 * dz1 + x221 * dx1 * dy1 *dz2 + x222 * dx1 * dy1 * dz1); + }`; + }; + const createResizeProgramInfo = (inputTensor: TensorView, attributes: ResizeAttributes, opsetVersion: number, scalesInput: readonly number[], sizes: readonly number[], roiInput: readonly number[]): ProgramInfo => { @@ -454,6 +542,7 @@ const createResizeProgramInfo = const outputSize = ShapeUtil.size(outputShape); const noScale = inputShape.length === outputShape.length && inputShape.every((d, i) => d === outputShape[i]); const useExtrapolation = attributes.coordinateTransformMode === 'tf_crop_and_resize'; + const extrapolationValue = attributes.extrapolationValue; const dataType = input.type.value; const getShaderSource = (shaderHelper: ShaderHelper) => ` ${noScale ? '' : ` @@ -471,16 +560,28 @@ const createResizeProgramInfo = case 'linear': return ` ${calculateOriginalIndicesFromOutputIndices(output, inputShape, outputShape, scales.length, roi.length)}; - ${ - bilinearInterpolation( - input, output, inputShape, scales, useExtrapolation, attributes.extrapolationValue)}; - `; + ${(() => { + if (inputShape.length === 2 || inputShape.length === 4) { + return `${bilinearInterpolation(input, output, inputShape, useExtrapolation, extrapolationValue)}`; + } else if (inputShape.length === 3 || inputShape.length === 5) { + return `${trilinearInterpolation(input, output, inputShape, useExtrapolation, extrapolationValue)}`; + } else { + throw Error('Linear mode only supports input dims 2, 3, 4 and 5 are supported in linear mode.'); + } + })()}; + `; case 'cubic': return ` - ${ - bicubicInterpolation( - input, output, inputShape, outputShape, scales, roi, attributes.cubicCoeffA, useExtrapolation, - attributes.extrapolationValue, attributes.excludeOutside)}; + ${(() => { + if (inputShape.length === 2 || inputShape.length === 4) { + return `${ + bicubicInterpolation( + input, output, inputShape, outputShape, scales, roi, attributes.cubicCoeffA, useExtrapolation, + attributes.extrapolationValue, attributes.excludeOutside)}`; + } else { + throw Error('Cubic mode only supports input dims 2 and 4 are supported in linear mode.'); + } + })()}; `; default: throw Error('Invalid resize mode'); @@ -507,21 +608,23 @@ const createResizeProgramInfo = output[global_idx] = ${attributes.extrapolationValue}; }`; case 'linear': - return 'output[global_idx] = bilinearInterpolation(output_indices);'; + return `output[global_idx] = ${ + (inputShape.length === 2 || inputShape.length === 4) ? 'bilinearInterpolation' : + 'trilinearInterpolation'}(output_indices);`; case 'cubic': return 'output[global_idx] = bicubicInterpolation(output_indices);'; default: throw Error(`Unsupported resize mode: ${attributes.mode}`); } })()}; - `} +`} }`; return { name: 'Resize', shaderCache: { hint: `${attributes.cacheKey}|${opsetVersion}|${scales.length > 0 ? scales : ''}|${ - sizes.length > 0 ? sizes : ''}|${roi.length > 0 ? roi : ''}|${noScale}`, + sizes.length > 0 ? sizes : ''}|${roi.length > 0 ? roi : ''}|${noScale}|${inputShape}`, inputDependencies: ['rank'] }, getShaderSource, @@ -551,6 +654,9 @@ export const resize = (context: ComputeContext, attributes: ResizeAttributes): v const sizes: number[] = []; const roi: number[] = []; const opsetVersion = getOpsetVersionFromCustomDataBuffer(context); + if (attributes.antialias !== 0) { + throw Error('Only default value (0) for Antialias attribute is supported'); + } validateInputs(context.inputs, attributes, opsetVersion, scales, sizes, roi); context.compute( createResizeProgramInfo(context.inputs[0], attributes, opsetVersion, scales, sizes, roi), {inputs: [0]}); diff --git a/js/web/test/data/ops/upsample.jsonc b/js/web/test/data/ops/upsample.jsonc index 6c11a8fa3f..c0ad8e547f 100644 --- a/js/web/test/data/ops/upsample.jsonc +++ b/js/web/test/data/ops/upsample.jsonc @@ -2,6 +2,7 @@ { "name": "Upsample - Nearest", "operator": "Upsample", + "opset": { "domain": "", "version": 7 }, "attributes": [{ "name": "scales", "data": [1.0, 1.0, 2.0, 3.0], "type": "floats" }], "cases": [ { @@ -32,6 +33,7 @@ { "name": "Upsample - Nearest2X", "operator": "Upsample", + "opset": { "domain": "", "version": 7 }, "attributes": [{ "name": "scales", "data": [1.0, 1.0, 2.0, 2.0], "type": "floats" }], "cases": [ { @@ -60,6 +62,7 @@ { "name": "Upsample - Nearest222X", "operator": "Upsample", + "opset": { "domain": "", "version": 7 }, "attributes": [{ "name": "scales", "data": [2.0, 1.0, 2.0, 2.0], "type": "floats" }], "cases": [ { @@ -92,6 +95,7 @@ { "name": "Upsample - Nearest15X", "operator": "Upsample", + "opset": { "domain": "", "version": 7 }, "attributes": [{ "name": "scales", "data": [1.0, 1.0, 2.0, 1.5], "type": "floats" }], "cases": [ { @@ -120,6 +124,7 @@ { "name": "Upsample - Nearest_NoScale", "operator": "Upsample", + "opset": { "domain": "", "version": 7 }, "attributes": [ { "name": "scales", "data": [1.0, 1.0, 1.0, 1.0], "type": "floats" }, { "name": "mode", "data": "nearest", "type": "string" } @@ -147,6 +152,7 @@ { "name": "Upsample - 4D Bilinear", "operator": "Upsample", + "opset": { "domain": "", "version": 7 }, "attributes": [ { "name": "scales", "data": [1.0, 1.0, 2.0, 4.0], "type": "floats" }, { "name": "mode", "data": "linear", "type": "string" } @@ -180,6 +186,7 @@ { "name": "Upsample - 2D Bilinear", "operator": "Upsample", + "opset": { "domain": "", "version": 7 }, "attributes": [ { "name": "scales", "data": [2.0, 4.0], "type": "floats" }, { "name": "mode", "data": "linear", "type": "string" } @@ -210,6 +217,7 @@ { "name": "Upsample - 4D Bilinear ScalesNoOp", "operator": "Upsample", + "opset": { "domain": "", "version": 7 }, "attributes": [ { "name": "scales", "data": [1.0, 1.0, 1.0, 1.0], "type": "floats" }, { "name": "mode", "data": "linear", "type": "string" } @@ -237,6 +245,7 @@ { "name": "Upsample - 1D Nearest", "operator": "Upsample", + "opset": { "domain": "", "version": 7 }, "attributes": [ { "name": "scales", "data": [2.0], "type": "floats" }, { "name": "mode", "data": "nearest", "type": "string" } @@ -260,5 +269,98 @@ ] } ] + }, + { + "name": "Upsample - 5D Trilinear", + "operator": "Upsample", + "opset": { "domain": "", "version": 7 }, + "attributes": [ + { "name": "scales", "data": [1.0, 1.0, 1.0, 2.0, 4.0], "type": "floats" }, + { "name": "mode", "data": "linear", "type": "string" } + ], + "cases": [ + { + "name": "X", + "inputs": [ + { + "data": [1.0, 3.0, 3.0, 5.0, 3.0, 5.0, 7.0, 9.0], + "dims": [1, 2, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 1.0, 1.5, 2.0, 2.5, 3.0, 3.0, 3.0, 3.0, 2.0, 2.5, 3.0, 3.5, 4.0, 4.0, 4.0, 4.0, 3.0, 3.5, 4.0, 4.5, 5.0, + 5.0, 5.0, 5.0, 3.0, 3.5, 4.0, 4.5, 5.0, 5.0, 5.0, 5.0, + + 3.0, 3.5, 4.0, 4.5, 5.0, 5.0, 5.0, 5.0, 5.0, 5.5, 6.0, 6.5, 7.0, 7.0, 7.0, 7.0, 7.0, 7.5, 8.0, 8.5, 9.0, + 9.0, 9.0, 9.0, 7.0, 7.5, 8.0, 8.5, 9.0, 9.0, 9.0, 9.0 + ], + "dims": [1, 2, 1, 4, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Upsample - 3D Trilinear", + "operator": "Upsample", + "opset": { "domain": "", "version": 7 }, + "attributes": [ + { "name": "scales", "data": [1.0, 2.0, 4.0], "type": "floats" }, + { "name": "mode", "data": "linear", "type": "string" } + ], + "cases": [ + { + "name": "X", + "inputs": [ + { + "data": [1.0, 3.0, 3.0, 5.0], + "dims": [1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 1.0, 1.5, 2.0, 2.5, 3.0, 3.0, 3.0, 3.0, 2.0, 2.5, 3.0, 3.5, 4.0, 4.0, 4.0, 4.0, 3.0, 3.5, 4.0, 4.5, 5.0, + 5.0, 5.0, 5.0, 3.0, 3.5, 4.0, 4.5, 5.0, 5.0, 5.0, 5.0 + ], + "dims": [1, 4, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Upsample - 3D Trilinear ScalesNoOp", + "operator": "Upsample", + "opset": { "domain": "", "version": 7 }, + "attributes": [ + { "name": "scales", "data": [1.0, 1.0, 1.0], "type": "floats" }, + { "name": "mode", "data": "linear", "type": "string" } + ], + "cases": [ + { + "name": "X", + "inputs": [ + { + "data": [1.0, 3.0, 3.0, 5.0, 3.0, 5.0, 7.0, 9.0], + "dims": [2, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1.0, 3.0, 3.0, 5.0, 3.0, 5.0, 7.0, 9.0], + "dims": [2, 2, 2], + "type": "float32" + } + ] + } + ] } ] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index a313adef71..594ce9feed 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1392,6 +1392,7 @@ "tile.jsonc", "transpose.jsonc", "transpose_int32_uint32.jsonc", + "upsample.jsonc", "where.jsonc" // Turn on this when https://github.com/microsoft/onnxruntime/issues/17405 is fixed. //"where_broadcast.jsonc", diff --git a/onnxruntime/core/providers/js/operators/conv.h b/onnxruntime/core/providers/js/operators/conv.h index 8f438a319f..5c0fbf93a4 100644 --- a/onnxruntime/core/providers/js/operators/conv.h +++ b/onnxruntime/core/providers/js/operators/conv.h @@ -3,8 +3,8 @@ #pragma once -#include #include +#include #include "core/providers/js/js_kernel.h" #include "core/providers/cpu/nn/conv_attributes.h" @@ -17,7 +17,6 @@ class ConvBase : public JsKernel { ConvBase(const OpKernelInfo& info, bool is_channels_last, bool is_fused_conv) : JsKernel(info), conv_attrs_(info), w_is_const_(false) { - std::vector activation_params; TensorShapeVector kernel_shape; const size_t pads_vec_size = conv_attrs_.pads.size() == 0 ? 4 : conv_attrs_.pads.size(); std::vector local_pads(pads_vec_size, 0); @@ -28,13 +27,8 @@ class ConvBase : public JsKernel { if (conv_attrs_.kernel_shape_specified) { ORT_ENFORCE(info.GetAttrs("kernel_shape", kernel_shape).IsOK()); } - if (is_fused_conv) { - ORT_THROW_IF_ERROR(info.GetAttr("activation", &conv_attrs_.activation)); - ORT_THROW_IF_ERROR(info.GetAttrs("activation_params", activation_params)); - } else { - conv_attrs_.activation = info.GetAttrOrDefault("activation", ""); - activation_params = info.GetAttrsOrDefault("activation_params", activation_params); - } + conv_attrs_.activation = info.GetAttrOrDefault("activation", ""); + std::vector activation_params = info.GetAttrsOrDefault("activation_params"); const auto* activation_params_ptr = activation_params.size() > 0 ? activation_params.data() : nullptr; int64_t channels_last = is_channels_last ? 1 : info.GetAttrOrDefault("channels_last", 0); auto kernel_shape_0 = conv_attrs_.kernel_shape_specified && kernel_shape.size() > 0 ? kernel_shape[0] : 0;