[JS/WebGPU] Add Resize operator (#16680)

### Description
Implemented Resize operator support in JSEP



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
satyajandhyala 2023-07-31 09:35:06 -07:00 committed by GitHub
parent 3fd1d3b9bd
commit 77b2b618b2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 879 additions and 35 deletions

View file

@ -22,8 +22,8 @@ Do not modify directly.*
| Ceil | ai.onnx(6-12,13+) | |
| Clip | ai.onnx(6-10,11,12,13+) | |
| Concat | ai.onnx(1-3,4-10,11-12,13+) | |
| Conv | ai.onnx(1-10,11+); com.ms.internal.nhwc(11+) | need perf optimization; conv3d not supported; need implementing activation |
| ConvTranspose | ai.onnx(1-10,11+); com.ms.internal.nhwc(11+) | |
| Conv | ai.onnx(1-10,11+); com.ms.internal.nhwc(11+) | need perf optimization; conv3d is not supported; need implementing activation |
| ConvTranspose | ai.onnx(1-10,11+); com.ms.internal.nhwc(11+) | need perf optimization; ConvTranspose3d is not supported; need implementing activation |
| Cos | ai.onnx(7+) | |
| Cosh | ai.onnx(9+) | |
| Div | ai.onnx(7-12,13,14+) | |
@ -58,6 +58,7 @@ Do not modify directly.*
| ReduceSumSquare | ai.onnx(1-10,11-12,13-17,18+) | |
| Relu | ai.onnx(6-12,13,14+) | |
| Reshape | ai.onnx(5-12,13,14+) | no GPU kernel |
| Resize | ai.onnx(10,11-12,13-17,18,19+) | CoordinateTransformMode align_corners is not supported with downsampling |
| Shape | ai.onnx(1-12,13-14,15+) | no GPU kernel; an ORT warning is generated - need to fix |
| Sigmoid | ai.onnx(6-12,13+) | |
| Sin | ai.onnx(7+) | |

View file

@ -11,6 +11,7 @@ import {gemm, parseGemmAttributes} from './ops/gemm';
import {matMul} from './ops/matmul';
import * as pool from './ops/pool';
import {parseReduceAttributes, reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce';
import {parseResizeAttributes, resize} from './ops/resize';
import {parseSliceAttributes, slice} from './ops/slice';
import {parseSplitAttributes, split} from './ops/split';
import {parseTransposeAttributes, transpose} from './ops/transpose';
@ -69,6 +70,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['ReduceLogSumExp', [reduceLogSumExp, parseReduceAttributes]],
['ReduceSumSquare', [reduceSumSquare, parseReduceAttributes]],
['Relu', [unaryOps.relu]],
['Resize', [resize, parseResizeAttributes]],
['Sigmoid', [unaryOps.sigmoid]],
['Sin', [unaryOps.sin]],
['Sinh', [unaryOps.sinh]],

View file

@ -0,0 +1,595 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {TensorView} from '../../tensor';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';
import {createIndicesHelper, ShaderHelper} from './common';
type CoordinateTransformMode = 'half_pixel'|'asymmetric'|'pytorch_half_pixel'|'tf_half_pixel_for_nn'|'align_corners'|
'tf_crop_and_resize'|'half_pixel_symmetric';
type KeepAspectRatioPolicy = 'stretch'|'not_smaller'|'not_larger';
type Mode = 'nearest'|'linear'|'cubic';
type NearestMode = 'round_prefer_floor'|'round_prefer_ceil'|'floor'|'ceil'|'simple';
export interface ResizeAttributes extends AttributeWithCacheKey {
antialias: number;
axes: number[];
coordinateTransformMode: CoordinateTransformMode;
cubicCoeffA: number;
excludeOutside: boolean;
extrapolationValue: number;
keepAspectRatioPolicy: KeepAspectRatioPolicy;
mode: Mode;
nearestMode: NearestMode;
}
const validateScales = (scales: number[], attributes: ResizeAttributes): void => {
scales.every((value) => value > 0 || (() => {
throw new Error('Resize requires scales input values to be positive');
}));
// Check scales dims based on mode: LINEAR, CUBIC
if (scales.length > 0) {
if (attributes.mode === 'linear') {
if (!(scales.length === 2 || (scales.length === 4 && scales[0] === 1 && scales[1] === 1) ||
(scales.length === 4 && scales[0] === 1 && scales[3] === 1))) {
throw new Error('Resize requires scales input size to be 2 or 4 for linear mode');
}
} else if (attributes.mode === 'cubic') {
if (!(scales.length === 2 || (scales.length === 4 && scales[0] === 1 && scales[1] === 1) ||
(scales.length === 4 && scales[0] === 1 && scales[3] === 1))) {
throw new Error('Resize requires scales input size to be 2 or 4 for cubic mode');
}
}
}
};
const updateScales = (scales: readonly number[], axes: readonly number[], rank: number): number[] => {
axes.every((value) => value >= 0 && value < rank || (() => {
throw new Error('Resize requires axes input values to be positive and less than rank');
}));
const newScales = new Array(rank).fill(1.0);
axes.forEach((value, index) => newScales[value] = scales[index]);
return newScales;
};
const validateInputs =
(inputs: readonly TensorView[], attributes: ResizeAttributes, opsetVersion: number, scales: number[],
sizes: number[], roi: number[]): void => {
const [roiInputIndex, scalesInputIndex, sizesInputIndex] =
(opsetVersion > 10) ? [1, 2, 3] : [-1, (inputs.length > 1) ? 1 : -1, -1];
const rank = inputs[0].dims.length;
if (roiInputIndex > 0 && inputs.length > roiInputIndex && inputs[roiInputIndex].dims.length > 0) {
inputs[roiInputIndex].getFloat32Array().forEach((value) => roi.push(value));
} else if (attributes.coordinateTransformMode === 'tf_crop_and_resize') {
throw new Error('Resize requires RoI input to be specified when coordinateTransformMode is tfCropAndResize');
}
if (scalesInputIndex > 0 && inputs.length > scalesInputIndex && inputs[scalesInputIndex].dims.length > 0) {
inputs[scalesInputIndex].getFloat32Array().forEach((value) => scales.push(value));
if (scales.length !== 0 &&
(scales.length !== rank && (opsetVersion >= 18 && scales.length !== attributes.axes.length))) {
throw new Error(
'Resize requires scales input size to be same as input rank or axes size for opset 18 and up');
}
validateScales(scales, attributes);
if (attributes.axes.length > 0) {
updateScales(scales, attributes.axes, rank).forEach((value, index) => scales[index] = value);
}
}
if (sizesInputIndex > 0 && inputs.length > sizesInputIndex) {
inputs[sizesInputIndex].getBigInt64Array().forEach((value) => sizes.push(Number(value)));
if (sizes.length !== rank || (opsetVersion >= 18 && sizes.length === attributes.axes.length)) {
throw new Error('Resize requires sizes input size to be same as input rank or axes size for opset 18 and up');
}
}
if (attributes.axes.length > 0) {
if (scales.length !== attributes.axes.length) {
throw new Error('Resize requires "scales" input size to be of axes rank when axes attributes is specified');
}
if (sizes.length !== attributes.axes.length) {
throw new Error(
'Resize requires "sizes" input size to be of rank axes rank when axes attributes is specified');
}
}
if (typeof scales !== 'undefined' && typeof sizes !== 'undefined' && scales.length > 0 && sizes.length > rank) {
throw new Error('Resize requires only of scales or sizes to be specified');
}
};
const getOriginalCoordinateFromResizedCoordinate = (coordinateTransferMode: CoordinateTransformMode): string =>
'fn getOriginalCoordinateFromResizedCoordinate(xResized: f32, xScale: f32, lengthResized: f32,\
lengthOriginal: f32, roiStart: f32, roiEnd: f32) -> f32 { ' +
(() => {
switch (coordinateTransferMode) {
case 'asymmetric':
return 'return xResized / xScale;';
case 'pytorch_half_pixel':
return 'if (lengthResized > 1) { \
return (xResized + 0.5) / xScale - 0.5; \
} else { \
return 0.0; \
}';
case 'tf_half_pixel_for_nn':
return 'return (xResized + 0.5) / xScale;';
case 'align_corners':
return 'if (lengthResized == 1) { \
return 0.0; \
} else { \
return xResized * (lengthOriginal - 1) / (lengthResized - 1); \
}';
case 'tf_crop_and_resize':
return 'if (lengthResized > 1) { \
return roiStart * (lengthOriginal - 1) + \
(xResized * (roiEnd - roiStart) * (lengthOriginal - 1)) / (lengthResized - 1); \
} else { \
return 0.5 * (roiStart + roiEnd) * f32(lengthOriginal - 1); \
}';
case 'half_pixel_symmetric':
return [
'const outputWidth = xScale * lengthResized;', 'const adjustment = lengthResized / outputWidth;',
'const center = lengthOriginal / 2;', 'const offset = center * (1 - adjustment);',
'return offset + ((xResized + 0.5) / xScale) - 0.5;'
].join('\n');
case 'half_pixel':
return 'return ((xResized + 0.5) / xScale) - 0.5;';
default:
throw new Error(`Coordinate transform mode ${coordinateTransferMode} is not supported`);
}
})() +
'}';
const getNearestPixelFromOriginal = (nearestMode: NearestMode, opsetVersion: number): string =>
'fn getNearestPixelFromOriginal(xOriginal: f32, isDownSample: bool) -> f32 {' + (() => {
switch (nearestMode) {
case 'round_prefer_ceil':
return 'if (fract(xOriginal) == 0.5) { \
return ceil(xOriginal); \
} else { \
return round(xOriginal); \
}';
case 'floor':
return 'return floor(xOriginal);';
case 'ceil':
return 'return ceil(xOriginal);';
case 'round_prefer_floor':
return 'if (fract(xOriginal) == 0.5) { \
return floor(xOriginal); \
} else { \
return round(xOriginal); \
}';
case 'simple':
default:
if (opsetVersion < 11) {
return 'if (isDownSample) \
{ \
return ceil(xOriginal); \
} else { \
return xOriginal; \
}';
}
throw new Error(`Nearest mode ${nearestMode} is not supported`);
}
})() +
'}';
const updateRoI = (roi: readonly number[], axes: readonly number[], rank: number): number[] => {
const roiTmp = new Array(rank).fill(0).concat(new Array(rank).fill(1));
const roiLocal = roi.length === 0 ? roiTmp : roi.slice();
if (axes.length > 0) {
axes.forEach((v, i) => {
roiTmp[v] = roiLocal[i];
roiTmp[i + rank] = roiLocal[axes.length + i];
});
return roiTmp;
}
return roiLocal;
};
const initOutputShape =
(inputShape: readonly number[], scales: readonly number[], sizes: readonly number[], axes: readonly number[]):
number[] => {
let outputShape: number[] = [];
if (sizes.length > 0) {
if (axes.length > 0) {
inputShape.forEach((v) => outputShape.push(v));
if (Math.max(...axes) > inputShape.length) {
throw new Error('axes is out of bound');
}
axes.forEach((v, i) => outputShape[v] = sizes[i]);
} else {
sizes.forEach((v) => outputShape.push(v));
}
} else {
if (scales.length === 0) {
throw new Error('Resize requires either scales or sizes.');
} else {
outputShape = inputShape.map((value, index) => Math.round(value * scales[index]));
}
}
return outputShape;
};
const adjustOutputShape =
(inputShape: readonly number[], outputShape: readonly number[], scales: number[], attributes: ResizeAttributes):
number[] => {
const scaleInPolicy = (() => {
switch (attributes.keepAspectRatioPolicy) {
case 'not_larger':
return attributes.axes.length > 0 ? Math.min(...attributes.axes.map(i => scales[i]), Number.MAX_VALUE) :
Math.min(...scales, Number.MAX_VALUE);
case 'not_smaller':
return attributes.axes.length > 0 ? Math.max(...attributes.axes.map(i => scales[i]), Number.MIN_VALUE) :
Math.max(...scales, Number.MIN_VALUE);
default:
throw new Error(`Keep aspect ratio policy ${attributes.keepAspectRatioPolicy} is not supported`);
}
})();
scales.fill(1.0, 0, scales.length);
const adjustedOutputShape = inputShape.slice();
if (attributes.axes.length > 0) {
attributes.axes.forEach((v) => scales[v] = scaleInPolicy);
attributes.axes.forEach((v) => adjustedOutputShape[v] = Math.round(inputShape[v] * scales[v]));
} else {
scales.fill(scaleInPolicy, 0, scales.length);
adjustedOutputShape.forEach((v, i) => adjustedOutputShape[i] = Math.round(v * scales[i]));
}
return adjustedOutputShape;
};
const calculateOriginalIndicesFromOutputIndices =
(inputShape: readonly number[], outputShape: readonly number[], scales: readonly number[], roi: readonly number[]):
string => {
const outputIndicesHelper = createIndicesHelper('output', outputShape);
return `
fn calculateOriginalIndicesFromOutputIndices(outputIndices: ${outputIndicesHelper.iType}) -> array<f32, ${
outputShape.length}> {
const inputShape = array<u32, ${inputShape.length}>(${inputShape.map(i => `${i}u`).join(',')});
const outputShape = array<u32, ${outputShape.length}>(${outputShape.map(i => `${i}u`).join(',')});
const scales = array<f32, ${scales.length}>(${scales.map(i => `${i}f`).join(',')});
const roi = array<f32, ${roi.length}>(${roi.map(i => `${i}f`).join(',')});
var originalIndices: array<f32, ${outputShape.length}>;
for (var i:u32 = 0; i < ${outputShape.length}; i++) {
var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'};
if (scales[i] == 1.0) {
originalIndices[i] = f32(outputIndex);
} else {
originalIndices[i] = getOriginalCoordinateFromResizedCoordinate(f32(outputIndex), scales[i],
f32(outputShape[i]), f32(inputShape[i]), roi[i], roi[i + ${inputShape.length}]);
}
}
return originalIndices;
}`;
};
const calculateInputIndicesFromOutputIndices =
(inputShape: readonly number[], outputShape: readonly number[], scales: readonly number[], roi: readonly number[],
useExtrapolation: boolean): string => {
const outputIndicesHelper = createIndicesHelper('output', outputShape);
const inputIndicesHelper = createIndicesHelper('input', inputShape);
return `
fn calculateInputIndicesFromOutputIndices(outputIndices: ${outputIndicesHelper.iType}) -> array<u32, ${
inputShape.length}> {
const inputShape = array<u32, ${inputShape.length}>(${inputShape.map(i => `${i}u`).join(',')});
const outputShape = array<u32, ${outputShape.length}>(${outputShape.map(i => `${i}u`).join(',')});
const scales = array<f32, ${scales.length}>(${scales.map(i => `${i}f`).join(',')});
const roi = array<f32, ${roi.length}>(${roi.map(i => `${i}f`).join(',')});
var inputIndices: ${inputIndicesHelper.iType};
for (var i:u32 = 0; i < ${outputShape.length}; i++) {
var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'};
var inputIndex: u32;
if (scales[i] == 1.0) {
inputIndex = outputIndex;
} else {
var original_idx = getOriginalCoordinateFromResizedCoordinate(f32(outputIndex), scales[i],
f32(outputShape[i]), f32(inputShape[i]), roi[i], roi[i + ${inputShape.length}]);
if (!${useExtrapolation} || (original_idx >= 0 && original_idx < f32(inputShape[i]))) {
if (original_idx < 0) {
inputIndex = 0;
} else if (original_idx > (f32(inputShape[i]) - 1)) {
inputIndex = inputShape[i] - 1;
} else {
inputIndex = u32(getNearestPixelFromOriginal(original_idx, scales[i] < 1));
}
} else {
inputIndex = u32(original_idx);
}
}
${inputShape.length === 1 ? 'inputIndices' : 'inputIndices[i]'} = inputIndex;
}
return inputIndices;
}`;
};
const checkInputIndices = (inputShape: readonly number[]): string => {
const inputIndicesHelper = createIndicesHelper('output', inputShape);
return `
fn checkInputIndices(inputIndices: ${inputIndicesHelper.iType}) -> bool {
const inputShape = array<u32, ${inputShape.length}>(${inputShape.map(i => `${i}u`).join(',')});
for (var i:u32 = 0; i < ${inputShape.length}; i++) {
var inputIndex = ${inputShape.length === 1 ? 'inputIndices' : 'inputIndices[i]'};
if (inputIndex < 0 || inputIndex >= inputShape[i]) {
return false;
}
}
return true;
}`;
};
const bilinearInterpolation =
(inputShape: readonly number[], outputShape: readonly number[], scales: readonly number[],
useExtrapolation: boolean, extrapolationValue: number): string => {
const outputIndicesHelper = createIndicesHelper('output', outputShape);
const inputIndicesHelper = createIndicesHelper('input', inputShape);
const [batchIdx, heightIdx, widthIdx, channelIdx] =
inputShape.length === 2 ? [-1, 0, 1, -1] : (scales[1] === 1.0 ? [0, 2, 3, 1] : [0, 1, 2, 3]);
return `
fn getInputValue(batch: u32, channel: u32, row: u32, col: u32) -> f32 {
var inputIndices: ${inputIndicesHelper.iType};
inputIndices[${heightIdx}] = max(0, min(row, ${inputShape[heightIdx]} - 1));
inputIndices[${widthIdx}] = max(0, min(col, ${inputShape[widthIdx]} - 1));
if (${inputShape.length} > 2) {
inputIndices[${channelIdx}] = channel;
inputIndices[${batchIdx}] = batch;
};
return input[${inputIndicesHelper.i2oExpression('inputIndices')}];
}
fn bilinearInterpolation(outputIndices: ${outputIndicesHelper.iType}) -> f32 {
var originalIndices = calculateOriginalIndicesFromOutputIndices(outputIndices);
var row:f32 = originalIndices[${heightIdx}];
var col:f32 = originalIndices[${widthIdx}];
if (${useExtrapolation} && (row < 0 || row > (${inputShape[heightIdx]} - 1) || col < 0 || col > ${
inputShape[widthIdx]} - 1)) {
return ${extrapolationValue};
}
row = max(0, min(row, ${inputShape[heightIdx]} - 1));
col = max(0, min(col, ${inputShape[widthIdx]} - 1));
var row1: u32 = u32(row);
var col1: u32 = u32(col);
var row2: u32 = u32(row + 1);
var col2: u32 = u32(col + 1);
var channel: u32 = 0;
var batch: u32 = 0;
if (${inputShape.length > 2}) {
channel = u32(originalIndices[${channelIdx}]);
batch = u32(originalIndices[${batchIdx}]);
}
var x11: f32 = getInputValue(batch, channel, row1, col1);
var x12: f32 = getInputValue(batch, channel, row1, col2);
var x21: f32 = getInputValue(batch, channel, row2, col1);
var x22: f32 = getInputValue(batch, channel, row2, col2);
var dx1: f32 = row - f32(row1);
var dx2: f32 = f32(row2 ) - row;
var dy1 = col - f32(col1);
var dy2 = f32(col2) - col;
return (x11 * dx2 * dy2 + x12 * dx2 * dy1 + x21 * dx1 * dy2 + x22 * dx1 * dy1);
}`;
};
const bicubicInterpolation =
(inputShape: readonly number[], outputShape: readonly number[], scales: readonly number[], roi: readonly number[],
cubicCoeffA: number, useExtrapolation: boolean, extrapolationValue: number, excludeOutside: boolean): string => {
const outputIndicesHelper = createIndicesHelper('output', outputShape);
const inputIndicesHelper = createIndicesHelper('input', inputShape);
const [heightIdx, widthIdx] = inputShape.length === 2 ? [0, 1] : (scales[1] === 1.0) ? [2, 3] : [1, 2];
const createCubicInterpolationFunction = (idx: number): string => {
const direction = idx === heightIdx ? 'row' : 'col';
return `
fn ${direction}CubicInterpolation(inputIndices: ${inputIndicesHelper.iType}, outputIndices: ${
outputIndicesHelper.iType}) -> f32 {
var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : `outputIndices[${idx}]`};
var originalIdx: f32 = getOriginalCoordinateFromResizedCoordinate(f32(outputIndex), ${scales[idx]},
f32(${outputShape[idx]}), f32(${inputShape[idx]}), ${roi[idx]}, ${roi[idx]} + ${inputShape.length});
var fractOriginalIdx: f32 = originalIdx - floor(originalIdx);
var coefs = getCubicInterpolationCoefs(fractOriginalIdx);
if (${useExtrapolation} && (originalIdx < 0 || originalIdx > (${inputShape[idx]} - 1))) {
return ${extrapolationValue};
}
var data: array<f32, 4> = array<f32, 4>(0.0, 0.0, 0.0, 0.0);
for (var i: i32 = -1; i < 3; i++) {
var ${direction}: f32 = originalIdx + f32(i);
if (${direction} < 0 || ${direction} >= ${inputShape[idx]}) {
if (${excludeOutside}) {
coefs[i + 1] = 0.0;
continue;
} else if (${useExtrapolation}) {
return ${extrapolationValue};
} else {
${direction} = max(0, min(${direction}, ${inputShape[idx]} - 1));
}
}
var inputIndicesCopy: ${inputIndicesHelper.iType} = inputIndices;
inputIndicesCopy[${idx}] = u32(${direction});
data[i + 1] = ${idx === heightIdx ? `input[${inputIndicesHelper.i2oExpression('inputIndicesCopy')}];` : `
rowCubicInterpolation(inputIndicesCopy, outputIndices);`}
}
return cubicInterpolation1D(data, coefs);
}`;
};
return `
${createCubicInterpolationFunction(heightIdx)};
${createCubicInterpolationFunction(widthIdx)};
fn getCubicInterpolationCoefs(s: f32) -> array<f32, 4> {
var absS = abs(s);
var coeffs: array<f32, 4> = array<f32, 4>(0.0, 0.0, 0.0, 0.0);
var oneMinusAbsS: f32 = 1.0 - absS;
var twoMinusAbsS: f32 = 2.0 - absS;
var onePlusAbsS: f32 = 1.0 + absS;
coeffs[0] = ((${cubicCoeffA} * onePlusAbsS - 5 * ${cubicCoeffA}) * onePlusAbsS + 8 * ${
cubicCoeffA}) * onePlusAbsS - 4 * ${cubicCoeffA};
coeffs[1] = ((${cubicCoeffA} + 2) * absS - (${cubicCoeffA} + 3)) * absS * absS + 1;
coeffs[2] = ((${cubicCoeffA} + 2) * oneMinusAbsS - (${cubicCoeffA} + 3)) * oneMinusAbsS * oneMinusAbsS + 1;
coeffs[3] = ((${cubicCoeffA} * twoMinusAbsS - 5 * ${cubicCoeffA}) * twoMinusAbsS + 8 * ${
cubicCoeffA}) * twoMinusAbsS - 4 * ${cubicCoeffA};
return coeffs;
}
fn cubicInterpolation1D(x: array<f32, 4>, coefs: array<f32, 4>) -> f32 {
var coefsSum: f32 = coefs[0] + coefs[1] + coefs[2] + coefs[3];
return (x[0] * coefs[0] + x[1] * coefs[1]+ x[2] * coefs[2]+ x[3] * coefs[3]) / coefsSum;
}
fn bicubicInterpolation(outputIndices: ${outputIndicesHelper.iType}) -> f32 {
var inputIndices: ${inputIndicesHelper.iType} = outputIndices;
return colCubicInterpolation(inputIndices, outputIndices);
}
`;
};
const createResizeProgramInfo =
(metadata: ProgramMetadata, input: TensorView, attributes: ResizeAttributes, opsetVersion: number,
scalesInput: readonly number[], sizes: readonly number[], roiInput: readonly number[]): ProgramInfo => {
const inputShape = input.dims;
const roi = updateRoI(roiInput, attributes.axes, inputShape.length);
let outputShape = initOutputShape(inputShape, scalesInput, sizes, attributes.axes);
let scales = scalesInput.slice();
if (scalesInput.length === 0) {
scales = inputShape.map((value, index) => value === 0 ? 1.0 : outputShape[index] / value);
if (attributes.keepAspectRatioPolicy !== 'stretch') {
outputShape = adjustOutputShape(inputShape, outputShape, scales, attributes);
}
}
const outputIndicesHelper = createIndicesHelper('output', outputShape);
const inputIndicesHelper = createIndicesHelper('input', inputShape);
const outputSize = ShapeUtil.size(outputShape);
const dataType = 'f32';
const noScale = inputShape.length === outputShape.length && inputShape.every((d, i) => d === outputShape[i]);
const useExtrapolation = attributes.coordinateTransformMode === 'tf_crop_and_resize';
const getShaderSource = (shaderHelper: ShaderHelper) => `
${getOriginalCoordinateFromResizedCoordinate(attributes.coordinateTransformMode)};
${(() => {
switch (attributes.mode) {
case 'nearest':
return `
${checkInputIndices(inputShape)};
${getNearestPixelFromOriginal(attributes.nearestMode, opsetVersion)};
${calculateInputIndicesFromOutputIndices(inputShape, outputShape, scales, roi, useExtrapolation)};
`;
case 'linear':
return `
${calculateOriginalIndicesFromOutputIndices(inputShape, outputShape, scales, roi)};
${
bilinearInterpolation(
inputShape, outputShape, scales, useExtrapolation, attributes.extrapolationValue)};
`;
case 'cubic':
return `
${
bicubicInterpolation(
inputShape, outputShape, scales, roi, attributes.cubicCoeffA, useExtrapolation,
attributes.extrapolationValue, attributes.excludeOutside)};
`;
default:
throw Error('Invalid resize mode');
}
})()};
@group(0) @binding(0) var<storage, read> input : array<${dataType}>;
@group(0) @binding(1) var<storage, read_write> output : array<${dataType}>;
${outputIndicesHelper.o2iImpl}
${inputIndicesHelper.i2oImpl}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
if (${noScale}) {
output[global_idx] = input[global_idx];
} else {
${outputIndicesHelper.indicesVariableDeclaration('outputIndices')}
${outputIndicesHelper.o2iCall('global_idx', 'outputIndices')}
${inputIndicesHelper.indicesVariableDeclaration('inputIndices')}
${(() => {
switch (attributes.mode) {
case 'nearest':
return `inputIndices = calculateInputIndicesFromOutputIndices(outputIndices);
if (checkInputIndices(inputIndices)) {
output[global_idx] = input[${inputIndicesHelper.i2oExpression('inputIndices')}];
} else {
output[global_idx] = ${attributes.extrapolationValue};
}`;
case 'linear':
return 'output[global_idx] = bilinearInterpolation(outputIndices);';
case 'cubic':
return 'output[global_idx] = bicubicInterpolation(outputIndices);';
default:
throw Error(`Unsupported resize mode: ${attributes.mode}`);
}
})()};
}
}`;
return {
...metadata,
getShaderSource,
outputs: [{dims: outputShape, dataType: input.dataType, gpuDataType: GpuDataType.default}],
dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)})
};
};
export const createResizeProgramInfoLoader =
(input: TensorView, attributes: ResizeAttributes, opsetVersion: number, scales: readonly number[],
sizes: readonly number[], roi: readonly number[]): ProgramInfoLoader => {
const metadata: ProgramMetadata = {
name: 'Resize',
inputTypes: [GpuDataType.default],
cacheHint: attributes.cacheKey + opsetVersion.toString() +
(scales.length > 0 ? '_scales_' + scales.toString() : '') +
(sizes.length > 0 ? '_sizes_' + sizes.toString() : ''),
};
return {
...metadata,
get: () => createResizeProgramInfo(metadata, input, attributes, opsetVersion, scales, sizes, roi)
};
};
const getOpsetVersionFromCustomDataBuffer = (context: ComputeContext): number => {
const customDataBuffer = context.customDataBuffer;
const customDataBuffer32 = new Uint32Array(customDataBuffer, customDataBuffer.byteOffset, 1);
const opsetVersion = customDataBuffer32[0];
return opsetVersion;
};
export const resize = (context: ComputeContext, attributes: ResizeAttributes): void => {
const scales: number[] = [];
const sizes: number[] = [];
const roi: number[] = [];
const opsetVersion = getOpsetVersionFromCustomDataBuffer(context);
validateInputs(context.inputs, attributes, opsetVersion, scales, sizes, roi);
context.compute(
createResizeProgramInfoLoader(context.inputs[0], attributes, opsetVersion, scales, sizes, roi), {inputs: [0]});
};
export const parseResizeAttributes = (attributes: Record<string, unknown>): ResizeAttributes => {
const antialias = attributes.antialias as number;
const axes = attributes.axes as number[];
const coordinateTransformMode: CoordinateTransformMode =
attributes.coordinateTransformMode as CoordinateTransformMode;
const cubicCoeffA = attributes.cubicCoeffA as number;
const excludeOutside = attributes.excludeOutside as number !== 0;
const extrapolationValue = attributes.extrapolationValue as number;
const keepAspectRatioPolicy: KeepAspectRatioPolicy = attributes.keepAspectRatioPolicy as KeepAspectRatioPolicy;
const mode: Mode = attributes.mode as Mode;
// If nearestMode is not specified, use simple mode.
const nearestMode: NearestMode = (attributes.nearestMode === '' ? 'simple' : attributes.nearestMode) as NearestMode;
return createAttributeWithCacheKey({
antialias,
axes,
coordinateTransformMode,
cubicCoeffA,
excludeOutside,
extrapolationValue,
keepAspectRatioPolicy,
mode,
nearestMode
});
};

