[JS/WebGPU] Add trilinear interpolation to Resize; activation_params attribute is optional for FusedConv also. (#18842)

### Description
Add trilinear interpolation to Resize and changed activation_params attribute as optional for FuseConv.



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
satyajandhyala 2023-12-27 16:21:29 -08:00 committed by GitHub
parent 31d4a21c4b
commit 3bbe4fe2ff
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 254 additions and 51 deletions

View file

@ -219,7 +219,7 @@ const initOutputShape =
return outputShape;
};
const adjustOutputShape = (inputShape: readonly number[], scales: number[], attributes: ResizeAttributes): number[] => {
const adjustOutputShape = (inputShape: readonly number[], scales: number[], attributes: ResizeAttributes) => {
const scaleInPolicy = (() => {
switch (attributes.keepAspectRatioPolicy) {
case 'not_larger':
@ -312,21 +312,27 @@ const checkInputIndices = (input: IndicesHelper, inputShape: readonly number[]):
return true;
}`;
const setChannelAndBatchIndices =
(input: IndicesHelper, channelIdx: number, batchIdx: number, spacialDims: number): string =>
input.rank > spacialDims ? `
${input.indicesSet('input_indices', channelIdx, 'channel')};
${input.indicesSet('input_indices', batchIdx, 'batch')};
` :
'';
const bilinearInterpolation =
(input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], scales: readonly number[],
useExtrapolation: boolean, extrapolationValue: number): string => {
(input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], useExtrapolation: boolean,
extrapolationValue: number): string => {
const isNchw = true;
const [batchIdx, heightIdx, widthIdx, channelIdx] =
inputShape.length === 2 ? [-1, 0, 1, -1] : (scales[1] === 1.0 ? [0, 2, 3, 1] : [0, 1, 2, 3]);
inputShape.length === 2 ? [-1, 0, 1, -1] : (isNchw ? [0, 2, 3, 1] : [0, 1, 2, 3]);
const dType = input.type.value;
return `
fn getInputValue(batch: u32, channel: u32, row: u32, col: u32) -> ${dType} {
var input_indices: ${input.type.indices};
${input.indicesSet('input_indices', heightIdx, `max(0, min(row, ${inputShape[heightIdx]} - 1))`)};
${input.indicesSet('input_indices', widthIdx, `max(0, min(col, ${inputShape[widthIdx]} - 1))`)};
if (${inputShape.length} > 2) {
${input.indicesSet('input_indices', channelIdx, 'channel')};
${input.indicesSet('input_indices', batchIdx, 'batch')};
};
${setChannelAndBatchIndices(input, channelIdx, batchIdx, 2)}
return ${input.getByIndices('input_indices')};
}
@ -334,30 +340,36 @@ const bilinearInterpolation =
var originalIndices = calculateOriginalIndicesFromOutputIndices(output_indices);
var row:${dType} = originalIndices[${heightIdx}];
var col:${dType} = originalIndices[${widthIdx}];
if (${useExtrapolation} && (row < 0 || row > (${inputShape[heightIdx]} - 1) || col < 0 || col > ${
inputShape[widthIdx]} - 1)) {
${
useExtrapolation ?
`if (row < 0 || row > (${inputShape[heightIdx]} - 1) || col < 0 || col > (${inputShape[widthIdx]} - 1))) {
return ${extrapolationValue};
}
}` :
''};
row = max(0, min(row, ${inputShape[heightIdx]} - 1));
col = max(0, min(col, ${inputShape[widthIdx]} - 1));
var row1: u32 = u32(row);
var col1: u32 = u32(col);
var row2: u32 = u32(row + 1);
var col2: u32 = u32(col + 1);
var channel: u32 = 0;
var batch: u32 = 0;
if (${inputShape.length > 2}) {
channel = u32(originalIndices[${channelIdx}]);
batch = u32(originalIndices[${batchIdx}]);
}
var channel: u32 = ${inputShape.length > 2 ? `u32(originalIndices[${channelIdx}])` : '0'};
var batch: u32 = ${inputShape.length > 2 ? `u32(originalIndices[${batchIdx}])` : '0'};
var x11: ${dType} = getInputValue(batch, channel, row1, col1);
var x12: ${dType} = getInputValue(batch, channel, row1, col2);
var x21: ${dType} = getInputValue(batch, channel, row2, col1);
var x22: ${dType} = getInputValue(batch, channel, row2, col2);
var dx1: ${dType} = row - ${dType}(row1);
var dx2: ${dType} = ${dType}(row2) - row;
var dy1 = col - ${dType}(col1);
var dy2 = ${dType}(col2) - col;
var dx1: ${dType} = abs(row - ${dType}(row1));
var dx2: ${dType} = abs(${dType}(row2) - row);
var dy1: ${dType} = abs(col - ${dType}(col1));
var dy2: ${dType} = abs(${dType}(col2) - col);
if (row1 == row2) {
dx1 = 0.5;
dx2 = 0.5;
}
if (col1 == col2) {
dy1 = 0.5;
dy2 = 0.5;
}
return (x11 * dx2 * dy2 + x12 * dx2 * dy1 + x21 * dx1 * dy2 + x22 * dx1 * dy1);
}`;
};
@ -366,7 +378,9 @@ const bicubicInterpolation =
(input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[],
scales: readonly number[], roi: readonly number[], cubicCoeffA: number, useExtrapolation: boolean,
extrapolationValue: number, excludeOutside: boolean): string => {
const [heightIdx, widthIdx] = inputShape.length === 2 ? [0, 1] : (scales[1] === 1.0) ? [2, 3] : [1, 2];
const is2D = inputShape.length === 2;
const isNchw = true;
const [heightIdx, widthIdx] = is2D ? [0, 1] : isNchw ? [2, 3] : [1, 2];
const dType = input.type.value;
const createCubicInterpolationFunction = (idx: number): string => {
const direction = idx === heightIdx ? 'row' : 'col';
@ -386,16 +400,18 @@ const bicubicInterpolation =
for (var i: i32 = -1; i < 3; i++) {
var ${direction}: ${dType} = originalIdx + ${dType}(i);
if (${direction} < 0 || ${direction} >= ${inputShape[idx]}) {
if (${excludeOutside}) {
coefs[i + 1] = 0.0;
continue;
} else if (${useExtrapolation}) {
return ${extrapolationValue};
} else {
${direction} = max(0, min(${direction}, ${inputShape[idx]} - 1));
}
${(() => {
if (excludeOutside) {
return `coefs[i + 1] = 0.0;
continue;`;
} else if (useExtrapolation) {
return `return ${extrapolationValue};`;
} else {
return `${direction} = max(0, min(${direction}, ${inputShape[idx]} - 1));`;
}
var input_indices_copy: ${input.type.indices} = input_indices;
})()};
}
var input_indices_copy: ${input.type.indices} = input_indices;
${input.indicesSet('input_indices_copy', idx, `u32(${direction})`)};
data[i + 1] = ${
idx === heightIdx ? input.getByIndices('input_indices_copy') :
@ -435,6 +451,78 @@ const bicubicInterpolation =
`;
};
const trilinearInterpolation =
(input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], useExtrapolation: boolean,
extrapolationValue: number): string => {
const isNchw = true;
const [batchIdx, depthIdx, heightIdx, widthIdx, channelIdx] =
inputShape.length === 3 ? [-1, 0, 1, 2, -1] : (isNchw ? [0, 2, 3, 4, 1] : [0, 1, 2, 3, 4]);
const dType = input.type.value;
return `
fn getInputValue(batch: u32, channel: u32, depth:u32, height: u32, width: u32) -> ${dType} {
var input_indices: ${input.type.indices};
${input.indicesSet('input_indices', depthIdx, `max(0, min(depth, ${inputShape[depthIdx]} - 1))`)};
${input.indicesSet('input_indices', heightIdx, `max(0, min(height, ${inputShape[heightIdx]} - 1))`)};
${input.indicesSet('input_indices', widthIdx, `max(0, min(width, ${inputShape[widthIdx]} - 1))`)};
${setChannelAndBatchIndices(input, channelIdx, batchIdx, 3)}
return ${input.getByIndices('input_indices')};
}
fn trilinearInterpolation(output_indices: ${output.type.indices}) -> ${dType} {
var originalIndices = calculateOriginalIndicesFromOutputIndices(output_indices);
var depth:${dType} = originalIndices[${depthIdx}];
var height:${dType} = originalIndices[${heightIdx}];
var width:${dType} = originalIndices[${widthIdx}];
${
useExtrapolation ? `if (depth < 0 || depth > (${inputShape[depthIdx]} - 1) || height < 0 || height > (${
inputShape[heightIdx]} - 1) || width < 0 || (width > ${inputShape[widthIdx]} - 1))) {
return ${extrapolationValue};
}` :
''};
depth = max(0, min(depth, ${inputShape[depthIdx]} - 1));
height = max(0, min(height, ${inputShape[heightIdx]} - 1));
width = max(0, min(width, ${inputShape[widthIdx]} - 1));
var depth1: u32 = u32(depth);
var height1: u32 = u32(height);
var width1: u32 = u32(width);
var depth2: u32 = u32(depth + 1);
var height2: u32 = u32(height + 1);
var width2: u32 = u32(width + 1);
var channel: u32 = ${inputShape.length > 3 ? `u32(originalIndices[${channelIdx}])` : '0'};
var batch: u32 = ${inputShape.length > 3 ? `u32(originalIndices[${batchIdx}])` : '0'};
var x111: ${dType} = getInputValue(batch, channel, depth1, height1, width1);
var x112: ${dType} = getInputValue(batch, channel, depth1, height1, width2);
var x121: ${dType} = getInputValue(batch, channel, depth1, height2, width1);
var x122: ${dType} = getInputValue(batch, channel, depth1, height2, width2);
var x211: ${dType} = getInputValue(batch, channel, depth2, height1, width1);
var x212: ${dType} = getInputValue(batch, channel, depth2, height1, width2);
var x221: ${dType} = getInputValue(batch, channel, depth2, height2, width1);
var x222: ${dType} = getInputValue(batch, channel, depth2, height2, width2);
var dx1: ${dType} = abs(depth - ${dType}(depth1));
var dx2: ${dType} = abs(${dType}(depth2) - depth);
var dy1: ${dType} = abs(height - ${dType}(height1));
var dy2: ${dType} = abs(${dType}(height2) - height);
var dz1: ${dType} = abs(width - ${dType}(width1));
var dz2: ${dType} = abs(${dType}(width2) - width);
if (depth1 == depth2) {
dx1 = 0.5;
dx2 = 0.5;
}
if (height1 == height2) {
dy1 = 0.5;
dy2 = 0.5;
}
if (width1 == width2) {
dz1 = 0.5;
dz2 = 0.5;
}
return (x111 * dx2 * dy2 * dz2 + x112 * dx2 * dy2 * dz1 + x121 * dx2 * dy1 *dz2 + x122 * dx2 * dy1 * dz1 +
x211 * dx1 * dy2 * dz2 + x212 * dx1 * dy2 * dz1 + x221 * dx1 * dy1 *dz2 + x222 * dx1 * dy1 * dz1);
}`;
};
const createResizeProgramInfo =
(inputTensor: TensorView, attributes: ResizeAttributes, opsetVersion: number, scalesInput: readonly number[],
sizes: readonly number[], roiInput: readonly number[]): ProgramInfo => {
@ -454,6 +542,7 @@ const createResizeProgramInfo =
const outputSize = ShapeUtil.size(outputShape);
const noScale = inputShape.length === outputShape.length && inputShape.every((d, i) => d === outputShape[i]);
const useExtrapolation = attributes.coordinateTransformMode === 'tf_crop_and_resize';
const extrapolationValue = attributes.extrapolationValue;
const dataType = input.type.value;
const getShaderSource = (shaderHelper: ShaderHelper) => `
${noScale ? '' : `
@ -471,16 +560,28 @@ const createResizeProgramInfo =
case 'linear':
return `
${calculateOriginalIndicesFromOutputIndices(output, inputShape, outputShape, scales.length, roi.length)};
${
bilinearInterpolation(
input, output, inputShape, scales, useExtrapolation, attributes.extrapolationValue)};
`;
${(() => {
if (inputShape.length === 2 || inputShape.length === 4) {
return `${bilinearInterpolation(input, output, inputShape, useExtrapolation, extrapolationValue)}`;
} else if (inputShape.length === 3 || inputShape.length === 5) {
return `${trilinearInterpolation(input, output, inputShape, useExtrapolation, extrapolationValue)}`;
} else {
throw Error('Linear mode only supports input dims 2, 3, 4 and 5 are supported in linear mode.');
}
})()};
`;
case 'cubic':
return `
${
bicubicInterpolation(
input, output, inputShape, outputShape, scales, roi, attributes.cubicCoeffA, useExtrapolation,
attributes.extrapolationValue, attributes.excludeOutside)};
${(() => {
if (inputShape.length === 2 || inputShape.length === 4) {
return `${
bicubicInterpolation(
input, output, inputShape, outputShape, scales, roi, attributes.cubicCoeffA, useExtrapolation,
attributes.extrapolationValue, attributes.excludeOutside)}`;
} else {
throw Error('Cubic mode only supports input dims 2 and 4 are supported in linear mode.');
}
})()};
`;
default:
throw Error('Invalid resize mode');
@ -507,21 +608,23 @@ const createResizeProgramInfo =
output[global_idx] = ${attributes.extrapolationValue};
}`;
case 'linear':
return 'output[global_idx] = bilinearInterpolation(output_indices);';
return `output[global_idx] = ${
(inputShape.length === 2 || inputShape.length === 4) ? 'bilinearInterpolation' :
'trilinearInterpolation'}(output_indices);`;
case 'cubic':
return 'output[global_idx] = bicubicInterpolation(output_indices);';
default:
throw Error(`Unsupported resize mode: ${attributes.mode}`);
}
})()};
`}
`}
}`;
return {
name: 'Resize',
shaderCache: {
hint: `${attributes.cacheKey}|${opsetVersion}|${scales.length > 0 ? scales : ''}|${
sizes.length > 0 ? sizes : ''}|${roi.length > 0 ? roi : ''}|${noScale}`,
sizes.length > 0 ? sizes : ''}|${roi.length > 0 ? roi : ''}|${noScale}|${inputShape}`,
inputDependencies: ['rank']
},
getShaderSource,
@ -551,6 +654,9 @@ export const resize = (context: ComputeContext, attributes: ResizeAttributes): v
const sizes: number[] = [];
const roi: number[] = [];
const opsetVersion = getOpsetVersionFromCustomDataBuffer(context);
if (attributes.antialias !== 0) {
throw Error('Only default value (0) for Antialias attribute is supported');
}
validateInputs(context.inputs, attributes, opsetVersion, scales, sizes, roi);
context.compute(
createResizeProgramInfo(context.inputs[0], attributes, opsetVersion, scales, sizes, roi), {inputs: [0]});

View file

@ -2,6 +2,7 @@
{
"name": "Upsample - Nearest",
"operator": "Upsample",
"opset": { "domain": "", "version": 7 },
"attributes": [{ "name": "scales", "data": [1.0, 1.0, 2.0, 3.0], "type": "floats" }],
"cases": [
{
@ -32,6 +33,7 @@
{
"name": "Upsample - Nearest2X",
"operator": "Upsample",
"opset": { "domain": "", "version": 7 },
"attributes": [{ "name": "scales", "data": [1.0, 1.0, 2.0, 2.0], "type": "floats" }],
"cases": [
{
@ -60,6 +62,7 @@
{
"name": "Upsample - Nearest222X",
"operator": "Upsample",
"opset": { "domain": "", "version": 7 },
"attributes": [{ "name": "scales", "data": [2.0, 1.0, 2.0, 2.0], "type": "floats" }],
"cases": [
{
@ -92,6 +95,7 @@
{
"name": "Upsample - Nearest15X",
"operator": "Upsample",
"opset": { "domain": "", "version": 7 },
"attributes": [{ "name": "scales", "data": [1.0, 1.0, 2.0, 1.5], "type": "floats" }],
"cases": [
{
@ -120,6 +124,7 @@
{
"name": "Upsample - Nearest_NoScale",
"operator": "Upsample",
"opset": { "domain": "", "version": 7 },
"attributes": [
{ "name": "scales", "data": [1.0, 1.0, 1.0, 1.0], "type": "floats" },
{ "name": "mode", "data": "nearest", "type": "string" }
@ -147,6 +152,7 @@
{
"name": "Upsample - 4D Bilinear",
"operator": "Upsample",
"opset": { "domain": "", "version": 7 },
"attributes": [
{ "name": "scales", "data": [1.0, 1.0, 2.0, 4.0], "type": "floats" },
{ "name": "mode", "data": "linear", "type": "string" }
@ -180,6 +186,7 @@
{
"name": "Upsample - 2D Bilinear",
"operator": "Upsample",
"opset": { "domain": "", "version": 7 },
"attributes": [
{ "name": "scales", "data": [2.0, 4.0], "type": "floats" },
{ "name": "mode", "data": "linear", "type": "string" }
@ -210,6 +217,7 @@
{
"name": "Upsample - 4D Bilinear ScalesNoOp",
"operator": "Upsample",
"opset": { "domain": "", "version": 7 },
"attributes": [
{ "name": "scales", "data": [1.0, 1.0, 1.0, 1.0], "type": "floats" },
{ "name": "mode", "data": "linear", "type": "string" }
@ -237,6 +245,7 @@
{
"name": "Upsample - 1D Nearest",
"operator": "Upsample",
"opset": { "domain": "", "version": 7 },
"attributes": [
{ "name": "scales", "data": [2.0], "type": "floats" },
{ "name": "mode", "data": "nearest", "type": "string" }
@ -260,5 +269,98 @@
]
}
]
},
{
"name": "Upsample - 5D Trilinear",
"operator": "Upsample",
"opset": { "domain": "", "version": 7 },
"attributes": [
{ "name": "scales", "data": [1.0, 1.0, 1.0, 2.0, 4.0], "type": "floats" },
{ "name": "mode", "data": "linear", "type": "string" }
],
"cases": [
{
"name": "X",
"inputs": [
{
"data": [1.0, 3.0, 3.0, 5.0, 3.0, 5.0, 7.0, 9.0],
"dims": [1, 2, 1, 2, 2],
"type": "float32"
}
],
"outputs": [
{
"data": [
1.0, 1.5, 2.0, 2.5, 3.0, 3.0, 3.0, 3.0, 2.0, 2.5, 3.0, 3.5, 4.0, 4.0, 4.0, 4.0, 3.0, 3.5, 4.0, 4.5, 5.0,
5.0, 5.0, 5.0, 3.0, 3.5, 4.0, 4.5, 5.0, 5.0, 5.0, 5.0,
3.0, 3.5, 4.0, 4.5, 5.0, 5.0, 5.0, 5.0, 5.0, 5.5, 6.0, 6.5, 7.0, 7.0, 7.0, 7.0, 7.0, 7.5, 8.0, 8.5, 9.0,
9.0, 9.0, 9.0, 7.0, 7.5, 8.0, 8.5, 9.0, 9.0, 9.0, 9.0
],
"dims": [1, 2, 1, 4, 8],
"type": "float32"
}
]
}
]
},
{
"name": "Upsample - 3D Trilinear",
"operator": "Upsample",
"opset": { "domain": "", "version": 7 },
"attributes": [
{ "name": "scales", "data": [1.0, 2.0, 4.0], "type": "floats" },
{ "name": "mode", "data": "linear", "type": "string" }
],
"cases": [
{
"name": "X",
"inputs": [
{
"data": [1.0, 3.0, 3.0, 5.0],
"dims": [1, 2, 2],
"type": "float32"
}
],
"outputs": [
{
"data": [
1.0, 1.5, 2.0, 2.5, 3.0, 3.0, 3.0, 3.0, 2.0, 2.5, 3.0, 3.5, 4.0, 4.0, 4.0, 4.0, 3.0, 3.5, 4.0, 4.5, 5.0,
5.0, 5.0, 5.0, 3.0, 3.5, 4.0, 4.5, 5.0, 5.0, 5.0, 5.0
],
"dims": [1, 4, 8],
"type": "float32"
}
]
}
]
},
{
"name": "Upsample - 3D Trilinear ScalesNoOp",
"operator": "Upsample",
"opset": { "domain": "", "version": 7 },
"attributes": [
{ "name": "scales", "data": [1.0, 1.0, 1.0], "type": "floats" },
{ "name": "mode", "data": "linear", "type": "string" }
],
"cases": [
{
"name": "X",
"inputs": [
{
"data": [1.0, 3.0, 3.0, 5.0, 3.0, 5.0, 7.0, 9.0],
"dims": [2, 2, 2],
"type": "float32"
}
],
"outputs": [
{
"data": [1.0, 3.0, 3.0, 5.0, 3.0, 5.0, 7.0, 9.0],
"dims": [2, 2, 2],
"type": "float32"
}
]
}
]
}
]

View file

@ -1392,6 +1392,7 @@
"tile.jsonc",
"transpose.jsonc",
"transpose_int32_uint32.jsonc",
"upsample.jsonc",
"where.jsonc"
// Turn on this when https://github.com/microsoft/onnxruntime/issues/17405 is fixed.
//"where_broadcast.jsonc",

View file

@ -3,8 +3,8 @@
#pragma once
#include <vector>
#include <string>
#include <vector>
#include "core/providers/js/js_kernel.h"
#include "core/providers/cpu/nn/conv_attributes.h"
@ -17,7 +17,6 @@ class ConvBase : public JsKernel {
ConvBase(const OpKernelInfo& info, bool is_channels_last, bool is_fused_conv) : JsKernel(info),
conv_attrs_(info),
w_is_const_(false) {
std::vector<float> activation_params;
TensorShapeVector kernel_shape;
const size_t pads_vec_size = conv_attrs_.pads.size() == 0 ? 4 : conv_attrs_.pads.size();
std::vector<int32_t> local_pads(pads_vec_size, 0);
@ -28,13 +27,8 @@ class ConvBase : public JsKernel {
if (conv_attrs_.kernel_shape_specified) {
ORT_ENFORCE(info.GetAttrs("kernel_shape", kernel_shape).IsOK());
}
if (is_fused_conv) {
ORT_THROW_IF_ERROR(info.GetAttr<std::string>("activation", &conv_attrs_.activation));
ORT_THROW_IF_ERROR(info.GetAttrs<float>("activation_params", activation_params));
} else {
conv_attrs_.activation = info.GetAttrOrDefault<std::string>("activation", "");
activation_params = info.GetAttrsOrDefault<float>("activation_params", activation_params);
}
conv_attrs_.activation = info.GetAttrOrDefault<std::string>("activation", "");
std::vector<float> activation_params = info.GetAttrsOrDefault<float>("activation_params");
const auto* activation_params_ptr = activation_params.size() > 0 ? activation_params.data() : nullptr;
int64_t channels_last = is_channels_last ? 1 : info.GetAttrOrDefault<int64_t>("channels_last", 0);
auto kernel_shape_0 = conv_attrs_.kernel_shape_specified && kernel_shape.size() > 0 ? kernel_shape[0] : 0;