mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[js/webgpu] support GridSample operator (#22652)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
d9b91682f1
commit
b5ee4ac760
7 changed files with 358 additions and 8 deletions
|
|
@ -56,6 +56,7 @@ Do not modify directly.*
|
|||
| GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
|
||||
| Greater | ai.onnx(7-8,9-12,13+) | |
|
||||
| GreaterOrEqual | ai.onnx(12-15,16+) | |
|
||||
| GridSample | ai.onnx(16-19); com.ms.internal.nhwc(16-19) | |
|
||||
| GroupQueryAttention | com.microsoft(1+) | |
|
||||
| HardSigmoid | ai.onnx(6+) | |
|
||||
| If | ai.onnx(1-10,11-12,13-18,19-20,21+) | |
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ import { gather, parseGatherAttributes } from './ops/gather';
|
|||
import { gatherBlockQuantized, parseGatherBlockQuantizedAttributes } from './ops/gather-block-quantized';
|
||||
import { gatherElements, parseGatherElementsAttributes } from './ops/gather-elements';
|
||||
import { gemm, parseGemmAttributes } from './ops/gemm';
|
||||
import { gridSample, parseGridSampleAttributes } from './ops/grid-sample';
|
||||
import { groupQueryAttention } from './ops/group-query-attention';
|
||||
import { instanceNorm } from './ops/instance-norm';
|
||||
import { layerNorm } from './ops/layer-norm';
|
||||
|
|
@ -104,6 +105,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
|
|||
['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]],
|
||||
['Greater', [binaryOps.greater]],
|
||||
['GreaterOrEqual', [binaryOps.greaterOrEqual]],
|
||||
['GridSample', [gridSample, parseGridSampleAttributes]],
|
||||
['GroupQueryAttention', [groupQueryAttention]],
|
||||
['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]],
|
||||
['InstanceNormalization', [instanceNorm]],
|
||||
|
|
|
|||
279
js/web/lib/wasm/jsep/webgpu/ops/grid-sample.ts
Normal file
279
js/web/lib/wasm/jsep/webgpu/ops/grid-sample.ts
Normal file
|
|
@ -0,0 +1,279 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import { DataType } from '../../../wasm-common';
|
||||
import { TensorView } from '../../tensor-view';
|
||||
import { ShapeUtil } from '../../util';
|
||||
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
|
||||
import { ComputeContext, ProgramInfo, ProgramUniform } from '../types';
|
||||
|
||||
import { createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper } from './common';
|
||||
|
||||
let [idxN, idxC, idxH, idxW] = [0, 1, 2, 3]; // NCHW
|
||||
type Mode = 'bilinear' | 'nearest' | 'bicubic';
|
||||
type PaddingMode = 'zeros' | 'border' | 'reflection';
|
||||
type Format = 'NHWC' | 'NCHW';
|
||||
export interface GridSampeAttributes extends AttributeWithCacheKey {
|
||||
alignCorners: number;
|
||||
mode: Mode;
|
||||
paddingMode: PaddingMode;
|
||||
format: Format;
|
||||
}
|
||||
|
||||
const validateInputs = (inputs: readonly TensorView[]): void => {
|
||||
if (inputs[0].dims.length !== 4) {
|
||||
throw new Error('only 4-D tensor is supported.');
|
||||
}
|
||||
if (inputs[0].dims.length !== inputs[1].dims.length) {
|
||||
throw new Error('input dimensions must be equal to grid dimensions');
|
||||
}
|
||||
|
||||
if (inputs[0].dims.length - 2 !== inputs[1].dims[inputs[1].dims.length - 1]) {
|
||||
throw new Error(`last dimension of grid must be equal to ${inputs[0].dims.length - 2}`);
|
||||
}
|
||||
|
||||
if (inputs[0].dims[0] !== inputs[1].dims[0]) {
|
||||
throw new Error('grid batch size must match input batch size');
|
||||
}
|
||||
};
|
||||
|
||||
const gsGetCubicCoeffs = `
|
||||
fn gs_get_cubic_coeffs(x: f32) -> vec4<f32> {
|
||||
let cubic_alpha = -0.75f;
|
||||
let x_abs = abs(x);
|
||||
var coeffs: vec4<f32>;
|
||||
coeffs[0] = (((cubic_alpha * (x_abs + 1) - 5 * cubic_alpha) * (x_abs + 1) + 8 * cubic_alpha) * (x_abs + 1) - 4 * cubic_alpha);
|
||||
coeffs[1] = (((cubic_alpha + 2) * x_abs - (cubic_alpha + 3)) * x_abs * x_abs + 1);
|
||||
coeffs[2] = (((cubic_alpha + 2) * (1 - x_abs) - (cubic_alpha + 3)) * (1 - x_abs) * (1 - x_abs) + 1);
|
||||
coeffs[3] = (((cubic_alpha * (2 - x_abs) - 5 * cubic_alpha) * (2 - x_abs) + 8 * cubic_alpha) * (2 - x_abs) - 4 * cubic_alpha);
|
||||
return coeffs;
|
||||
}
|
||||
`;
|
||||
|
||||
const gsBicubicInterpolate = (dataType: string): string => `
|
||||
fn gs_bicubic_interpolate(p: mat4x4<${dataType}>, x: f32, y: f32) -> ${dataType} {
|
||||
var v: vec4<f32>;
|
||||
var coeffs = gs_get_cubic_coeffs(x);
|
||||
for (var i = 0; i < 4; i++) {
|
||||
v[i] = coeffs[0] * p[i][0] + coeffs[1] * p[i][1] + coeffs[2] * p[i][2] + coeffs[3] * p[i][3];
|
||||
}
|
||||
coeffs = gs_get_cubic_coeffs(y);
|
||||
let pixel = ${dataType}(coeffs[0] * v[0] + coeffs[1] * v[1] + coeffs[2] * v[2] + coeffs[3] * v[3]);
|
||||
return pixel;
|
||||
}
|
||||
`;
|
||||
|
||||
const gsDenormalize = (attributes: GridSampeAttributes): string => `
|
||||
fn gs_denormalize(n: f32, length: i32) -> f32 {
|
||||
${
|
||||
attributes.alignCorners === 0
|
||||
? `
|
||||
// alignCorners: false => [-1, 1] to [-0.5, length - 0.5]
|
||||
return ((n + 1.0) * f32(length) - 1.0) / 2.0;
|
||||
`
|
||||
: `
|
||||
// alignCorners: true => [-1, 1] to [0, length - 1]
|
||||
return (n + 1.0) / 2.0 * (f32(length - 1));
|
||||
`
|
||||
}
|
||||
}
|
||||
`;
|
||||
|
||||
const gsReflect = (attributes: GridSampeAttributes): string => `
|
||||
${
|
||||
attributes.paddingMode === 'reflection'
|
||||
? `
|
||||
fn gs_reflect(x: i32, x_min: f32, x_max: f32) -> u32 {
|
||||
var dx = 0.0;
|
||||
var fx = f32(x);
|
||||
let range = x_max - x_min;
|
||||
if (fx < x_min) {
|
||||
dx = x_min - fx;
|
||||
let n = u32(dx / range);
|
||||
let r = dx - f32(n) * range;
|
||||
if (n % 2 == 0) {
|
||||
fx = x_min + r;
|
||||
} else {
|
||||
fx = x_max - r;
|
||||
}
|
||||
} else if (fx > x_max) {
|
||||
dx = fx - x_max;
|
||||
let n = u32(dx / range);
|
||||
let r = dx - f32(n) * range;
|
||||
if (n % 2 == 0) {
|
||||
fx = x_max - r;
|
||||
} else {
|
||||
fx = x_min + r;
|
||||
}
|
||||
}
|
||||
return u32(fx);
|
||||
}`
|
||||
: ''
|
||||
}
|
||||
`;
|
||||
|
||||
const pixelAtGrid = (input: IndicesHelper, dataType: string, attributes: GridSampeAttributes): string =>
|
||||
`
|
||||
fn pixel_at_grid(r: i32, c: i32, H: i32, W: i32, batch: u32, channel: u32, border: vec4<f32>) -> ${dataType} {
|
||||
var pixel = ${dataType}(0);
|
||||
var indices = vec4<u32>(0);
|
||||
indices[${idxN}] = batch;
|
||||
indices[${idxC}] = channel;` +
|
||||
(() => {
|
||||
switch (attributes.paddingMode) {
|
||||
case 'zeros':
|
||||
return `
|
||||
if (r >= 0 && r < H && c >=0 && c < W) {
|
||||
indices[${idxH}] = u32(r);
|
||||
indices[${idxW}] = u32(c);
|
||||
}
|
||||
`;
|
||||
case 'border':
|
||||
return `
|
||||
indices[${idxH}] = u32(clamp(r, 0, H - 1));
|
||||
indices[${idxW}] = u32(clamp(c, 0, W - 1));
|
||||
`;
|
||||
case 'reflection':
|
||||
return `
|
||||
indices[${idxH}] = gs_reflect(r, border[1], border[3]);
|
||||
indices[${idxW}] = gs_reflect(c, border[0], border[2]);
|
||||
`;
|
||||
default:
|
||||
throw new Error(`padding mode ${attributes.paddingMode} is not supported`);
|
||||
}
|
||||
})() +
|
||||
`
|
||||
return ${input.getByIndices('indices')};
|
||||
}
|
||||
`;
|
||||
|
||||
const computePixel = (output: IndicesHelper, dataType: string, attributes: GridSampeAttributes): string =>
|
||||
(() => {
|
||||
switch (attributes.mode) {
|
||||
case 'nearest':
|
||||
return `
|
||||
let result = pixel_at_grid(i32(round(y)), i32(round(x)), H_in, W_in, indices[${idxN}], indices[${idxC}], border);
|
||||
`;
|
||||
case 'bilinear':
|
||||
return `
|
||||
let x1 = i32(floor(x));
|
||||
let y1 = i32(floor(y));
|
||||
let x2 = x1 + 1;
|
||||
let y2 = y1 + 1;
|
||||
|
||||
let p11 = pixel_at_grid(y1, x1, H_in, W_in, indices[${idxN}], indices[${idxC}], border);
|
||||
let p12 = pixel_at_grid(y1, x2, H_in, W_in, indices[${idxN}], indices[${idxC}], border);
|
||||
let p21 = pixel_at_grid(y2, x1, H_in, W_in, indices[${idxN}], indices[${idxC}], border);
|
||||
let p22 = pixel_at_grid(y2, x2, H_in, W_in, indices[${idxN}], indices[${idxC}], border);
|
||||
|
||||
let dx2 = ${dataType}(f32(x2) - x);
|
||||
let dx1 = ${dataType}(x - f32(x1));
|
||||
let dy2 = ${dataType}(f32(y2) - y);
|
||||
let dy1 = ${dataType}(y - f32(y1));
|
||||
let result = dy2 * (dx2 * p11 + dx1 * p12) + dy1 * (dx2 * p21 + dx1 * p22);
|
||||
`;
|
||||
case 'bicubic':
|
||||
return `
|
||||
let x0 = i32(floor(x)) - 1;
|
||||
let y0 = i32(floor(y)) - 1;
|
||||
var p: mat4x4<${dataType}>;
|
||||
for (var h = 0; h < 4; h++) {
|
||||
for (var w = 0; w < 4; w++) {
|
||||
p[h][w] = pixel_at_grid(h + y0, w + x0, H_in, W_in, indices[${idxN}], indices[${idxC}], border);
|
||||
}
|
||||
}
|
||||
|
||||
let dx = x - f32(x0 + 1);
|
||||
let dy = y - f32(y0 + 1);
|
||||
let result = gs_bicubic_interpolate(p, dx, dy);
|
||||
`;
|
||||
default:
|
||||
throw new Error(`mode ${attributes.mode} is not supported`);
|
||||
}
|
||||
})() + `${output.setByOffset('global_idx', 'result')}`;
|
||||
|
||||
const createGridSampleProgramInfo = (inputs: readonly TensorView[], attributes: GridSampeAttributes): ProgramInfo => {
|
||||
const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length);
|
||||
// discard last dimension for using vec2 to access grid data
|
||||
const gridShape = [inputs[1].dims[0], inputs[1].dims[1], inputs[1].dims[2]];
|
||||
const grid = inputVariable('grid', inputs[1].dataType, gridShape.length, 2);
|
||||
let outputShape = [inputs[0].dims[0], inputs[0].dims[1], inputs[1].dims[1], inputs[1].dims[2]];
|
||||
if (attributes.format === 'NHWC') {
|
||||
outputShape = [inputs[0].dims[0], inputs[1].dims[1], inputs[1].dims[2], inputs[0].dims[3]];
|
||||
[idxN, idxC, idxH, idxW] = [0, 3, 1, 2];
|
||||
}
|
||||
const output = outputVariable('output', inputs[0].dataType, outputShape.length);
|
||||
const dataType = x.type.value;
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{ type: DataType.uint32, data: outputSize },
|
||||
...createTensorShapeVariables(inputs[0].dims, gridShape, outputShape),
|
||||
];
|
||||
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
${shaderHelper.registerUniform('output_size', 'u32').declareVariables(x, grid, output)}
|
||||
${gsGetCubicCoeffs}
|
||||
${gsBicubicInterpolate(dataType)}
|
||||
${gsDenormalize(attributes)}
|
||||
${gsReflect(attributes)}
|
||||
${pixelAtGrid(x, dataType, attributes)}
|
||||
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
|
||||
let H_in = i32(uniforms.x_shape[${idxH}]);
|
||||
let W_in = i32(uniforms.x_shape[${idxW}]);
|
||||
|
||||
${
|
||||
attributes.alignCorners === 0
|
||||
? `
|
||||
let x_min = -0.5;
|
||||
let x_max = f32(W_in) - 0.5;
|
||||
let y_min = -0.5;
|
||||
let y_max = f32(H_in) - 0.5;
|
||||
`
|
||||
: `
|
||||
let x_min = 0.0;
|
||||
let x_max = f32(W_in) - 1.0;
|
||||
let y_min = 0.0;
|
||||
let y_max = f32(H_in) - 1.0;
|
||||
`
|
||||
};
|
||||
let border = vec4<f32>(x_min, y_min, x_max, y_max);
|
||||
|
||||
let indices = ${output.offsetToIndices('global_idx')};
|
||||
var grid_indices = vec3<u32>(indices[${idxN}], indices[${idxH}], indices[${idxW}]);
|
||||
let nxy = ${grid.getByIndices('grid_indices')};
|
||||
var x = gs_denormalize(f32(nxy[0]), W_in);
|
||||
var y = gs_denormalize(f32(nxy[1]), H_in);
|
||||
|
||||
${computePixel(output, dataType, attributes)}
|
||||
}`;
|
||||
|
||||
return {
|
||||
name: 'GridSample',
|
||||
shaderCache: { hint: `${attributes.cacheKey}`, inputDependencies: ['type', 'type'] },
|
||||
getRunData: (inputs) => {
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
return {
|
||||
outputs: [{ dims: outputShape, dataType: inputs[0].dataType }],
|
||||
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
|
||||
programUniforms,
|
||||
};
|
||||
},
|
||||
getShaderSource,
|
||||
};
|
||||
};
|
||||
|
||||
export const gridSample = (context: ComputeContext, attributes: GridSampeAttributes): void => {
|
||||
validateInputs(context.inputs);
|
||||
context.compute(createGridSampleProgramInfo(context.inputs, attributes));
|
||||
};
|
||||
|
||||
export const parseGridSampleAttributes = (attributes: Record<string, unknown>): GridSampeAttributes =>
|
||||
createAttributeWithCacheKey({
|
||||
alignCorners: attributes.align_corners as number,
|
||||
mode: attributes.mode as Mode,
|
||||
paddingMode: attributes.padding_mode as PaddingMode,
|
||||
format: attributes.format as Format,
|
||||
});
|
||||
|
|
@ -570,14 +570,14 @@
|
|||
"test_greater_equal_expanded",
|
||||
"test_greater_equal",
|
||||
"test_greater",
|
||||
// // "test_gridsample_aligncorners_true",
|
||||
// // "test_gridsample_bicubic",
|
||||
// // "test_gridsample_bilinear",
|
||||
// // "test_gridsample_border_padding",
|
||||
// // "test_gridsample_nearest",
|
||||
// // "test_gridsample_reflection_padding",
|
||||
// // "test_gridsample_zeros_padding",
|
||||
// // "test_gridsample",
|
||||
"test_gridsample_aligncorners_true",
|
||||
"test_gridsample_bicubic",
|
||||
"test_gridsample_bilinear",
|
||||
"test_gridsample_border_padding",
|
||||
"test_gridsample_nearest",
|
||||
"test_gridsample_reflection_padding",
|
||||
"test_gridsample_zeros_padding",
|
||||
"test_gridsample",
|
||||
// // "test_gru_batchwise",
|
||||
// // "test_gru_defaults",
|
||||
// // "test_gru_seq_length",
|
||||
|
|
|
|||
|
|
@ -400,6 +400,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 2
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, int8_t, DequantizeLinear);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, int32_t, DequantizeLinear);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, 19, GridSample);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 16, 19, GridSample);
|
||||
|
||||
std::unique_ptr<KernelRegistry> RegisterKernels() {
|
||||
auto kernel_registry = std::make_unique<onnxruntime::KernelRegistry>();
|
||||
|
||||
|
|
@ -728,6 +731,9 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, uint8_t, DequantizeLinear)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, int8_t, DequantizeLinear)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, int32_t, DequantizeLinear)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, 19, GridSample)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 16, 19, GridSample)>,
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
|
|
|
|||
30
onnxruntime/core/providers/js/operators/grid_sample.cc
Normal file
30
onnxruntime/core/providers/js/operators/grid_sample.cc
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "grid_sample.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace js {
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
GridSample,
|
||||
kMSInternalNHWCDomain,
|
||||
16, 19,
|
||||
kJsExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T1", JsepSupportedDataTypes())
|
||||
.TypeConstraint("T2", JsepSupportedFloatTypes()),
|
||||
GridSample<true>);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
GridSample,
|
||||
kOnnxDomain,
|
||||
16, 19,
|
||||
kJsExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T1", JsepSupportedDataTypes())
|
||||
.TypeConstraint("T2", JsepSupportedFloatTypes()),
|
||||
GridSample<false>);
|
||||
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
32
onnxruntime/core/providers/js/operators/grid_sample.h
Normal file
32
onnxruntime/core/providers/js/operators/grid_sample.h
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/js/js_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace js {
|
||||
|
||||
template <bool is_channels_last>
|
||||
class GridSample : public JsKernel {
|
||||
public:
|
||||
GridSample(const OpKernelInfo& info) : JsKernel(info) {
|
||||
int64_t align_corners = info.GetAttrOrDefault<int64_t>("align_corners", 0);
|
||||
std::string mode = info.GetAttrOrDefault<std::string>("mode", "linear");
|
||||
std::string padding_mode = info.GetAttrOrDefault<std::string>("padding_mode", "zeros");
|
||||
int64_t channels_last = is_channels_last ? 1 : info.GetAttrOrDefault<int64_t>("channels_last", 0);
|
||||
|
||||
JSEP_INIT_KERNEL_ATTRIBUTE(GridSample, ({
|
||||
"align_corners" : $1,
|
||||
"mode" : UTF8ToString($2),
|
||||
"padding_mode" : UTF8ToString($3),
|
||||
"format" : $4 ? "NHWC" : "NCHW"
|
||||
}),
|
||||
static_cast<int32_t>(align_corners), mode.c_str(),
|
||||
padding_mode.c_str(), static_cast<int32_t>(channels_last));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue