mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-30 03:37:44 +00:00
[JS/WebGPU] Add Resize operator (#16680)
### Description Implemented Resize operator support in JSEP ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
3fd1d3b9bd
commit
77b2b618b2
10 changed files with 879 additions and 35 deletions
|
|
@ -22,8 +22,8 @@ Do not modify directly.*
|
|||
| Ceil | ai.onnx(6-12,13+) | |
|
||||
| Clip | ai.onnx(6-10,11,12,13+) | |
|
||||
| Concat | ai.onnx(1-3,4-10,11-12,13+) | |
|
||||
| Conv | ai.onnx(1-10,11+); com.ms.internal.nhwc(11+) | need perf optimization; conv3d not supported; need implementing activation |
|
||||
| ConvTranspose | ai.onnx(1-10,11+); com.ms.internal.nhwc(11+) | |
|
||||
| Conv | ai.onnx(1-10,11+); com.ms.internal.nhwc(11+) | need perf optimization; conv3d is not supported; need implementing activation |
|
||||
| ConvTranspose | ai.onnx(1-10,11+); com.ms.internal.nhwc(11+) | need perf optimization; ConvTranspose3d is not supported; need implementing activation |
|
||||
| Cos | ai.onnx(7+) | |
|
||||
| Cosh | ai.onnx(9+) | |
|
||||
| Div | ai.onnx(7-12,13,14+) | |
|
||||
|
|
@ -58,6 +58,7 @@ Do not modify directly.*
|
|||
| ReduceSumSquare | ai.onnx(1-10,11-12,13-17,18+) | |
|
||||
| Relu | ai.onnx(6-12,13,14+) | |
|
||||
| Reshape | ai.onnx(5-12,13,14+) | no GPU kernel |
|
||||
| Resize | ai.onnx(10,11-12,13-17,18,19+) | CoordinateTransformMode align_corners is not supported with downsampling |
|
||||
| Shape | ai.onnx(1-12,13-14,15+) | no GPU kernel; an ORT warning is generated - need to fix |
|
||||
| Sigmoid | ai.onnx(6-12,13+) | |
|
||||
| Sin | ai.onnx(7+) | |
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import {gemm, parseGemmAttributes} from './ops/gemm';
|
|||
import {matMul} from './ops/matmul';
|
||||
import * as pool from './ops/pool';
|
||||
import {parseReduceAttributes, reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce';
|
||||
import {parseResizeAttributes, resize} from './ops/resize';
|
||||
import {parseSliceAttributes, slice} from './ops/slice';
|
||||
import {parseSplitAttributes, split} from './ops/split';
|
||||
import {parseTransposeAttributes, transpose} from './ops/transpose';
|
||||
|
|
@ -69,6 +70,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
|
|||
['ReduceLogSumExp', [reduceLogSumExp, parseReduceAttributes]],
|
||||
['ReduceSumSquare', [reduceSumSquare, parseReduceAttributes]],
|
||||
['Relu', [unaryOps.relu]],
|
||||
['Resize', [resize, parseResizeAttributes]],
|
||||
['Sigmoid', [unaryOps.sigmoid]],
|
||||
['Sin', [unaryOps.sin]],
|
||||
['Sinh', [unaryOps.sinh]],
|
||||
|
|
|
|||
595
js/web/lib/wasm/jsep/webgpu/ops/resize.ts
Normal file
595
js/web/lib/wasm/jsep/webgpu/ops/resize.ts
Normal file
|
|
@ -0,0 +1,595 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
|
||||
import {TensorView} from '../../tensor';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';
|
||||
|
||||
import {createIndicesHelper, ShaderHelper} from './common';
|
||||
|
||||
type CoordinateTransformMode = 'half_pixel'|'asymmetric'|'pytorch_half_pixel'|'tf_half_pixel_for_nn'|'align_corners'|
|
||||
'tf_crop_and_resize'|'half_pixel_symmetric';
|
||||
|
||||
type KeepAspectRatioPolicy = 'stretch'|'not_smaller'|'not_larger';
|
||||
|
||||
type Mode = 'nearest'|'linear'|'cubic';
|
||||
|
||||
type NearestMode = 'round_prefer_floor'|'round_prefer_ceil'|'floor'|'ceil'|'simple';
|
||||
|
||||
export interface ResizeAttributes extends AttributeWithCacheKey {
|
||||
antialias: number;
|
||||
axes: number[];
|
||||
coordinateTransformMode: CoordinateTransformMode;
|
||||
cubicCoeffA: number;
|
||||
excludeOutside: boolean;
|
||||
extrapolationValue: number;
|
||||
keepAspectRatioPolicy: KeepAspectRatioPolicy;
|
||||
mode: Mode;
|
||||
nearestMode: NearestMode;
|
||||
}
|
||||
|
||||
const validateScales = (scales: number[], attributes: ResizeAttributes): void => {
|
||||
scales.every((value) => value > 0 || (() => {
|
||||
throw new Error('Resize requires scales input values to be positive');
|
||||
}));
|
||||
// Check scales dims based on mode: LINEAR, CUBIC
|
||||
if (scales.length > 0) {
|
||||
if (attributes.mode === 'linear') {
|
||||
if (!(scales.length === 2 || (scales.length === 4 && scales[0] === 1 && scales[1] === 1) ||
|
||||
(scales.length === 4 && scales[0] === 1 && scales[3] === 1))) {
|
||||
throw new Error('Resize requires scales input size to be 2 or 4 for linear mode');
|
||||
}
|
||||
} else if (attributes.mode === 'cubic') {
|
||||
if (!(scales.length === 2 || (scales.length === 4 && scales[0] === 1 && scales[1] === 1) ||
|
||||
(scales.length === 4 && scales[0] === 1 && scales[3] === 1))) {
|
||||
throw new Error('Resize requires scales input size to be 2 or 4 for cubic mode');
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const updateScales = (scales: readonly number[], axes: readonly number[], rank: number): number[] => {
|
||||
axes.every((value) => value >= 0 && value < rank || (() => {
|
||||
throw new Error('Resize requires axes input values to be positive and less than rank');
|
||||
}));
|
||||
const newScales = new Array(rank).fill(1.0);
|
||||
axes.forEach((value, index) => newScales[value] = scales[index]);
|
||||
return newScales;
|
||||
};
|
||||
|
||||
const validateInputs =
|
||||
(inputs: readonly TensorView[], attributes: ResizeAttributes, opsetVersion: number, scales: number[],
|
||||
sizes: number[], roi: number[]): void => {
|
||||
const [roiInputIndex, scalesInputIndex, sizesInputIndex] =
|
||||
(opsetVersion > 10) ? [1, 2, 3] : [-1, (inputs.length > 1) ? 1 : -1, -1];
|
||||
const rank = inputs[0].dims.length;
|
||||
if (roiInputIndex > 0 && inputs.length > roiInputIndex && inputs[roiInputIndex].dims.length > 0) {
|
||||
inputs[roiInputIndex].getFloat32Array().forEach((value) => roi.push(value));
|
||||
|
||||
} else if (attributes.coordinateTransformMode === 'tf_crop_and_resize') {
|
||||
throw new Error('Resize requires RoI input to be specified when coordinateTransformMode is tfCropAndResize');
|
||||
}
|
||||
|
||||
if (scalesInputIndex > 0 && inputs.length > scalesInputIndex && inputs[scalesInputIndex].dims.length > 0) {
|
||||
inputs[scalesInputIndex].getFloat32Array().forEach((value) => scales.push(value));
|
||||
if (scales.length !== 0 &&
|
||||
(scales.length !== rank && (opsetVersion >= 18 && scales.length !== attributes.axes.length))) {
|
||||
throw new Error(
|
||||
'Resize requires scales input size to be same as input rank or axes size for opset 18 and up');
|
||||
}
|
||||
validateScales(scales, attributes);
|
||||
if (attributes.axes.length > 0) {
|
||||
updateScales(scales, attributes.axes, rank).forEach((value, index) => scales[index] = value);
|
||||
}
|
||||
}
|
||||
if (sizesInputIndex > 0 && inputs.length > sizesInputIndex) {
|
||||
inputs[sizesInputIndex].getBigInt64Array().forEach((value) => sizes.push(Number(value)));
|
||||
if (sizes.length !== rank || (opsetVersion >= 18 && sizes.length === attributes.axes.length)) {
|
||||
throw new Error('Resize requires sizes input size to be same as input rank or axes size for opset 18 and up');
|
||||
}
|
||||
}
|
||||
|
||||
if (attributes.axes.length > 0) {
|
||||
if (scales.length !== attributes.axes.length) {
|
||||
throw new Error('Resize requires "scales" input size to be of axes rank when axes attributes is specified');
|
||||
}
|
||||
if (sizes.length !== attributes.axes.length) {
|
||||
throw new Error(
|
||||
'Resize requires "sizes" input size to be of rank axes rank when axes attributes is specified');
|
||||
}
|
||||
}
|
||||
if (typeof scales !== 'undefined' && typeof sizes !== 'undefined' && scales.length > 0 && sizes.length > rank) {
|
||||
throw new Error('Resize requires only of scales or sizes to be specified');
|
||||
}
|
||||
};
|
||||
|
||||
const getOriginalCoordinateFromResizedCoordinate = (coordinateTransferMode: CoordinateTransformMode): string =>
|
||||
'fn getOriginalCoordinateFromResizedCoordinate(xResized: f32, xScale: f32, lengthResized: f32,\
|
||||
lengthOriginal: f32, roiStart: f32, roiEnd: f32) -> f32 { ' +
|
||||
(() => {
|
||||
switch (coordinateTransferMode) {
|
||||
case 'asymmetric':
|
||||
return 'return xResized / xScale;';
|
||||
case 'pytorch_half_pixel':
|
||||
return 'if (lengthResized > 1) { \
|
||||
return (xResized + 0.5) / xScale - 0.5; \
|
||||
} else { \
|
||||
return 0.0; \
|
||||
}';
|
||||
case 'tf_half_pixel_for_nn':
|
||||
return 'return (xResized + 0.5) / xScale;';
|
||||
case 'align_corners':
|
||||
return 'if (lengthResized == 1) { \
|
||||
return 0.0; \
|
||||
} else { \
|
||||
return xResized * (lengthOriginal - 1) / (lengthResized - 1); \
|
||||
}';
|
||||
case 'tf_crop_and_resize':
|
||||
return 'if (lengthResized > 1) { \
|
||||
return roiStart * (lengthOriginal - 1) + \
|
||||
(xResized * (roiEnd - roiStart) * (lengthOriginal - 1)) / (lengthResized - 1); \
|
||||
} else { \
|
||||
return 0.5 * (roiStart + roiEnd) * f32(lengthOriginal - 1); \
|
||||
}';
|
||||
case 'half_pixel_symmetric':
|
||||
return [
|
||||
'const outputWidth = xScale * lengthResized;', 'const adjustment = lengthResized / outputWidth;',
|
||||
'const center = lengthOriginal / 2;', 'const offset = center * (1 - adjustment);',
|
||||
'return offset + ((xResized + 0.5) / xScale) - 0.5;'
|
||||
].join('\n');
|
||||
case 'half_pixel':
|
||||
return 'return ((xResized + 0.5) / xScale) - 0.5;';
|
||||
default:
|
||||
throw new Error(`Coordinate transform mode ${coordinateTransferMode} is not supported`);
|
||||
}
|
||||
})() +
|
||||
'}';
|
||||
|
||||
const getNearestPixelFromOriginal = (nearestMode: NearestMode, opsetVersion: number): string =>
|
||||
'fn getNearestPixelFromOriginal(xOriginal: f32, isDownSample: bool) -> f32 {' + (() => {
|
||||
switch (nearestMode) {
|
||||
case 'round_prefer_ceil':
|
||||
return 'if (fract(xOriginal) == 0.5) { \
|
||||
return ceil(xOriginal); \
|
||||
} else { \
|
||||
return round(xOriginal); \
|
||||
}';
|
||||
case 'floor':
|
||||
return 'return floor(xOriginal);';
|
||||
case 'ceil':
|
||||
return 'return ceil(xOriginal);';
|
||||
case 'round_prefer_floor':
|
||||
return 'if (fract(xOriginal) == 0.5) { \
|
||||
return floor(xOriginal); \
|
||||
} else { \
|
||||
return round(xOriginal); \
|
||||
}';
|
||||
case 'simple':
|
||||
default:
|
||||
if (opsetVersion < 11) {
|
||||
return 'if (isDownSample) \
|
||||
{ \
|
||||
return ceil(xOriginal); \
|
||||
} else { \
|
||||
return xOriginal; \
|
||||
}';
|
||||
}
|
||||
throw new Error(`Nearest mode ${nearestMode} is not supported`);
|
||||
}
|
||||
})() +
|
||||
'}';
|
||||
|
||||
const updateRoI = (roi: readonly number[], axes: readonly number[], rank: number): number[] => {
|
||||
const roiTmp = new Array(rank).fill(0).concat(new Array(rank).fill(1));
|
||||
const roiLocal = roi.length === 0 ? roiTmp : roi.slice();
|
||||
if (axes.length > 0) {
|
||||
axes.forEach((v, i) => {
|
||||
roiTmp[v] = roiLocal[i];
|
||||
roiTmp[i + rank] = roiLocal[axes.length + i];
|
||||
});
|
||||
return roiTmp;
|
||||
}
|
||||
return roiLocal;
|
||||
};
|
||||
|
||||
const initOutputShape =
|
||||
(inputShape: readonly number[], scales: readonly number[], sizes: readonly number[], axes: readonly number[]):
|
||||
number[] => {
|
||||
let outputShape: number[] = [];
|
||||
if (sizes.length > 0) {
|
||||
if (axes.length > 0) {
|
||||
inputShape.forEach((v) => outputShape.push(v));
|
||||
if (Math.max(...axes) > inputShape.length) {
|
||||
throw new Error('axes is out of bound');
|
||||
}
|
||||
axes.forEach((v, i) => outputShape[v] = sizes[i]);
|
||||
} else {
|
||||
sizes.forEach((v) => outputShape.push(v));
|
||||
}
|
||||
} else {
|
||||
if (scales.length === 0) {
|
||||
throw new Error('Resize requires either scales or sizes.');
|
||||
} else {
|
||||
outputShape = inputShape.map((value, index) => Math.round(value * scales[index]));
|
||||
}
|
||||
}
|
||||
return outputShape;
|
||||
};
|
||||
|
||||
const adjustOutputShape =
|
||||
(inputShape: readonly number[], outputShape: readonly number[], scales: number[], attributes: ResizeAttributes):
|
||||
number[] => {
|
||||
const scaleInPolicy = (() => {
|
||||
switch (attributes.keepAspectRatioPolicy) {
|
||||
case 'not_larger':
|
||||
return attributes.axes.length > 0 ? Math.min(...attributes.axes.map(i => scales[i]), Number.MAX_VALUE) :
|
||||
Math.min(...scales, Number.MAX_VALUE);
|
||||
case 'not_smaller':
|
||||
return attributes.axes.length > 0 ? Math.max(...attributes.axes.map(i => scales[i]), Number.MIN_VALUE) :
|
||||
Math.max(...scales, Number.MIN_VALUE);
|
||||
default:
|
||||
throw new Error(`Keep aspect ratio policy ${attributes.keepAspectRatioPolicy} is not supported`);
|
||||
}
|
||||
})();
|
||||
scales.fill(1.0, 0, scales.length);
|
||||
const adjustedOutputShape = inputShape.slice();
|
||||
if (attributes.axes.length > 0) {
|
||||
attributes.axes.forEach((v) => scales[v] = scaleInPolicy);
|
||||
attributes.axes.forEach((v) => adjustedOutputShape[v] = Math.round(inputShape[v] * scales[v]));
|
||||
} else {
|
||||
scales.fill(scaleInPolicy, 0, scales.length);
|
||||
adjustedOutputShape.forEach((v, i) => adjustedOutputShape[i] = Math.round(v * scales[i]));
|
||||
}
|
||||
return adjustedOutputShape;
|
||||
};
|
||||
|
||||
const calculateOriginalIndicesFromOutputIndices =
|
||||
(inputShape: readonly number[], outputShape: readonly number[], scales: readonly number[], roi: readonly number[]):
|
||||
string => {
|
||||
const outputIndicesHelper = createIndicesHelper('output', outputShape);
|
||||
return `
|
||||
fn calculateOriginalIndicesFromOutputIndices(outputIndices: ${outputIndicesHelper.iType}) -> array<f32, ${
|
||||
outputShape.length}> {
|
||||
const inputShape = array<u32, ${inputShape.length}>(${inputShape.map(i => `${i}u`).join(',')});
|
||||
const outputShape = array<u32, ${outputShape.length}>(${outputShape.map(i => `${i}u`).join(',')});
|
||||
const scales = array<f32, ${scales.length}>(${scales.map(i => `${i}f`).join(',')});
|
||||
const roi = array<f32, ${roi.length}>(${roi.map(i => `${i}f`).join(',')});
|
||||
var originalIndices: array<f32, ${outputShape.length}>;
|
||||
for (var i:u32 = 0; i < ${outputShape.length}; i++) {
|
||||
var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'};
|
||||
if (scales[i] == 1.0) {
|
||||
originalIndices[i] = f32(outputIndex);
|
||||
} else {
|
||||
originalIndices[i] = getOriginalCoordinateFromResizedCoordinate(f32(outputIndex), scales[i],
|
||||
f32(outputShape[i]), f32(inputShape[i]), roi[i], roi[i + ${inputShape.length}]);
|
||||
}
|
||||
}
|
||||
return originalIndices;
|
||||
}`;
|
||||
};
|
||||
|
||||
const calculateInputIndicesFromOutputIndices =
|
||||
(inputShape: readonly number[], outputShape: readonly number[], scales: readonly number[], roi: readonly number[],
|
||||
useExtrapolation: boolean): string => {
|
||||
const outputIndicesHelper = createIndicesHelper('output', outputShape);
|
||||
const inputIndicesHelper = createIndicesHelper('input', inputShape);
|
||||
return `
|
||||
fn calculateInputIndicesFromOutputIndices(outputIndices: ${outputIndicesHelper.iType}) -> array<u32, ${
|
||||
inputShape.length}> {
|
||||
const inputShape = array<u32, ${inputShape.length}>(${inputShape.map(i => `${i}u`).join(',')});
|
||||
const outputShape = array<u32, ${outputShape.length}>(${outputShape.map(i => `${i}u`).join(',')});
|
||||
const scales = array<f32, ${scales.length}>(${scales.map(i => `${i}f`).join(',')});
|
||||
const roi = array<f32, ${roi.length}>(${roi.map(i => `${i}f`).join(',')});
|
||||
var inputIndices: ${inputIndicesHelper.iType};
|
||||
for (var i:u32 = 0; i < ${outputShape.length}; i++) {
|
||||
var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'};
|
||||
var inputIndex: u32;
|
||||
if (scales[i] == 1.0) {
|
||||
inputIndex = outputIndex;
|
||||
} else {
|
||||
var original_idx = getOriginalCoordinateFromResizedCoordinate(f32(outputIndex), scales[i],
|
||||
f32(outputShape[i]), f32(inputShape[i]), roi[i], roi[i + ${inputShape.length}]);
|
||||
if (!${useExtrapolation} || (original_idx >= 0 && original_idx < f32(inputShape[i]))) {
|
||||
if (original_idx < 0) {
|
||||
inputIndex = 0;
|
||||
} else if (original_idx > (f32(inputShape[i]) - 1)) {
|
||||
inputIndex = inputShape[i] - 1;
|
||||
} else {
|
||||
inputIndex = u32(getNearestPixelFromOriginal(original_idx, scales[i] < 1));
|
||||
}
|
||||
} else {
|
||||
inputIndex = u32(original_idx);
|
||||
}
|
||||
}
|
||||
${inputShape.length === 1 ? 'inputIndices' : 'inputIndices[i]'} = inputIndex;
|
||||
}
|
||||
return inputIndices;
|
||||
}`;
|
||||
};
|
||||
|
||||
const checkInputIndices = (inputShape: readonly number[]): string => {
|
||||
const inputIndicesHelper = createIndicesHelper('output', inputShape);
|
||||
return `
|
||||
fn checkInputIndices(inputIndices: ${inputIndicesHelper.iType}) -> bool {
|
||||
const inputShape = array<u32, ${inputShape.length}>(${inputShape.map(i => `${i}u`).join(',')});
|
||||
for (var i:u32 = 0; i < ${inputShape.length}; i++) {
|
||||
var inputIndex = ${inputShape.length === 1 ? 'inputIndices' : 'inputIndices[i]'};
|
||||
if (inputIndex < 0 || inputIndex >= inputShape[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}`;
|
||||
};
|
||||
|
||||
const bilinearInterpolation =
|
||||
(inputShape: readonly number[], outputShape: readonly number[], scales: readonly number[],
|
||||
useExtrapolation: boolean, extrapolationValue: number): string => {
|
||||
const outputIndicesHelper = createIndicesHelper('output', outputShape);
|
||||
const inputIndicesHelper = createIndicesHelper('input', inputShape);
|
||||
const [batchIdx, heightIdx, widthIdx, channelIdx] =
|
||||
inputShape.length === 2 ? [-1, 0, 1, -1] : (scales[1] === 1.0 ? [0, 2, 3, 1] : [0, 1, 2, 3]);
|
||||
return `
|
||||
fn getInputValue(batch: u32, channel: u32, row: u32, col: u32) -> f32 {
|
||||
var inputIndices: ${inputIndicesHelper.iType};
|
||||
inputIndices[${heightIdx}] = max(0, min(row, ${inputShape[heightIdx]} - 1));
|
||||
inputIndices[${widthIdx}] = max(0, min(col, ${inputShape[widthIdx]} - 1));
|
||||
if (${inputShape.length} > 2) {
|
||||
inputIndices[${channelIdx}] = channel;
|
||||
inputIndices[${batchIdx}] = batch;
|
||||
};
|
||||
return input[${inputIndicesHelper.i2oExpression('inputIndices')}];
|
||||
}
|
||||
|
||||
fn bilinearInterpolation(outputIndices: ${outputIndicesHelper.iType}) -> f32 {
|
||||
var originalIndices = calculateOriginalIndicesFromOutputIndices(outputIndices);
|
||||
var row:f32 = originalIndices[${heightIdx}];
|
||||
var col:f32 = originalIndices[${widthIdx}];
|
||||
if (${useExtrapolation} && (row < 0 || row > (${inputShape[heightIdx]} - 1) || col < 0 || col > ${
|
||||
inputShape[widthIdx]} - 1)) {
|
||||
return ${extrapolationValue};
|
||||
}
|
||||
row = max(0, min(row, ${inputShape[heightIdx]} - 1));
|
||||
col = max(0, min(col, ${inputShape[widthIdx]} - 1));
|
||||
var row1: u32 = u32(row);
|
||||
var col1: u32 = u32(col);
|
||||
var row2: u32 = u32(row + 1);
|
||||
var col2: u32 = u32(col + 1);
|
||||
var channel: u32 = 0;
|
||||
var batch: u32 = 0;
|
||||
if (${inputShape.length > 2}) {
|
||||
channel = u32(originalIndices[${channelIdx}]);
|
||||
batch = u32(originalIndices[${batchIdx}]);
|
||||
}
|
||||
var x11: f32 = getInputValue(batch, channel, row1, col1);
|
||||
var x12: f32 = getInputValue(batch, channel, row1, col2);
|
||||
var x21: f32 = getInputValue(batch, channel, row2, col1);
|
||||
var x22: f32 = getInputValue(batch, channel, row2, col2);
|
||||
var dx1: f32 = row - f32(row1);
|
||||
var dx2: f32 = f32(row2 ) - row;
|
||||
var dy1 = col - f32(col1);
|
||||
var dy2 = f32(col2) - col;
|
||||
return (x11 * dx2 * dy2 + x12 * dx2 * dy1 + x21 * dx1 * dy2 + x22 * dx1 * dy1);
|
||||
}`;
|
||||
};
|
||||
|
||||
const bicubicInterpolation =
|
||||
(inputShape: readonly number[], outputShape: readonly number[], scales: readonly number[], roi: readonly number[],
|
||||
cubicCoeffA: number, useExtrapolation: boolean, extrapolationValue: number, excludeOutside: boolean): string => {
|
||||
const outputIndicesHelper = createIndicesHelper('output', outputShape);
|
||||
const inputIndicesHelper = createIndicesHelper('input', inputShape);
|
||||
const [heightIdx, widthIdx] = inputShape.length === 2 ? [0, 1] : (scales[1] === 1.0) ? [2, 3] : [1, 2];
|
||||
|
||||
const createCubicInterpolationFunction = (idx: number): string => {
|
||||
const direction = idx === heightIdx ? 'row' : 'col';
|
||||
return `
|
||||
fn ${direction}CubicInterpolation(inputIndices: ${inputIndicesHelper.iType}, outputIndices: ${
|
||||
outputIndicesHelper.iType}) -> f32 {
|
||||
var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : `outputIndices[${idx}]`};
|
||||
var originalIdx: f32 = getOriginalCoordinateFromResizedCoordinate(f32(outputIndex), ${scales[idx]},
|
||||
f32(${outputShape[idx]}), f32(${inputShape[idx]}), ${roi[idx]}, ${roi[idx]} + ${inputShape.length});
|
||||
var fractOriginalIdx: f32 = originalIdx - floor(originalIdx);
|
||||
var coefs = getCubicInterpolationCoefs(fractOriginalIdx);
|
||||
|
||||
if (${useExtrapolation} && (originalIdx < 0 || originalIdx > (${inputShape[idx]} - 1))) {
|
||||
return ${extrapolationValue};
|
||||
}
|
||||
var data: array<f32, 4> = array<f32, 4>(0.0, 0.0, 0.0, 0.0);
|
||||
for (var i: i32 = -1; i < 3; i++) {
|
||||
var ${direction}: f32 = originalIdx + f32(i);
|
||||
if (${direction} < 0 || ${direction} >= ${inputShape[idx]}) {
|
||||
if (${excludeOutside}) {
|
||||
coefs[i + 1] = 0.0;
|
||||
continue;
|
||||
} else if (${useExtrapolation}) {
|
||||
return ${extrapolationValue};
|
||||
} else {
|
||||
${direction} = max(0, min(${direction}, ${inputShape[idx]} - 1));
|
||||
}
|
||||
}
|
||||
var inputIndicesCopy: ${inputIndicesHelper.iType} = inputIndices;
|
||||
inputIndicesCopy[${idx}] = u32(${direction});
|
||||
data[i + 1] = ${idx === heightIdx ? `input[${inputIndicesHelper.i2oExpression('inputIndicesCopy')}];` : `
|
||||
rowCubicInterpolation(inputIndicesCopy, outputIndices);`}
|
||||
}
|
||||
return cubicInterpolation1D(data, coefs);
|
||||
}`;
|
||||
};
|
||||
|
||||
return `
|
||||
${createCubicInterpolationFunction(heightIdx)};
|
||||
${createCubicInterpolationFunction(widthIdx)};
|
||||
fn getCubicInterpolationCoefs(s: f32) -> array<f32, 4> {
|
||||
var absS = abs(s);
|
||||
var coeffs: array<f32, 4> = array<f32, 4>(0.0, 0.0, 0.0, 0.0);
|
||||
var oneMinusAbsS: f32 = 1.0 - absS;
|
||||
var twoMinusAbsS: f32 = 2.0 - absS;
|
||||
var onePlusAbsS: f32 = 1.0 + absS;
|
||||
coeffs[0] = ((${cubicCoeffA} * onePlusAbsS - 5 * ${cubicCoeffA}) * onePlusAbsS + 8 * ${
|
||||
cubicCoeffA}) * onePlusAbsS - 4 * ${cubicCoeffA};
|
||||
coeffs[1] = ((${cubicCoeffA} + 2) * absS - (${cubicCoeffA} + 3)) * absS * absS + 1;
|
||||
coeffs[2] = ((${cubicCoeffA} + 2) * oneMinusAbsS - (${cubicCoeffA} + 3)) * oneMinusAbsS * oneMinusAbsS + 1;
|
||||
coeffs[3] = ((${cubicCoeffA} * twoMinusAbsS - 5 * ${cubicCoeffA}) * twoMinusAbsS + 8 * ${
|
||||
cubicCoeffA}) * twoMinusAbsS - 4 * ${cubicCoeffA};
|
||||
return coeffs;
|
||||
}
|
||||
|
||||
fn cubicInterpolation1D(x: array<f32, 4>, coefs: array<f32, 4>) -> f32 {
|
||||
var coefsSum: f32 = coefs[0] + coefs[1] + coefs[2] + coefs[3];
|
||||
return (x[0] * coefs[0] + x[1] * coefs[1]+ x[2] * coefs[2]+ x[3] * coefs[3]) / coefsSum;
|
||||
}
|
||||
|
||||
fn bicubicInterpolation(outputIndices: ${outputIndicesHelper.iType}) -> f32 {
|
||||
var inputIndices: ${inputIndicesHelper.iType} = outputIndices;
|
||||
return colCubicInterpolation(inputIndices, outputIndices);
|
||||
}
|
||||
`;
|
||||
};
|
||||
|
||||
const createResizeProgramInfo =
|
||||
(metadata: ProgramMetadata, input: TensorView, attributes: ResizeAttributes, opsetVersion: number,
|
||||
scalesInput: readonly number[], sizes: readonly number[], roiInput: readonly number[]): ProgramInfo => {
|
||||
const inputShape = input.dims;
|
||||
const roi = updateRoI(roiInput, attributes.axes, inputShape.length);
|
||||
|
||||
let outputShape = initOutputShape(inputShape, scalesInput, sizes, attributes.axes);
|
||||
let scales = scalesInput.slice();
|
||||
if (scalesInput.length === 0) {
|
||||
scales = inputShape.map((value, index) => value === 0 ? 1.0 : outputShape[index] / value);
|
||||
if (attributes.keepAspectRatioPolicy !== 'stretch') {
|
||||
outputShape = adjustOutputShape(inputShape, outputShape, scales, attributes);
|
||||
}
|
||||
}
|
||||
const outputIndicesHelper = createIndicesHelper('output', outputShape);
|
||||
const inputIndicesHelper = createIndicesHelper('input', inputShape);
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
const dataType = 'f32';
|
||||
const noScale = inputShape.length === outputShape.length && inputShape.every((d, i) => d === outputShape[i]);
|
||||
const useExtrapolation = attributes.coordinateTransformMode === 'tf_crop_and_resize';
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
${getOriginalCoordinateFromResizedCoordinate(attributes.coordinateTransformMode)};
|
||||
${(() => {
|
||||
switch (attributes.mode) {
|
||||
case 'nearest':
|
||||
return `
|
||||
${checkInputIndices(inputShape)};
|
||||
${getNearestPixelFromOriginal(attributes.nearestMode, opsetVersion)};
|
||||
${calculateInputIndicesFromOutputIndices(inputShape, outputShape, scales, roi, useExtrapolation)};
|
||||
`;
|
||||
case 'linear':
|
||||
return `
|
||||
${calculateOriginalIndicesFromOutputIndices(inputShape, outputShape, scales, roi)};
|
||||
${
|
||||
bilinearInterpolation(
|
||||
inputShape, outputShape, scales, useExtrapolation, attributes.extrapolationValue)};
|
||||
`;
|
||||
case 'cubic':
|
||||
return `
|
||||
${
|
||||
bicubicInterpolation(
|
||||
inputShape, outputShape, scales, roi, attributes.cubicCoeffA, useExtrapolation,
|
||||
attributes.extrapolationValue, attributes.excludeOutside)};
|
||||
`;
|
||||
default:
|
||||
throw Error('Invalid resize mode');
|
||||
}
|
||||
})()};
|
||||
@group(0) @binding(0) var<storage, read> input : array<${dataType}>;
|
||||
@group(0) @binding(1) var<storage, read_write> output : array<${dataType}>;
|
||||
${outputIndicesHelper.o2iImpl}
|
||||
${inputIndicesHelper.i2oImpl}
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
|
||||
if (${noScale}) {
|
||||
output[global_idx] = input[global_idx];
|
||||
} else {
|
||||
${outputIndicesHelper.indicesVariableDeclaration('outputIndices')}
|
||||
${outputIndicesHelper.o2iCall('global_idx', 'outputIndices')}
|
||||
${inputIndicesHelper.indicesVariableDeclaration('inputIndices')}
|
||||
${(() => {
|
||||
switch (attributes.mode) {
|
||||
case 'nearest':
|
||||
return `inputIndices = calculateInputIndicesFromOutputIndices(outputIndices);
|
||||
if (checkInputIndices(inputIndices)) {
|
||||
output[global_idx] = input[${inputIndicesHelper.i2oExpression('inputIndices')}];
|
||||
} else {
|
||||
output[global_idx] = ${attributes.extrapolationValue};
|
||||
}`;
|
||||
case 'linear':
|
||||
return 'output[global_idx] = bilinearInterpolation(outputIndices);';
|
||||
case 'cubic':
|
||||
return 'output[global_idx] = bicubicInterpolation(outputIndices);';
|
||||
default:
|
||||
throw Error(`Unsupported resize mode: ${attributes.mode}`);
|
||||
}
|
||||
})()};
|
||||
}
|
||||
}`;
|
||||
|
||||
return {
|
||||
...metadata,
|
||||
getShaderSource,
|
||||
outputs: [{dims: outputShape, dataType: input.dataType, gpuDataType: GpuDataType.default}],
|
||||
dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)})
|
||||
};
|
||||
};
|
||||
|
||||
export const createResizeProgramInfoLoader =
|
||||
(input: TensorView, attributes: ResizeAttributes, opsetVersion: number, scales: readonly number[],
|
||||
sizes: readonly number[], roi: readonly number[]): ProgramInfoLoader => {
|
||||
const metadata: ProgramMetadata = {
|
||||
name: 'Resize',
|
||||
inputTypes: [GpuDataType.default],
|
||||
cacheHint: attributes.cacheKey + opsetVersion.toString() +
|
||||
(scales.length > 0 ? '_scales_' + scales.toString() : '') +
|
||||
(sizes.length > 0 ? '_sizes_' + sizes.toString() : ''),
|
||||
};
|
||||
return {
|
||||
...metadata,
|
||||
get: () => createResizeProgramInfo(metadata, input, attributes, opsetVersion, scales, sizes, roi)
|
||||
};
|
||||
};
|
||||
|
||||
const getOpsetVersionFromCustomDataBuffer = (context: ComputeContext): number => {
|
||||
const customDataBuffer = context.customDataBuffer;
|
||||
const customDataBuffer32 = new Uint32Array(customDataBuffer, customDataBuffer.byteOffset, 1);
|
||||
const opsetVersion = customDataBuffer32[0];
|
||||
return opsetVersion;
|
||||
};
|
||||
|
||||
export const resize = (context: ComputeContext, attributes: ResizeAttributes): void => {
|
||||
const scales: number[] = [];
|
||||
const sizes: number[] = [];
|
||||
const roi: number[] = [];
|
||||
const opsetVersion = getOpsetVersionFromCustomDataBuffer(context);
|
||||
validateInputs(context.inputs, attributes, opsetVersion, scales, sizes, roi);
|
||||
context.compute(
|
||||
createResizeProgramInfoLoader(context.inputs[0], attributes, opsetVersion, scales, sizes, roi), {inputs: [0]});
|
||||
};
|
||||
|
||||
export const parseResizeAttributes = (attributes: Record<string, unknown>): ResizeAttributes => {
|
||||
const antialias = attributes.antialias as number;
|
||||
const axes = attributes.axes as number[];
|
||||
const coordinateTransformMode: CoordinateTransformMode =
|
||||
attributes.coordinateTransformMode as CoordinateTransformMode;
|
||||
const cubicCoeffA = attributes.cubicCoeffA as number;
|
||||
const excludeOutside = attributes.excludeOutside as number !== 0;
|
||||
const extrapolationValue = attributes.extrapolationValue as number;
|
||||
const keepAspectRatioPolicy: KeepAspectRatioPolicy = attributes.keepAspectRatioPolicy as KeepAspectRatioPolicy;
|
||||
const mode: Mode = attributes.mode as Mode;
|
||||
// If nearestMode is not specified, use simple mode.
|
||||
const nearestMode: NearestMode = (attributes.nearestMode === '' ? 'simple' : attributes.nearestMode) as NearestMode;
|
||||
return createAttributeWithCacheKey({
|
||||
antialias,
|
||||
axes,
|
||||
coordinateTransformMode,
|
||||
cubicCoeffA,
|
||||
excludeOutside,
|
||||
extrapolationValue,
|
||||
keepAspectRatioPolicy,
|
||||
mode,
|
||||
nearestMode
|
||||
});
|
||||
};
|
||||
|
|
@ -196,6 +196,5 @@ export const parseSliceAttributes = (attributes: Record<string, unknown>): Slice
|
|||
const starts = attributes.starts as number[];
|
||||
const ends = attributes.ends as number[];
|
||||
const axes = attributes.axes as number[];
|
||||
const steps: number[] = [];
|
||||
return createAttributeWithCacheKey({starts, ends, axes, steps});
|
||||
return createAttributeWithCacheKey({starts, ends, axes});
|
||||
};
|
||||
|
|
|
|||
|
|
@ -10,10 +10,12 @@ import path from 'path';
|
|||
const COMMENTS: Record<string, string> = {
|
||||
'AveragePool': 'need perf optimization; need implementing activation',
|
||||
'MaxPool': 'need perf optimization; need implementing activation',
|
||||
'Conv': 'need perf optimization; conv3d not supported; need implementing activation',
|
||||
'Conv': 'need perf optimization; conv3d is not supported; need implementing activation',
|
||||
'ConvTranspose': 'need perf optimization; ConvTranspose3d is not supported; need implementing activation',
|
||||
'Transpose': 'need perf optimization',
|
||||
'Reshape': 'no GPU kernel',
|
||||
'Shape': 'no GPU kernel; an ORT warning is generated - need to fix',
|
||||
'Resize': 'CoordinateTransformMode align_corners is not supported with downsampling',
|
||||
};
|
||||
|
||||
/* eslint-disable max-len */
|
||||
|
|
|
|||
|
|
@ -5,7 +5,12 @@
|
|||
"ops": []
|
||||
},
|
||||
"webgl": {
|
||||
"onnx": ["resnet50", "squeezenet", "tiny_yolov2", "emotion_ferplus"],
|
||||
"onnx": [
|
||||
"resnet50",
|
||||
"squeezenet",
|
||||
"tiny_yolov2",
|
||||
"emotion_ferplus"
|
||||
],
|
||||
"node": [
|
||||
"test_abs",
|
||||
"test_acos_example",
|
||||
|
|
@ -976,34 +981,34 @@
|
|||
"test_reshape_reordered_last_dims",
|
||||
"test_reshape_zero_and_negative_dim",
|
||||
"test_reshape_zero_dim",
|
||||
// "test_resize_downsample_linear",
|
||||
// "test_resize_downsample_nearest",
|
||||
// "test_resize_downsample_scales_cubic_A_n0p5_exclude_outside",
|
||||
"test_resize_downsample_linear",
|
||||
"test_resize_downsample_nearest",
|
||||
"test_resize_downsample_scales_cubic_A_n0p5_exclude_outside",
|
||||
// "test_resize_downsample_scales_cubic_align_corners",
|
||||
// "test_resize_downsample_scales_cubic",
|
||||
"test_resize_downsample_scales_cubic",
|
||||
// "test_resize_downsample_scales_linear_align_corners",
|
||||
// "test_resize_downsample_scales_linear",
|
||||
// "test_resize_downsample_scales_nearest",
|
||||
// "test_resize_downsample_sizes_cubic",
|
||||
// "test_resize_downsample_sizes_linear_pytorch_half_pixel",
|
||||
// "test_resize_downsample_sizes_nearest_tf_half_pixel_for_nn",
|
||||
// "test_resize_downsample_sizes_nearest",
|
||||
// "test_resize_nearest",
|
||||
// "test_resize_tf_crop_and_resize",
|
||||
// "test_resize_upsample_linear",
|
||||
// "test_resize_upsample_nearest",
|
||||
// "test_resize_upsample_scales_cubic_A_n0p5_exclude_outside",
|
||||
// "test_resize_upsample_scales_cubic_align_corners",
|
||||
// "test_resize_upsample_scales_cubic_asymmetric",
|
||||
// "test_resize_upsample_scales_cubic",
|
||||
// "test_resize_upsample_scales_linear_align_corners",
|
||||
// "test_resize_upsample_scales_linear",
|
||||
// "test_resize_upsample_scales_nearest",
|
||||
// "test_resize_upsample_sizes_cubic",
|
||||
// "test_resize_upsample_sizes_nearest_ceil_half_pixel",
|
||||
// "test_resize_upsample_sizes_nearest_floor_align_corners",
|
||||
// "test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric",
|
||||
// "test_resize_upsample_sizes_nearest",
|
||||
"test_resize_downsample_scales_linear",
|
||||
"test_resize_downsample_scales_nearest",
|
||||
"test_resize_downsample_sizes_cubic",
|
||||
"test_resize_downsample_sizes_linear_pytorch_half_pixel",
|
||||
"test_resize_downsample_sizes_nearest_tf_half_pixel_for_nn",
|
||||
"test_resize_downsample_sizes_nearest",
|
||||
"test_resize_nearest",
|
||||
"test_resize_tf_crop_and_resize",
|
||||
"test_resize_upsample_linear",
|
||||
"test_resize_upsample_nearest",
|
||||
"test_resize_upsample_scales_cubic_A_n0p5_exclude_outside",
|
||||
"test_resize_upsample_scales_cubic_align_corners",
|
||||
"test_resize_upsample_scales_cubic_asymmetric",
|
||||
"test_resize_upsample_scales_cubic",
|
||||
"test_resize_upsample_scales_linear_align_corners",
|
||||
"test_resize_upsample_scales_linear",
|
||||
"test_resize_upsample_scales_nearest",
|
||||
"test_resize_upsample_sizes_cubic",
|
||||
"opset{12,13,17,18}/test_resize_upsample_sizes_nearest_ceil_half_pixel",
|
||||
"opset{12,13,17,18}/test_resize_upsample_sizes_nearest_floor_align_corners",
|
||||
"opset{12,13,17,18}/test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric",
|
||||
"test_resize_upsample_sizes_nearest",
|
||||
// // "test_reversesequence_batch",
|
||||
// // "test_reversesequence_time",
|
||||
// // "test_rnn_seq_length",
|
||||
|
|
@ -1364,7 +1369,12 @@
|
|||
]
|
||||
},
|
||||
"wasm": {
|
||||
"onnx": ["resnet50", "squeezenet", "tiny_yolov2", "emotion_ferplus"],
|
||||
"onnx": [
|
||||
"resnet50",
|
||||
"squeezenet",
|
||||
"tiny_yolov2",
|
||||
"emotion_ferplus"
|
||||
],
|
||||
"node": [
|
||||
// Check in node tests that have native Wasm implementations
|
||||
// (i.e.) not tests that rely on the fallback cpu implementations
|
||||
|
|
@ -1463,7 +1473,12 @@
|
|||
"ops": []
|
||||
},
|
||||
"webnn": {
|
||||
"onnx": ["resnet50", "squeezenet", "tiny_yolov2", "emotion_ferplus"],
|
||||
"onnx": [
|
||||
"resnet50",
|
||||
"squeezenet",
|
||||
"tiny_yolov2",
|
||||
"emotion_ferplus"
|
||||
],
|
||||
"node": [
|
||||
// Check in node tests that have native Wasm implementations.
|
||||
// (i.e.) not tests that rely on the fallback cpu implementations.
|
||||
|
|
|
|||
|
|
@ -249,6 +249,11 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, Split);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 8, 12, Expand);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Expand);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, Resize);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Resize);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Resize);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Resize);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Resize);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 9, Slice);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, Slice);
|
||||
|
|
@ -443,6 +448,12 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 8, 12, Expand)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Expand)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, Resize)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Resize)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Resize)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Resize)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Resize)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 9, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, Slice)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Slice)>,
|
||||
|
|
|
|||
|
|
@ -102,7 +102,14 @@ class JsKernel : public OpKernel {
|
|||
size_t index = 4;
|
||||
for (int i = 0; i < context->InputCount(); i++) {
|
||||
p_serialized_kernel_context[index++] = static_cast<uint32_t>(context->Input<Tensor>(i)->GetElementType());
|
||||
p_serialized_kernel_context[index++] = reinterpret_cast<uint32_t>(context->Input<Tensor>(i)->DataRaw());
|
||||
const auto* ptr = context->Input<Tensor>(i);
|
||||
// Skip if the input is only a placeholder.
|
||||
if (ptr == nullptr) {
|
||||
p_serialized_kernel_context[index++] = 0;
|
||||
p_serialized_kernel_context[index++] = 0;
|
||||
continue;
|
||||
}
|
||||
p_serialized_kernel_context[index++] = reinterpret_cast<uint32_t>(ptr->DataRaw());
|
||||
p_serialized_kernel_context[index++] = static_cast<uint32_t>(context->Input<Tensor>(i)->Shape().NumDimensions());
|
||||
for (size_t d = 0; d < context->Input<Tensor>(i)->Shape().NumDimensions(); d++) {
|
||||
p_serialized_kernel_context[index++] = static_cast<uint32_t>(context->Input<Tensor>(i)->Shape()[d]);
|
||||
|
|
|
|||
74
onnxruntime/core/providers/js/operators/resize.cc
Normal file
74
onnxruntime/core/providers/js/operators/resize.cc
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "resize.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace js {
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Resize,
|
||||
kOnnxDomain,
|
||||
10, 10,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 1)
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Resize);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Resize,
|
||||
kOnnxDomain,
|
||||
11, 12,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 1) // roi
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 2) // scales
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 3) // sizes
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<float>())
|
||||
.TypeConstraint("T2", DataTypeImpl::GetTensorType<float>()),
|
||||
Resize);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Resize,
|
||||
kOnnxDomain,
|
||||
13,
|
||||
17,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 1)
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 2)
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 3)
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<float>())
|
||||
.TypeConstraint("T2", DataTypeImpl::GetTensorType<float>()),
|
||||
Resize);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Resize,
|
||||
kOnnxDomain,
|
||||
18,
|
||||
18,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 1)
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 2)
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 3)
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<float>())
|
||||
.TypeConstraint("T2", DataTypeImpl::GetTensorType<float>()),
|
||||
Resize);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
Resize,
|
||||
kOnnxDomain,
|
||||
19,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 1)
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 2)
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 3)
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<float>())
|
||||
.TypeConstraint("T2", DataTypeImpl::GetTensorType<float>()),
|
||||
Resize);
|
||||
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
138
onnxruntime/core/providers/js/operators/resize.h
Normal file
138
onnxruntime/core/providers/js/operators/resize.h
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/cpu/tensor/upsamplebase.h"
|
||||
#include "core/providers/js/js_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace js {
|
||||
|
||||
class Resize : public JsKernel, public UpsampleBase {
|
||||
public:
|
||||
Resize(const OpKernelInfo& info) : JsKernel(info), UpsampleBase(info) {
|
||||
const auto& node = info.node();
|
||||
opset_ = node.SinceVersion();
|
||||
|
||||
auto resize_coordinate_transformation_mode = ResizeCoordinateTransformationModeToString(coordinate_transform_mode_);
|
||||
auto keep_aspect_ratio_policy = KeepAspectRatioPolicyToString(keep_aspect_ratio_policy_);
|
||||
auto nearest_mode = NearestModeToString(nearest_mode_);
|
||||
auto mode = UpsampleModeToString(mode_);
|
||||
std::vector<int32_t> axes;
|
||||
std::transform(axes_.begin(), axes_.end(), std::back_inserter(axes), [](auto& axis) { return gsl::narrow_cast<int32_t>(axis); });
|
||||
JSEP_INIT_KERNEL_ATTRIBUTE(Resize, ({
|
||||
"antialias" : $1,
|
||||
"axes" : $2 ? Array.from(HEAP32.subarray($3, $3 + $2)) : [],
|
||||
"coordinateTransformMode" : UTF8ToString($4),
|
||||
"cubicCoeffA" : $5,
|
||||
"excludeOutside" : $6,
|
||||
"extrapolationValue" : $7,
|
||||
"keepAspectRatioPolicy" : UTF8ToString($8),
|
||||
"mode" : UTF8ToString($9),
|
||||
"nearestMode" : UTF8ToString($10),
|
||||
}),
|
||||
static_cast<int32_t>(antialias_),
|
||||
gsl::narrow_cast<int32_t>(axes.size()),
|
||||
reinterpret_cast<int32_t>((axes.size() > 0) ? axes.data() : nullptr) >> 2,
|
||||
resize_coordinate_transformation_mode.c_str(),
|
||||
static_cast<double>(cubic_coeff_a_),
|
||||
static_cast<int32_t>(exclude_outside_),
|
||||
static_cast<double>(extrapolation_value_),
|
||||
keep_aspect_ratio_policy.c_str(),
|
||||
mode.c_str(),
|
||||
nearest_mode.c_str());
|
||||
}
|
||||
|
||||
std::string UpsampleModeToString(UpsampleMode mode) {
|
||||
switch (mode) {
|
||||
case UpsampleMode::NN:
|
||||
return UpsampleModeNN;
|
||||
case UpsampleMode::LINEAR:
|
||||
return UpsampleModeLinear;
|
||||
case UpsampleMode::CUBIC:
|
||||
return UpsampleModeCubic;
|
||||
default:
|
||||
ORT_THROW("UpsampleMode is not supported!");
|
||||
}
|
||||
}
|
||||
|
||||
std::string KeepAspectRatioPolicyToString(AspectRatioPolicy policy) {
|
||||
switch (policy) {
|
||||
case AspectRatioPolicy::STRETCH:
|
||||
return "stretch";
|
||||
case AspectRatioPolicy::NOT_LARGER:
|
||||
return "not_larger";
|
||||
case AspectRatioPolicy::NOT_SMALLER:
|
||||
return "not_smaller";
|
||||
default:
|
||||
ORT_THROW("AspectRatioPolicy is not supported!");
|
||||
}
|
||||
}
|
||||
|
||||
std::string ResizeCoordinateTransformationModeToString(const ResizeCoordinateTransformationMode mode) {
|
||||
switch (mode) {
|
||||
case ASYMMETRIC:
|
||||
return "asymmetric";
|
||||
case PYTORCH_HALF_PIXEL:
|
||||
return "pytorch_half_pixel";
|
||||
case TF_HALF_PIXEL_FOR_NN:
|
||||
return "tf_half_pixel_for_nn";
|
||||
case ALIGN_CORNERS:
|
||||
return "align_corners";
|
||||
case TF_CROP_AND_RESIZE:
|
||||
return "tf_crop_and_resize";
|
||||
case HALF_PIXEL:
|
||||
return "half_pixel";
|
||||
case HALF_PIXEL_SYMMETRIC:
|
||||
return "half_pixel_symmetric";
|
||||
default:
|
||||
ORT_THROW("ResizeCoordinateTransformationMode is not supported!");
|
||||
}
|
||||
}
|
||||
|
||||
std::string NearestModeToString(const ResizeNearestMode mode) {
|
||||
switch (mode) {
|
||||
case ROUND_PREFER_FLOOR:
|
||||
return "round_prefer_floor";
|
||||
case ROUND_PREFER_CEIL:
|
||||
return "round_prefer_ceil";
|
||||
case FLOOR:
|
||||
return "floor";
|
||||
case CEIL:
|
||||
return "ceil";
|
||||
default:
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
virtual Status SerializeCustomData(OpKernelContext* context, AllocatorPtr alloc, void** ptr, size_t* size) const {
|
||||
TensorShapeVector output_dims;
|
||||
std::vector<float> roi_array;
|
||||
std::vector<float> scales_array;
|
||||
|
||||
// Compute the size of the custom data
|
||||
size_t customDataSize = sizeof(int32_t);
|
||||
|
||||
// Allocate memory for custom data
|
||||
void* p_custom_data = alloc->Alloc(customDataSize);
|
||||
|
||||
// Validate the memory allocation
|
||||
if (p_custom_data == nullptr) {
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, "failed to allocate memory for the custom data");
|
||||
}
|
||||
|
||||
// Serialize the custom data
|
||||
int32_t* p_int32 = reinterpret_cast<int32_t*>(p_custom_data);
|
||||
*p_int32 = static_cast<int32_t>(opset_);
|
||||
|
||||
*ptr = p_custom_data;
|
||||
*size = customDataSize;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int opset_;
|
||||
};
|
||||
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue