mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
[JS/WebGPU] Add trilinear interpolation to Resize; activation_params attribute is optional for FusedConv also. (#18842)
### Description Add trilinear interpolation to Resize and changed activation_params attribute as optional for FuseConv. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
31d4a21c4b
commit
3bbe4fe2ff
4 changed files with 254 additions and 51 deletions
|
|
@ -219,7 +219,7 @@ const initOutputShape =
|
|||
return outputShape;
|
||||
};
|
||||
|
||||
const adjustOutputShape = (inputShape: readonly number[], scales: number[], attributes: ResizeAttributes): number[] => {
|
||||
const adjustOutputShape = (inputShape: readonly number[], scales: number[], attributes: ResizeAttributes) => {
|
||||
const scaleInPolicy = (() => {
|
||||
switch (attributes.keepAspectRatioPolicy) {
|
||||
case 'not_larger':
|
||||
|
|
@ -312,21 +312,27 @@ const checkInputIndices = (input: IndicesHelper, inputShape: readonly number[]):
|
|||
return true;
|
||||
}`;
|
||||
|
||||
const setChannelAndBatchIndices =
|
||||
(input: IndicesHelper, channelIdx: number, batchIdx: number, spacialDims: number): string =>
|
||||
input.rank > spacialDims ? `
|
||||
${input.indicesSet('input_indices', channelIdx, 'channel')};
|
||||
${input.indicesSet('input_indices', batchIdx, 'batch')};
|
||||
` :
|
||||
'';
|
||||
|
||||
const bilinearInterpolation =
|
||||
(input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], scales: readonly number[],
|
||||
useExtrapolation: boolean, extrapolationValue: number): string => {
|
||||
(input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], useExtrapolation: boolean,
|
||||
extrapolationValue: number): string => {
|
||||
const isNchw = true;
|
||||
const [batchIdx, heightIdx, widthIdx, channelIdx] =
|
||||
inputShape.length === 2 ? [-1, 0, 1, -1] : (scales[1] === 1.0 ? [0, 2, 3, 1] : [0, 1, 2, 3]);
|
||||
inputShape.length === 2 ? [-1, 0, 1, -1] : (isNchw ? [0, 2, 3, 1] : [0, 1, 2, 3]);
|
||||
const dType = input.type.value;
|
||||
return `
|
||||
fn getInputValue(batch: u32, channel: u32, row: u32, col: u32) -> ${dType} {
|
||||
var input_indices: ${input.type.indices};
|
||||
${input.indicesSet('input_indices', heightIdx, `max(0, min(row, ${inputShape[heightIdx]} - 1))`)};
|
||||
${input.indicesSet('input_indices', widthIdx, `max(0, min(col, ${inputShape[widthIdx]} - 1))`)};
|
||||
if (${inputShape.length} > 2) {
|
||||
${input.indicesSet('input_indices', channelIdx, 'channel')};
|
||||
${input.indicesSet('input_indices', batchIdx, 'batch')};
|
||||
};
|
||||
${setChannelAndBatchIndices(input, channelIdx, batchIdx, 2)}
|
||||
return ${input.getByIndices('input_indices')};
|
||||
}
|
||||
|
||||
|
|
@ -334,30 +340,36 @@ const bilinearInterpolation =
|
|||
var originalIndices = calculateOriginalIndicesFromOutputIndices(output_indices);
|
||||
var row:${dType} = originalIndices[${heightIdx}];
|
||||
var col:${dType} = originalIndices[${widthIdx}];
|
||||
if (${useExtrapolation} && (row < 0 || row > (${inputShape[heightIdx]} - 1) || col < 0 || col > ${
|
||||
inputShape[widthIdx]} - 1)) {
|
||||
${
|
||||
useExtrapolation ?
|
||||
`if (row < 0 || row > (${inputShape[heightIdx]} - 1) || col < 0 || col > (${inputShape[widthIdx]} - 1))) {
|
||||
return ${extrapolationValue};
|
||||
}
|
||||
}` :
|
||||
''};
|
||||
row = max(0, min(row, ${inputShape[heightIdx]} - 1));
|
||||
col = max(0, min(col, ${inputShape[widthIdx]} - 1));
|
||||
var row1: u32 = u32(row);
|
||||
var col1: u32 = u32(col);
|
||||
var row2: u32 = u32(row + 1);
|
||||
var col2: u32 = u32(col + 1);
|
||||
var channel: u32 = 0;
|
||||
var batch: u32 = 0;
|
||||
if (${inputShape.length > 2}) {
|
||||
channel = u32(originalIndices[${channelIdx}]);
|
||||
batch = u32(originalIndices[${batchIdx}]);
|
||||
}
|
||||
var channel: u32 = ${inputShape.length > 2 ? `u32(originalIndices[${channelIdx}])` : '0'};
|
||||
var batch: u32 = ${inputShape.length > 2 ? `u32(originalIndices[${batchIdx}])` : '0'};
|
||||
var x11: ${dType} = getInputValue(batch, channel, row1, col1);
|
||||
var x12: ${dType} = getInputValue(batch, channel, row1, col2);
|
||||
var x21: ${dType} = getInputValue(batch, channel, row2, col1);
|
||||
var x22: ${dType} = getInputValue(batch, channel, row2, col2);
|
||||
var dx1: ${dType} = row - ${dType}(row1);
|
||||
var dx2: ${dType} = ${dType}(row2) - row;
|
||||
var dy1 = col - ${dType}(col1);
|
||||
var dy2 = ${dType}(col2) - col;
|
||||
var dx1: ${dType} = abs(row - ${dType}(row1));
|
||||
var dx2: ${dType} = abs(${dType}(row2) - row);
|
||||
var dy1: ${dType} = abs(col - ${dType}(col1));
|
||||
var dy2: ${dType} = abs(${dType}(col2) - col);
|
||||
if (row1 == row2) {
|
||||
dx1 = 0.5;
|
||||
dx2 = 0.5;
|
||||
}
|
||||
if (col1 == col2) {
|
||||
dy1 = 0.5;
|
||||
dy2 = 0.5;
|
||||
}
|
||||
return (x11 * dx2 * dy2 + x12 * dx2 * dy1 + x21 * dx1 * dy2 + x22 * dx1 * dy1);
|
||||
}`;
|
||||
};
|
||||
|
|
@ -366,7 +378,9 @@ const bicubicInterpolation =
|
|||
(input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[],
|
||||
scales: readonly number[], roi: readonly number[], cubicCoeffA: number, useExtrapolation: boolean,
|
||||
extrapolationValue: number, excludeOutside: boolean): string => {
|
||||
const [heightIdx, widthIdx] = inputShape.length === 2 ? [0, 1] : (scales[1] === 1.0) ? [2, 3] : [1, 2];
|
||||
const is2D = inputShape.length === 2;
|
||||
const isNchw = true;
|
||||
const [heightIdx, widthIdx] = is2D ? [0, 1] : isNchw ? [2, 3] : [1, 2];
|
||||
const dType = input.type.value;
|
||||
const createCubicInterpolationFunction = (idx: number): string => {
|
||||
const direction = idx === heightIdx ? 'row' : 'col';
|
||||
|
|
@ -386,16 +400,18 @@ const bicubicInterpolation =
|
|||
for (var i: i32 = -1; i < 3; i++) {
|
||||
var ${direction}: ${dType} = originalIdx + ${dType}(i);
|
||||
if (${direction} < 0 || ${direction} >= ${inputShape[idx]}) {
|
||||
if (${excludeOutside}) {
|
||||
coefs[i + 1] = 0.0;
|
||||
continue;
|
||||
} else if (${useExtrapolation}) {
|
||||
return ${extrapolationValue};
|
||||
} else {
|
||||
${direction} = max(0, min(${direction}, ${inputShape[idx]} - 1));
|
||||
}
|
||||
${(() => {
|
||||
if (excludeOutside) {
|
||||
return `coefs[i + 1] = 0.0;
|
||||
continue;`;
|
||||
} else if (useExtrapolation) {
|
||||
return `return ${extrapolationValue};`;
|
||||
} else {
|
||||
return `${direction} = max(0, min(${direction}, ${inputShape[idx]} - 1));`;
|
||||
}
|
||||
var input_indices_copy: ${input.type.indices} = input_indices;
|
||||
})()};
|
||||
}
|
||||
var input_indices_copy: ${input.type.indices} = input_indices;
|
||||
${input.indicesSet('input_indices_copy', idx, `u32(${direction})`)};
|
||||
data[i + 1] = ${
|
||||
idx === heightIdx ? input.getByIndices('input_indices_copy') :
|
||||
|
|
@ -435,6 +451,78 @@ const bicubicInterpolation =
|
|||
`;
|
||||
};
|
||||
|
||||
const trilinearInterpolation =
|
||||
(input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], useExtrapolation: boolean,
|
||||
extrapolationValue: number): string => {
|
||||
const isNchw = true;
|
||||
const [batchIdx, depthIdx, heightIdx, widthIdx, channelIdx] =
|
||||
inputShape.length === 3 ? [-1, 0, 1, 2, -1] : (isNchw ? [0, 2, 3, 4, 1] : [0, 1, 2, 3, 4]);
|
||||
const dType = input.type.value;
|
||||
return `
|
||||
fn getInputValue(batch: u32, channel: u32, depth:u32, height: u32, width: u32) -> ${dType} {
|
||||
var input_indices: ${input.type.indices};
|
||||
${input.indicesSet('input_indices', depthIdx, `max(0, min(depth, ${inputShape[depthIdx]} - 1))`)};
|
||||
${input.indicesSet('input_indices', heightIdx, `max(0, min(height, ${inputShape[heightIdx]} - 1))`)};
|
||||
${input.indicesSet('input_indices', widthIdx, `max(0, min(width, ${inputShape[widthIdx]} - 1))`)};
|
||||
${setChannelAndBatchIndices(input, channelIdx, batchIdx, 3)}
|
||||
return ${input.getByIndices('input_indices')};
|
||||
}
|
||||
|
||||
fn trilinearInterpolation(output_indices: ${output.type.indices}) -> ${dType} {
|
||||
var originalIndices = calculateOriginalIndicesFromOutputIndices(output_indices);
|
||||
var depth:${dType} = originalIndices[${depthIdx}];
|
||||
var height:${dType} = originalIndices[${heightIdx}];
|
||||
var width:${dType} = originalIndices[${widthIdx}];
|
||||
${
|
||||
useExtrapolation ? `if (depth < 0 || depth > (${inputShape[depthIdx]} - 1) || height < 0 || height > (${
|
||||
inputShape[heightIdx]} - 1) || width < 0 || (width > ${inputShape[widthIdx]} - 1))) {
|
||||
return ${extrapolationValue};
|
||||
}` :
|
||||
''};
|
||||
|
||||
depth = max(0, min(depth, ${inputShape[depthIdx]} - 1));
|
||||
height = max(0, min(height, ${inputShape[heightIdx]} - 1));
|
||||
width = max(0, min(width, ${inputShape[widthIdx]} - 1));
|
||||
var depth1: u32 = u32(depth);
|
||||
var height1: u32 = u32(height);
|
||||
var width1: u32 = u32(width);
|
||||
var depth2: u32 = u32(depth + 1);
|
||||
var height2: u32 = u32(height + 1);
|
||||
var width2: u32 = u32(width + 1);
|
||||
var channel: u32 = ${inputShape.length > 3 ? `u32(originalIndices[${channelIdx}])` : '0'};
|
||||
var batch: u32 = ${inputShape.length > 3 ? `u32(originalIndices[${batchIdx}])` : '0'};
|
||||
|
||||
var x111: ${dType} = getInputValue(batch, channel, depth1, height1, width1);
|
||||
var x112: ${dType} = getInputValue(batch, channel, depth1, height1, width2);
|
||||
var x121: ${dType} = getInputValue(batch, channel, depth1, height2, width1);
|
||||
var x122: ${dType} = getInputValue(batch, channel, depth1, height2, width2);
|
||||
var x211: ${dType} = getInputValue(batch, channel, depth2, height1, width1);
|
||||
var x212: ${dType} = getInputValue(batch, channel, depth2, height1, width2);
|
||||
var x221: ${dType} = getInputValue(batch, channel, depth2, height2, width1);
|
||||
var x222: ${dType} = getInputValue(batch, channel, depth2, height2, width2);
|
||||
var dx1: ${dType} = abs(depth - ${dType}(depth1));
|
||||
var dx2: ${dType} = abs(${dType}(depth2) - depth);
|
||||
var dy1: ${dType} = abs(height - ${dType}(height1));
|
||||
var dy2: ${dType} = abs(${dType}(height2) - height);
|
||||
var dz1: ${dType} = abs(width - ${dType}(width1));
|
||||
var dz2: ${dType} = abs(${dType}(width2) - width);
|
||||
if (depth1 == depth2) {
|
||||
dx1 = 0.5;
|
||||
dx2 = 0.5;
|
||||
}
|
||||
if (height1 == height2) {
|
||||
dy1 = 0.5;
|
||||
dy2 = 0.5;
|
||||
}
|
||||
if (width1 == width2) {
|
||||
dz1 = 0.5;
|
||||
dz2 = 0.5;
|
||||
}
|
||||
return (x111 * dx2 * dy2 * dz2 + x112 * dx2 * dy2 * dz1 + x121 * dx2 * dy1 *dz2 + x122 * dx2 * dy1 * dz1 +
|
||||
x211 * dx1 * dy2 * dz2 + x212 * dx1 * dy2 * dz1 + x221 * dx1 * dy1 *dz2 + x222 * dx1 * dy1 * dz1);
|
||||
}`;
|
||||
};
|
||||
|
||||
const createResizeProgramInfo =
|
||||
(inputTensor: TensorView, attributes: ResizeAttributes, opsetVersion: number, scalesInput: readonly number[],
|
||||
sizes: readonly number[], roiInput: readonly number[]): ProgramInfo => {
|
||||
|
|
@ -454,6 +542,7 @@ const createResizeProgramInfo =
|
|||
const outputSize = ShapeUtil.size(outputShape);
|
||||
const noScale = inputShape.length === outputShape.length && inputShape.every((d, i) => d === outputShape[i]);
|
||||
const useExtrapolation = attributes.coordinateTransformMode === 'tf_crop_and_resize';
|
||||
const extrapolationValue = attributes.extrapolationValue;
|
||||
const dataType = input.type.value;
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
${noScale ? '' : `
|
||||
|
|
@ -471,16 +560,28 @@ const createResizeProgramInfo =
|
|||
case 'linear':
|
||||
return `
|
||||
${calculateOriginalIndicesFromOutputIndices(output, inputShape, outputShape, scales.length, roi.length)};
|
||||
${
|
||||
bilinearInterpolation(
|
||||
input, output, inputShape, scales, useExtrapolation, attributes.extrapolationValue)};
|
||||
`;
|
||||
${(() => {
|
||||
if (inputShape.length === 2 || inputShape.length === 4) {
|
||||
return `${bilinearInterpolation(input, output, inputShape, useExtrapolation, extrapolationValue)}`;
|
||||
} else if (inputShape.length === 3 || inputShape.length === 5) {
|
||||
return `${trilinearInterpolation(input, output, inputShape, useExtrapolation, extrapolationValue)}`;
|
||||
} else {
|
||||
throw Error('Linear mode only supports input dims 2, 3, 4 and 5 are supported in linear mode.');
|
||||
}
|
||||
})()};
|
||||
`;
|
||||
case 'cubic':
|
||||
return `
|
||||
${
|
||||
bicubicInterpolation(
|
||||
input, output, inputShape, outputShape, scales, roi, attributes.cubicCoeffA, useExtrapolation,
|
||||
attributes.extrapolationValue, attributes.excludeOutside)};
|
||||
${(() => {
|
||||
if (inputShape.length === 2 || inputShape.length === 4) {
|
||||
return `${
|
||||
bicubicInterpolation(
|
||||
input, output, inputShape, outputShape, scales, roi, attributes.cubicCoeffA, useExtrapolation,
|
||||
attributes.extrapolationValue, attributes.excludeOutside)}`;
|
||||
} else {
|
||||
throw Error('Cubic mode only supports input dims 2 and 4 are supported in linear mode.');
|
||||
}
|
||||
})()};
|
||||
`;
|
||||
default:
|
||||
throw Error('Invalid resize mode');
|
||||
|
|
@ -507,21 +608,23 @@ const createResizeProgramInfo =
|
|||
output[global_idx] = ${attributes.extrapolationValue};
|
||||
}`;
|
||||
case 'linear':
|
||||
return 'output[global_idx] = bilinearInterpolation(output_indices);';
|
||||
return `output[global_idx] = ${
|
||||
(inputShape.length === 2 || inputShape.length === 4) ? 'bilinearInterpolation' :
|
||||
'trilinearInterpolation'}(output_indices);`;
|
||||
case 'cubic':
|
||||
return 'output[global_idx] = bicubicInterpolation(output_indices);';
|
||||
default:
|
||||
throw Error(`Unsupported resize mode: ${attributes.mode}`);
|
||||
}
|
||||
})()};
|
||||
`}
|
||||
`}
|
||||
}`;
|
||||
|
||||
return {
|
||||
name: 'Resize',
|
||||
shaderCache: {
|
||||
hint: `${attributes.cacheKey}|${opsetVersion}|${scales.length > 0 ? scales : ''}|${
|
||||
sizes.length > 0 ? sizes : ''}|${roi.length > 0 ? roi : ''}|${noScale}`,
|
||||
sizes.length > 0 ? sizes : ''}|${roi.length > 0 ? roi : ''}|${noScale}|${inputShape}`,
|
||||
inputDependencies: ['rank']
|
||||
},
|
||||
getShaderSource,
|
||||
|
|
@ -551,6 +654,9 @@ export const resize = (context: ComputeContext, attributes: ResizeAttributes): v
|
|||
const sizes: number[] = [];
|
||||
const roi: number[] = [];
|
||||
const opsetVersion = getOpsetVersionFromCustomDataBuffer(context);
|
||||
if (attributes.antialias !== 0) {
|
||||
throw Error('Only default value (0) for Antialias attribute is supported');
|
||||
}
|
||||
validateInputs(context.inputs, attributes, opsetVersion, scales, sizes, roi);
|
||||
context.compute(
|
||||
createResizeProgramInfo(context.inputs[0], attributes, opsetVersion, scales, sizes, roi), {inputs: [0]});
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
{
|
||||
"name": "Upsample - Nearest",
|
||||
"operator": "Upsample",
|
||||
"opset": { "domain": "", "version": 7 },
|
||||
"attributes": [{ "name": "scales", "data": [1.0, 1.0, 2.0, 3.0], "type": "floats" }],
|
||||
"cases": [
|
||||
{
|
||||
|
|
@ -32,6 +33,7 @@
|
|||
{
|
||||
"name": "Upsample - Nearest2X",
|
||||
"operator": "Upsample",
|
||||
"opset": { "domain": "", "version": 7 },
|
||||
"attributes": [{ "name": "scales", "data": [1.0, 1.0, 2.0, 2.0], "type": "floats" }],
|
||||
"cases": [
|
||||
{
|
||||
|
|
@ -60,6 +62,7 @@
|
|||
{
|
||||
"name": "Upsample - Nearest222X",
|
||||
"operator": "Upsample",
|
||||
"opset": { "domain": "", "version": 7 },
|
||||
"attributes": [{ "name": "scales", "data": [2.0, 1.0, 2.0, 2.0], "type": "floats" }],
|
||||
"cases": [
|
||||
{
|
||||
|
|
@ -92,6 +95,7 @@
|
|||
{
|
||||
"name": "Upsample - Nearest15X",
|
||||
"operator": "Upsample",
|
||||
"opset": { "domain": "", "version": 7 },
|
||||
"attributes": [{ "name": "scales", "data": [1.0, 1.0, 2.0, 1.5], "type": "floats" }],
|
||||
"cases": [
|
||||
{
|
||||
|
|
@ -120,6 +124,7 @@
|
|||
{
|
||||
"name": "Upsample - Nearest_NoScale",
|
||||
"operator": "Upsample",
|
||||
"opset": { "domain": "", "version": 7 },
|
||||
"attributes": [
|
||||
{ "name": "scales", "data": [1.0, 1.0, 1.0, 1.0], "type": "floats" },
|
||||
{ "name": "mode", "data": "nearest", "type": "string" }
|
||||
|
|
@ -147,6 +152,7 @@
|
|||
{
|
||||
"name": "Upsample - 4D Bilinear",
|
||||
"operator": "Upsample",
|
||||
"opset": { "domain": "", "version": 7 },
|
||||
"attributes": [
|
||||
{ "name": "scales", "data": [1.0, 1.0, 2.0, 4.0], "type": "floats" },
|
||||
{ "name": "mode", "data": "linear", "type": "string" }
|
||||
|
|
@ -180,6 +186,7 @@
|
|||
{
|
||||
"name": "Upsample - 2D Bilinear",
|
||||
"operator": "Upsample",
|
||||
"opset": { "domain": "", "version": 7 },
|
||||
"attributes": [
|
||||
{ "name": "scales", "data": [2.0, 4.0], "type": "floats" },
|
||||
{ "name": "mode", "data": "linear", "type": "string" }
|
||||
|
|
@ -210,6 +217,7 @@
|
|||
{
|
||||
"name": "Upsample - 4D Bilinear ScalesNoOp",
|
||||
"operator": "Upsample",
|
||||
"opset": { "domain": "", "version": 7 },
|
||||
"attributes": [
|
||||
{ "name": "scales", "data": [1.0, 1.0, 1.0, 1.0], "type": "floats" },
|
||||
{ "name": "mode", "data": "linear", "type": "string" }
|
||||
|
|
@ -237,6 +245,7 @@
|
|||
{
|
||||
"name": "Upsample - 1D Nearest",
|
||||
"operator": "Upsample",
|
||||
"opset": { "domain": "", "version": 7 },
|
||||
"attributes": [
|
||||
{ "name": "scales", "data": [2.0], "type": "floats" },
|
||||
{ "name": "mode", "data": "nearest", "type": "string" }
|
||||
|
|
@ -260,5 +269,98 @@
|
|||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Upsample - 5D Trilinear",
|
||||
"operator": "Upsample",
|
||||
"opset": { "domain": "", "version": 7 },
|
||||
"attributes": [
|
||||
{ "name": "scales", "data": [1.0, 1.0, 1.0, 2.0, 4.0], "type": "floats" },
|
||||
{ "name": "mode", "data": "linear", "type": "string" }
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "X",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1.0, 3.0, 3.0, 5.0, 3.0, 5.0, 7.0, 9.0],
|
||||
"dims": [1, 2, 1, 2, 2],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [
|
||||
1.0, 1.5, 2.0, 2.5, 3.0, 3.0, 3.0, 3.0, 2.0, 2.5, 3.0, 3.5, 4.0, 4.0, 4.0, 4.0, 3.0, 3.5, 4.0, 4.5, 5.0,
|
||||
5.0, 5.0, 5.0, 3.0, 3.5, 4.0, 4.5, 5.0, 5.0, 5.0, 5.0,
|
||||
|
||||
3.0, 3.5, 4.0, 4.5, 5.0, 5.0, 5.0, 5.0, 5.0, 5.5, 6.0, 6.5, 7.0, 7.0, 7.0, 7.0, 7.0, 7.5, 8.0, 8.5, 9.0,
|
||||
9.0, 9.0, 9.0, 7.0, 7.5, 8.0, 8.5, 9.0, 9.0, 9.0, 9.0
|
||||
],
|
||||
"dims": [1, 2, 1, 4, 8],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Upsample - 3D Trilinear",
|
||||
"operator": "Upsample",
|
||||
"opset": { "domain": "", "version": 7 },
|
||||
"attributes": [
|
||||
{ "name": "scales", "data": [1.0, 2.0, 4.0], "type": "floats" },
|
||||
{ "name": "mode", "data": "linear", "type": "string" }
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "X",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1.0, 3.0, 3.0, 5.0],
|
||||
"dims": [1, 2, 2],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [
|
||||
1.0, 1.5, 2.0, 2.5, 3.0, 3.0, 3.0, 3.0, 2.0, 2.5, 3.0, 3.5, 4.0, 4.0, 4.0, 4.0, 3.0, 3.5, 4.0, 4.5, 5.0,
|
||||
5.0, 5.0, 5.0, 3.0, 3.5, 4.0, 4.5, 5.0, 5.0, 5.0, 5.0
|
||||
],
|
||||
"dims": [1, 4, 8],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Upsample - 3D Trilinear ScalesNoOp",
|
||||
"operator": "Upsample",
|
||||
"opset": { "domain": "", "version": 7 },
|
||||
"attributes": [
|
||||
{ "name": "scales", "data": [1.0, 1.0, 1.0], "type": "floats" },
|
||||
{ "name": "mode", "data": "linear", "type": "string" }
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "X",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1.0, 3.0, 3.0, 5.0, 3.0, 5.0, 7.0, 9.0],
|
||||
"dims": [2, 2, 2],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [1.0, 3.0, 3.0, 5.0, 3.0, 5.0, 7.0, 9.0],
|
||||
"dims": [2, 2, 2],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1392,6 +1392,7 @@
|
|||
"tile.jsonc",
|
||||
"transpose.jsonc",
|
||||
"transpose_int32_uint32.jsonc",
|
||||
"upsample.jsonc",
|
||||
"where.jsonc"
|
||||
// Turn on this when https://github.com/microsoft/onnxruntime/issues/17405 is fixed.
|
||||
//"where_broadcast.jsonc",
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "core/providers/js/js_kernel.h"
|
||||
#include "core/providers/cpu/nn/conv_attributes.h"
|
||||
|
|
@ -17,7 +17,6 @@ class ConvBase : public JsKernel {
|
|||
ConvBase(const OpKernelInfo& info, bool is_channels_last, bool is_fused_conv) : JsKernel(info),
|
||||
conv_attrs_(info),
|
||||
w_is_const_(false) {
|
||||
std::vector<float> activation_params;
|
||||
TensorShapeVector kernel_shape;
|
||||
const size_t pads_vec_size = conv_attrs_.pads.size() == 0 ? 4 : conv_attrs_.pads.size();
|
||||
std::vector<int32_t> local_pads(pads_vec_size, 0);
|
||||
|
|
@ -28,13 +27,8 @@ class ConvBase : public JsKernel {
|
|||
if (conv_attrs_.kernel_shape_specified) {
|
||||
ORT_ENFORCE(info.GetAttrs("kernel_shape", kernel_shape).IsOK());
|
||||
}
|
||||
if (is_fused_conv) {
|
||||
ORT_THROW_IF_ERROR(info.GetAttr<std::string>("activation", &conv_attrs_.activation));
|
||||
ORT_THROW_IF_ERROR(info.GetAttrs<float>("activation_params", activation_params));
|
||||
} else {
|
||||
conv_attrs_.activation = info.GetAttrOrDefault<std::string>("activation", "");
|
||||
activation_params = info.GetAttrsOrDefault<float>("activation_params", activation_params);
|
||||
}
|
||||
conv_attrs_.activation = info.GetAttrOrDefault<std::string>("activation", "");
|
||||
std::vector<float> activation_params = info.GetAttrsOrDefault<float>("activation_params");
|
||||
const auto* activation_params_ptr = activation_params.size() > 0 ? activation_params.data() : nullptr;
|
||||
int64_t channels_last = is_channels_last ? 1 : info.GetAttrOrDefault<int64_t>("channels_last", 0);
|
||||
auto kernel_shape_0 = conv_attrs_.kernel_shape_specified && kernel_shape.size() > 0 ? kernel_shape[0] : 0;
|
||||
|
|
|
|||
Loading…
Reference in a new issue