View file

@ -196,6 +196,5 @@ export const parseSliceAttributes = (attributes: Record<string, unknown>): Slice
const starts = attributes.starts as number[];
const ends = attributes.ends as number[];
const axes = attributes.axes as number[];
const steps: number[] = [];
return createAttributeWithCacheKey({starts, ends, axes, steps});
return createAttributeWithCacheKey({starts, ends, axes});
};

View file

@ -10,10 +10,12 @@ import path from 'path';
const COMMENTS: Record<string, string> = {
'AveragePool': 'need perf optimization; need implementing activation',
'MaxPool': 'need perf optimization; need implementing activation',
'Conv': 'need perf optimization; conv3d not supported; need implementing activation',
'Conv': 'need perf optimization; conv3d is not supported; need implementing activation',
'ConvTranspose': 'need perf optimization; ConvTranspose3d is not supported; need implementing activation',
'Transpose': 'need perf optimization',
'Reshape': 'no GPU kernel',
'Shape': 'no GPU kernel; an ORT warning is generated - need to fix',
'Resize': 'CoordinateTransformMode align_corners is not supported with downsampling',
};
/* eslint-disable max-len */

View file

@ -5,7 +5,12 @@
"ops": []
},
"webgl": {
"onnx": ["resnet50", "squeezenet", "tiny_yolov2", "emotion_ferplus"],
"onnx": [
"resnet50",
"squeezenet",
"tiny_yolov2",
"emotion_ferplus"
],
"node": [
"test_abs",
"test_acos_example",
@ -976,34 +981,34 @@
"test_reshape_reordered_last_dims",
"test_reshape_zero_and_negative_dim",
"test_reshape_zero_dim",
// "test_resize_downsample_linear",
// "test_resize_downsample_nearest",
// "test_resize_downsample_scales_cubic_A_n0p5_exclude_outside",
"test_resize_downsample_linear",
"test_resize_downsample_nearest",
"test_resize_downsample_scales_cubic_A_n0p5_exclude_outside",
// "test_resize_downsample_scales_cubic_align_corners",
// "test_resize_downsample_scales_cubic",
"test_resize_downsample_scales_cubic",
// "test_resize_downsample_scales_linear_align_corners",
// "test_resize_downsample_scales_linear",
// "test_resize_downsample_scales_nearest",
// "test_resize_downsample_sizes_cubic",
// "test_resize_downsample_sizes_linear_pytorch_half_pixel",
// "test_resize_downsample_sizes_nearest_tf_half_pixel_for_nn",
// "test_resize_downsample_sizes_nearest",
// "test_resize_nearest",
// "test_resize_tf_crop_and_resize",
// "test_resize_upsample_linear",
// "test_resize_upsample_nearest",
// "test_resize_upsample_scales_cubic_A_n0p5_exclude_outside",
// "test_resize_upsample_scales_cubic_align_corners",
// "test_resize_upsample_scales_cubic_asymmetric",
// "test_resize_upsample_scales_cubic",
// "test_resize_upsample_scales_linear_align_corners",
// "test_resize_upsample_scales_linear",
// "test_resize_upsample_scales_nearest",
// "test_resize_upsample_sizes_cubic",
// "test_resize_upsample_sizes_nearest_ceil_half_pixel",
// "test_resize_upsample_sizes_nearest_floor_align_corners",
// "test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric",
// "test_resize_upsample_sizes_nearest",
"test_resize_downsample_scales_linear",
"test_resize_downsample_scales_nearest",
"test_resize_downsample_sizes_cubic",
"test_resize_downsample_sizes_linear_pytorch_half_pixel",
"test_resize_downsample_sizes_nearest_tf_half_pixel_for_nn",
"test_resize_downsample_sizes_nearest",
"test_resize_nearest",
"test_resize_tf_crop_and_resize",
"test_resize_upsample_linear",
"test_resize_upsample_nearest",
"test_resize_upsample_scales_cubic_A_n0p5_exclude_outside",
"test_resize_upsample_scales_cubic_align_corners",
"test_resize_upsample_scales_cubic_asymmetric",
"test_resize_upsample_scales_cubic",
"test_resize_upsample_scales_linear_align_corners",
"test_resize_upsample_scales_linear",
"test_resize_upsample_scales_nearest",
"test_resize_upsample_sizes_cubic",
"opset{12,13,17,18}/test_resize_upsample_sizes_nearest_ceil_half_pixel",
"opset{12,13,17,18}/test_resize_upsample_sizes_nearest_floor_align_corners",
"opset{12,13,17,18}/test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric",
"test_resize_upsample_sizes_nearest",
// // "test_reversesequence_batch",
// // "test_reversesequence_time",
// // "test_rnn_seq_length",
@ -1364,7 +1369,12 @@
]
},
"wasm": {
"onnx": ["resnet50", "squeezenet", "tiny_yolov2", "emotion_ferplus"],
"onnx": [
"resnet50",
"squeezenet",
"tiny_yolov2",
"emotion_ferplus"
],
"node": [
// Check in node tests that have native Wasm implementations
// (i.e.) not tests that rely on the fallback cpu implementations
@ -1463,7 +1473,12 @@
"ops": []
},
"webnn": {
"onnx": ["resnet50", "squeezenet", "tiny_yolov2", "emotion_ferplus"],
"onnx": [
"resnet50",
"squeezenet",
"tiny_yolov2",
"emotion_ferplus"
],
"node": [
// Check in node tests that have native Wasm implementations.
// (i.e.) not tests that rely on the fallback cpu implementations.

View file

@ -249,6 +249,11 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, Split);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 8, 12, Expand);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Expand);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, Resize);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Resize);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Resize);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Resize);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Resize);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 9, Slice);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, Slice);
@ -443,6 +448,12 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 8, 12, Expand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Expand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 9, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Slice)>,

View file

@ -102,7 +102,14 @@ class JsKernel : public OpKernel {
size_t index = 4;
for (int i = 0; i < context->InputCount(); i++) {
p_serialized_kernel_context[index++] = static_cast<uint32_t>(context->Input<Tensor>(i)->GetElementType());
p_serialized_kernel_context[index++] = reinterpret_cast<uint32_t>(context->Input<Tensor>(i)->DataRaw());
const auto* ptr = context->Input<Tensor>(i);
// Skip if the input is only a placeholder.
if (ptr == nullptr) {
p_serialized_kernel_context[index++] = 0;
p_serialized_kernel_context[index++] = 0;
continue;
}
p_serialized_kernel_context[index++] = reinterpret_cast<uint32_t>(ptr->DataRaw());
p_serialized_kernel_context[index++] = static_cast<uint32_t>(context->Input<Tensor>(i)->Shape().NumDimensions());
for (size_t d = 0; d < context->Input<Tensor>(i)->Shape().NumDimensions(); d++) {
p_serialized_kernel_context[index++] = static_cast<uint32_t>(context->Input<Tensor>(i)->Shape()[d]);

View file

@ -0,0 +1,74 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "resize.h"
namespace onnxruntime {
namespace js {
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Resize,
kOnnxDomain,
10, 10,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 1)
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Resize);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Resize,
kOnnxDomain,
11, 12,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 1) // roi
.InputMemoryType(OrtMemTypeCPUInput, 2) // scales
.InputMemoryType(OrtMemTypeCPUInput, 3) // sizes
.TypeConstraint("T1", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<float>()),
Resize);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Resize,
kOnnxDomain,
13,
17,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 1)
.InputMemoryType(OrtMemTypeCPUInput, 2)
.InputMemoryType(OrtMemTypeCPUInput, 3)
.TypeConstraint("T1", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<float>()),
Resize);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Resize,
kOnnxDomain,
18,
18,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 1)
.InputMemoryType(OrtMemTypeCPUInput, 2)
.InputMemoryType(OrtMemTypeCPUInput, 3)
.TypeConstraint("T1", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<float>()),
Resize);
ONNX_OPERATOR_KERNEL_EX(
Resize,
kOnnxDomain,
19,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 1)
.InputMemoryType(OrtMemTypeCPUInput, 2)
.InputMemoryType(OrtMemTypeCPUInput, 3)
.TypeConstraint("T1", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<float>()),
Resize);
} // namespace js
} // namespace onnxruntime

View file

@ -0,0 +1,138 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/cpu/tensor/upsamplebase.h"
#include "core/providers/js/js_kernel.h"
namespace onnxruntime {
namespace js {
class Resize : public JsKernel, public UpsampleBase {
public:
Resize(const OpKernelInfo& info) : JsKernel(info), UpsampleBase(info) {
const auto& node = info.node();
opset_ = node.SinceVersion();
auto resize_coordinate_transformation_mode = ResizeCoordinateTransformationModeToString(coordinate_transform_mode_);
auto keep_aspect_ratio_policy = KeepAspectRatioPolicyToString(keep_aspect_ratio_policy_);
auto nearest_mode = NearestModeToString(nearest_mode_);
auto mode = UpsampleModeToString(mode_);
std::vector<int32_t> axes;
std::transform(axes_.begin(), axes_.end(), std::back_inserter(axes), [](auto& axis) { return gsl::narrow_cast<int32_t>(axis); });
JSEP_INIT_KERNEL_ATTRIBUTE(Resize, ({
"antialias" : $1,
"axes" : $2 ? Array.from(HEAP32.subarray($3, $3 + $2)) : [],
"coordinateTransformMode" : UTF8ToString($4),
"cubicCoeffA" : $5,
"excludeOutside" : $6,
"extrapolationValue" : $7,
"keepAspectRatioPolicy" : UTF8ToString($8),
"mode" : UTF8ToString($9),
"nearestMode" : UTF8ToString($10),
}),
static_cast<int32_t>(antialias_),
gsl::narrow_cast<int32_t>(axes.size()),
reinterpret_cast<int32_t>((axes.size() > 0) ? axes.data() : nullptr) >> 2,
resize_coordinate_transformation_mode.c_str(),
static_cast<double>(cubic_coeff_a_),
static_cast<int32_t>(exclude_outside_),
static_cast<double>(extrapolation_value_),
keep_aspect_ratio_policy.c_str(),
mode.c_str(),
nearest_mode.c_str());
}
std::string UpsampleModeToString(UpsampleMode mode) {
switch (mode) {
case UpsampleMode::NN:
return UpsampleModeNN;
case UpsampleMode::LINEAR:
return UpsampleModeLinear;
case UpsampleMode::CUBIC:
return UpsampleModeCubic;
default:
ORT_THROW("UpsampleMode is not supported!");
}
}
std::string KeepAspectRatioPolicyToString(AspectRatioPolicy policy) {
switch (policy) {
case AspectRatioPolicy::STRETCH:
return "stretch";
case AspectRatioPolicy::NOT_LARGER:
return "not_larger";
case AspectRatioPolicy::NOT_SMALLER:
return "not_smaller";
default:
ORT_THROW("AspectRatioPolicy is not supported!");
}
}
std::string ResizeCoordinateTransformationModeToString(const ResizeCoordinateTransformationMode mode) {
switch (mode) {
case ASYMMETRIC:
return "asymmetric";
case PYTORCH_HALF_PIXEL:
return "pytorch_half_pixel";
case TF_HALF_PIXEL_FOR_NN:
return "tf_half_pixel_for_nn";
case ALIGN_CORNERS:
return "align_corners";
case TF_CROP_AND_RESIZE:
return "tf_crop_and_resize";
case HALF_PIXEL:
return "half_pixel";
case HALF_PIXEL_SYMMETRIC:
return "half_pixel_symmetric";
default:
ORT_THROW("ResizeCoordinateTransformationMode is not supported!");
}
}
std::string NearestModeToString(const ResizeNearestMode mode) {
switch (mode) {
case ROUND_PREFER_FLOOR:
return "round_prefer_floor";
case ROUND_PREFER_CEIL:
return "round_prefer_ceil";
case FLOOR:
return "floor";
case CEIL:
return "ceil";
default:
return "";
}
}
virtual Status SerializeCustomData(OpKernelContext* context, AllocatorPtr alloc, void** ptr, size_t* size) const {
TensorShapeVector output_dims;
std::vector<float> roi_array;
std::vector<float> scales_array;
// Compute the size of the custom data
size_t customDataSize = sizeof(int32_t);
// Allocate memory for custom data
void* p_custom_data = alloc->Alloc(customDataSize);
// Validate the memory allocation
if (p_custom_data == nullptr) {
return Status(common::ONNXRUNTIME, common::FAIL, "failed to allocate memory for the custom data");
}
// Serialize the custom data
int32_t* p_int32 = reinterpret_cast<int32_t*>(p_custom_data);
*p_int32 = static_cast<int32_t>(opset_);
*ptr = p_custom_data;
*size = customDataSize;
return Status::OK();
}
int opset_;
};
} // namespace js
} // namespace onnxruntime