mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
[WebGPU/JS] Added Pad operator support (#16928)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
e11849e716
commit
198d468849
7 changed files with 379 additions and 5 deletions
|
|
@ -59,6 +59,7 @@ Do not modify directly.*
|
|||
| Mul | ai.onnx(7-12,13,14+) | |
|
||||
| Neg | ai.onnx(6-12,13+) | |
|
||||
| Not | ai.onnx(1+) | |
|
||||
| Pad | ai.onnx(2-10,11-12,13-17,18,19+) | |
|
||||
| Pow | ai.onnx(7-11,12,13-14,15+) | |
|
||||
| Reciprocal | ai.onnx(6-12,13+) | |
|
||||
| ReduceL1 | ai.onnx(1-10,11-12,13-17,18+) | |
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ import {gemm, parseGemmAttributes} from './ops/gemm';
|
|||
import {instanceNorm, parseInstanceNormAttributes} from './ops/instance-norm';
|
||||
import {layerNorm, parseLayerNormAttributes} from './ops/layer-norm';
|
||||
import {matMul} from './ops/matmul';
|
||||
import {pad, parsePadAttributes} from './ops/pad';
|
||||
import * as pool from './ops/pool';
|
||||
import {parseReduceAttributes, reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce';
|
||||
import {parseResizeAttributes, resize} from './ops/resize';
|
||||
|
|
@ -80,6 +81,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
|
|||
['Mul', [binaryOps.mul]],
|
||||
['Neg', [unaryOps.neg]],
|
||||
['Not', [unaryOps.not]],
|
||||
['Pad', [pad, parsePadAttributes]],
|
||||
['Pow', [binaryOps.pow]],
|
||||
['Reciprocal', [unaryOps.reciprocal]],
|
||||
['ReduceMin', [reduceMin, parseReduceAttributes]],
|
||||
|
|
|
|||
252
js/web/lib/wasm/jsep/webgpu/ops/pad.ts
Normal file
252
js/web/lib/wasm/jsep/webgpu/ops/pad.ts
Normal file
|
|
@ -0,0 +1,252 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {DataType} from '../../../wasm-common';
|
||||
import {TensorView} from '../../tensor';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';
|
||||
|
||||
import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
|
||||
|
||||
export interface PadAttributes extends AttributeWithCacheKey {
|
||||
// 0-constant, 1-reflect, 2-edge, 3-wrap
|
||||
readonly mode: number;
|
||||
readonly value: number;
|
||||
readonly pads: number[];
|
||||
}
|
||||
|
||||
const validateInputs = (inputs: readonly TensorView[]): void => {
|
||||
if (!inputs || inputs.length < 1) {
|
||||
throw new Error('Too few inputs');
|
||||
}
|
||||
if (inputs[0].dataType !== DataType.float) {
|
||||
throw new Error('Input type must be float.');
|
||||
}
|
||||
|
||||
if (inputs.length >= 2) {
|
||||
let validPads = inputs[0].dims.length * 2 === inputs[1].dims[0];
|
||||
if (inputs.length === 4) {
|
||||
validPads = inputs[3].dims[0] * 2 === inputs[1].dims[0];
|
||||
}
|
||||
if (!validPads) {
|
||||
throw new Error('The pads should be a 1D tensor of shape [2 * input_rank] or [2 * num_axes].');
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const getPadConstant =
|
||||
(output: IndicesHelper, outputDims: readonly number[], inputDims: readonly number[],
|
||||
inputStrides: readonly number[], pads: number[], dataType: string, constantValue: number): string => {
|
||||
const inputRank = inputDims.length;
|
||||
|
||||
let block = '';
|
||||
for (let i = inputRank - 1; i >= 0; --i) {
|
||||
block += `
|
||||
k = i32(${output.indicesGet('indices', i)}) - ${pads[i]};
|
||||
if (k < 0) {
|
||||
break;
|
||||
}
|
||||
if (k >= ${inputDims[i]}) {
|
||||
break;
|
||||
}
|
||||
offset += k * ${inputStrides[i]};
|
||||
`;
|
||||
}
|
||||
|
||||
return `
|
||||
value = ${dataType}(${constantValue});
|
||||
for (var i = 0; i < 1; i++) {
|
||||
var offset = 0;
|
||||
var k = 0;
|
||||
${block}
|
||||
value = x[offset];
|
||||
}
|
||||
`;
|
||||
};
|
||||
|
||||
const getPadReflect =
|
||||
(output: IndicesHelper, outputDims: readonly number[], inputDims: readonly number[],
|
||||
inputStrides: readonly number[], pads: number[]): string => {
|
||||
const inputRank = inputDims.length;
|
||||
|
||||
let block = '';
|
||||
for (let i = inputRank - 1; i >= 0; --i) {
|
||||
block += `
|
||||
k = i32(${output.indicesGet('indices', i)}) - ${pads[i]};
|
||||
if (k < 0) {
|
||||
k = -k;
|
||||
}
|
||||
{
|
||||
let _2n_1 = ${2 * (inputDims[i] - 1)};
|
||||
k = k % _2n_1;
|
||||
if(k >= ${inputDims[i]}) {
|
||||
k = _2n_1 - k;
|
||||
}
|
||||
}
|
||||
offset += k * ${inputStrides[i]};
|
||||
`;
|
||||
}
|
||||
|
||||
return `
|
||||
var offset = 0;
|
||||
var k = 0;
|
||||
${block}
|
||||
value = x[offset];
|
||||
`;
|
||||
};
|
||||
|
||||
const getPadEdge =
|
||||
(output: IndicesHelper, outputDims: readonly number[], inputDims: readonly number[],
|
||||
inputStrides: readonly number[], pads: number[]): string => {
|
||||
const inputRank = inputDims.length;
|
||||
|
||||
let block = '';
|
||||
for (let i = inputRank - 1; i >= 0; --i) {
|
||||
block += `
|
||||
k = i32(${output.indicesGet('indices', i)}) - ${pads[i]};
|
||||
if (k < 0) {
|
||||
k = 0;
|
||||
}
|
||||
if (k >= ${inputDims[i]}) {
|
||||
k = ${inputDims[i] - 1};
|
||||
}
|
||||
offset += k * ${inputStrides[i]};
|
||||
`;
|
||||
}
|
||||
|
||||
return `
|
||||
var offset = 0;
|
||||
var k = 0;
|
||||
${block}
|
||||
value = x[offset];
|
||||
`;
|
||||
};
|
||||
|
||||
const getPadWrap =
|
||||
(output: IndicesHelper, outputDims: readonly number[], inputDims: readonly number[],
|
||||
inputStrides: readonly number[], pads: number[]): string => {
|
||||
const inputRank = inputDims.length;
|
||||
|
||||
let block = '';
|
||||
for (let i = inputRank - 1; i >= 0; --i) {
|
||||
block += `
|
||||
k = i32(${output.indicesGet('indices', i)}) - ${pads[i]};
|
||||
if (k < 0) {
|
||||
k += ${inputDims[i]};
|
||||
}
|
||||
if (k >= ${inputDims[i]}) {
|
||||
k -= ${inputDims[i]};
|
||||
}
|
||||
offset += k * ${inputStrides[i]};
|
||||
`;
|
||||
}
|
||||
|
||||
return `
|
||||
var offset = 0;
|
||||
var k = 0;
|
||||
${block}
|
||||
value = x[offset];
|
||||
`;
|
||||
};
|
||||
|
||||
const getPadSnippet =
|
||||
(output: IndicesHelper, outputDims: readonly number[], inputDims: readonly number[],
|
||||
inputStrides: readonly number[], attributes: PadAttributes, dataType: string): string => {
|
||||
switch (attributes.mode) {
|
||||
case 0:
|
||||
return getPadConstant(
|
||||
output, outputDims, inputDims, inputStrides, attributes.pads, dataType, attributes.value);
|
||||
case 1:
|
||||
return getPadReflect(output, outputDims, inputDims, inputStrides, attributes.pads);
|
||||
case 2:
|
||||
return getPadEdge(output, outputDims, inputDims, inputStrides, attributes.pads);
|
||||
case 3:
|
||||
return getPadWrap(output, outputDims, inputDims, inputStrides, attributes.pads);
|
||||
default:
|
||||
throw new Error('Invalid mode');
|
||||
}
|
||||
};
|
||||
|
||||
const generatePadCode =
|
||||
(shaderHelper: ShaderHelper, inputs: readonly TensorView[], attributes: PadAttributes, dataType: string):
|
||||
string => {
|
||||
const inputDims = inputs[0].dims;
|
||||
const outputDims = ShapeUtil.padShape(inputDims.slice(), attributes.pads);
|
||||
const outputSize = ShapeUtil.size(outputDims);
|
||||
const inputStrides = ShapeUtil.computeStrides(inputDims);
|
||||
|
||||
const output = outputVariable('output', inputs[0].dataType, outputDims);
|
||||
const input = inputVariable('x', inputs[0].dataType, inputDims);
|
||||
|
||||
const padSnippet = getPadSnippet(output, outputDims, inputDims, inputStrides, attributes, dataType);
|
||||
const padCode = `
|
||||
${shaderHelper.declareVariables(input, output)}
|
||||
${output.impl()}
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
|
||||
|
||||
let indices = ${output.offsetToIndices('global_idx')};
|
||||
|
||||
var value = ${dataType}(0);
|
||||
${padSnippet}
|
||||
output[global_idx] = value;
|
||||
}`;
|
||||
return padCode;
|
||||
};
|
||||
|
||||
const createPadProgramInfo =
|
||||
(inputs: readonly TensorView[], metadata: ProgramMetadata, attributes: PadAttributes): ProgramInfo => {
|
||||
const outputShape = ShapeUtil.padShape(inputs[0].dims.slice(), attributes.pads);
|
||||
return {
|
||||
...metadata,
|
||||
outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}],
|
||||
getShaderSource: shaderHelper => generatePadCode(shaderHelper, inputs, attributes, 'f32'),
|
||||
dispatchGroup: () => ({x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)})
|
||||
};
|
||||
};
|
||||
|
||||
const createPadAttributesFromInputs = (inputs: readonly TensorView[], attributes: PadAttributes): PadAttributes => {
|
||||
if (inputs.length > 1) {
|
||||
const bigInt64Pads = inputs[1].getBigInt64Array();
|
||||
const value = (inputs.length >= 3) ? inputs[2].getFloat32Array()[0] : 0.0;
|
||||
|
||||
const inputRank = inputs[0].dims.length;
|
||||
const updatePads = new Int32Array(2 * inputRank).fill(0);
|
||||
if (inputs.length >= 4) {
|
||||
const axes = inputs[3].getBigInt64Array();
|
||||
for (let i = 0; i < axes.length; i++) {
|
||||
updatePads[Number(axes[i])] = Number(bigInt64Pads[i]);
|
||||
updatePads[Number(axes[i]) + inputRank] = Number(bigInt64Pads[i + axes.length]);
|
||||
}
|
||||
} else {
|
||||
bigInt64Pads.forEach((i, v) => updatePads[Number(i)] = (Number(v)));
|
||||
}
|
||||
|
||||
const pads: number[] = [];
|
||||
updatePads.forEach(v => pads.push(v));
|
||||
|
||||
return createAttributeWithCacheKey({mode: attributes.mode, value, pads});
|
||||
} else {
|
||||
return attributes;
|
||||
}
|
||||
};
|
||||
|
||||
const createPadProgramInfoLoader = (inputs: readonly TensorView[], attributes: PadAttributes): ProgramInfoLoader => {
|
||||
const updatedAttributes = createPadAttributesFromInputs(inputs, attributes);
|
||||
const metadata:
|
||||
ProgramMetadata = {name: 'Pad', inputTypes: [GpuDataType.default], cacheHint: updatedAttributes.cacheKey};
|
||||
return {...metadata, get: () => createPadProgramInfo(inputs, metadata, updatedAttributes)};
|
||||
};
|
||||
|
||||
export const pad = (context: ComputeContext, attributes: PadAttributes): void => {
|
||||
validateInputs(context.inputs);
|
||||
context.compute(createPadProgramInfoLoader(context.inputs, attributes), {inputs: [0]});
|
||||
};
|
||||
|
||||
export const parsePadAttributes = (attributes: Record<string, unknown>): PadAttributes => {
|
||||
const mode = attributes.mode as number;
|
||||
const value = attributes.value as number;
|
||||
const pads = attributes.pads as number[];
|
||||
return createAttributeWithCacheKey({mode, value, pads});
|
||||
};
|
||||
|
|
@ -505,7 +505,7 @@
|
|||
// // "test_dynamicquantizelinear_min_adjusted_expanded",
|
||||
// // "test_dynamicquantizelinear_min_adjusted",
|
||||
// // "test_dynamicquantizelinear",
|
||||
// // "test_edge_pad",
|
||||
"test_edge_pad",
|
||||
// "test_einsum_batch_diagonal",
|
||||
// "test_einsum_batch_matmul",
|
||||
// "test_einsum_inner_prod",
|
||||
|
|
@ -965,7 +965,7 @@
|
|||
"test_reduce_sum_square_keepdims_random",
|
||||
"test_reduce_sum_square_negative_axes_keepdims_example",
|
||||
"test_reduce_sum_square_negative_axes_keepdims_random",
|
||||
// // "test_reflect_pad",
|
||||
"test_reflect_pad",
|
||||
"test_relu",
|
||||
// "test_reshape_allowzero_reordered",
|
||||
"test_reshape_extended_dims",
|
||||
|
|
@ -1308,7 +1308,8 @@
|
|||
"test_unsqueeze_three_axes",
|
||||
"test_unsqueeze_two_axes",
|
||||
"test_unsqueeze_unsorted_axes",
|
||||
"test_unsqueeze"
|
||||
"test_unsqueeze",
|
||||
"test_wrap_pad"
|
||||
// "test_upsample_nearest",
|
||||
// "test_where_example",
|
||||
// "test_where_long_example",
|
||||
|
|
@ -1361,8 +1362,8 @@
|
|||
"reduce-min.jsonc",
|
||||
"relu.jsonc",
|
||||
"gelu.jsonc",
|
||||
//"pad.jsonc",
|
||||
//"pad-big.jsonc",
|
||||
"pad.jsonc",
|
||||
"pad-big.jsonc",
|
||||
"pow.jsonc",
|
||||
"pow_int32.jsonc",
|
||||
"pow-big-number.jsonc",
|
||||
|
|
|
|||
|
|
@ -321,6 +321,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6
|
|||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, Einsum);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 2, 10, Pad);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Pad);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Pad);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Pad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Pad);
|
||||
|
||||
std::unique_ptr<KernelRegistry> RegisterKernels() {
|
||||
auto kernel_registry = std::make_unique<onnxruntime::KernelRegistry>();
|
||||
|
||||
|
|
@ -577,6 +583,12 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
|
|||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, Einsum)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 2, 10, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Pad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Pad)>,
|
||||
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
|
|
|
|||
72
onnxruntime/core/providers/js/operators/pad.cc
Normal file
72
onnxruntime/core/providers/js/operators/pad.cc
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/js/js_kernel.h"
|
||||
|
||||
#include "pad.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace js {
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Pad,
|
||||
kOnnxDomain,
|
||||
2,
|
||||
10,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Pad);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Pad,
|
||||
kOnnxDomain,
|
||||
11,
|
||||
12,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
|
||||
.InputMemoryType(OrtMemTypeCPU, 1)
|
||||
.InputMemoryType(OrtMemTypeCPU, 2)
|
||||
.InputMemoryType(OrtMemTypeCPU, 3),
|
||||
Pad);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Pad,
|
||||
kOnnxDomain,
|
||||
13,
|
||||
17,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
|
||||
.InputMemoryType(OrtMemTypeCPU, 1)
|
||||
.InputMemoryType(OrtMemTypeCPU, 2)
|
||||
.InputMemoryType(OrtMemTypeCPU, 3),
|
||||
Pad);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Pad,
|
||||
kOnnxDomain,
|
||||
18,
|
||||
18,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
|
||||
.InputMemoryType(OrtMemTypeCPU, 1)
|
||||
.InputMemoryType(OrtMemTypeCPU, 2)
|
||||
.InputMemoryType(OrtMemTypeCPU, 3),
|
||||
Pad);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
Pad,
|
||||
kOnnxDomain,
|
||||
19,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
|
||||
.InputMemoryType(OrtMemTypeCPU, 1)
|
||||
.InputMemoryType(OrtMemTypeCPU, 2)
|
||||
.InputMemoryType(OrtMemTypeCPU, 3),
|
||||
Pad);
|
||||
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
34
onnxruntime/core/providers/js/operators/pad.h
Normal file
34
onnxruntime/core/providers/js/operators/pad.h
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/js/js_kernel.h"
|
||||
#include "core/providers/cpu/tensor/padbase.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace js {
|
||||
|
||||
class Pad : public JsKernel, public PadBase {
|
||||
public:
|
||||
explicit Pad(const OpKernelInfo& info) : JsKernel(info), PadBase(info) {
|
||||
std::vector<int32_t> pads;
|
||||
if (!is_dynamic_) {
|
||||
pads.resize(pads_.size());
|
||||
for (size_t i = 0; i < pads_.size(); ++i) {
|
||||
pads[i] = gsl::narrow_cast<int32_t>(pads_[i]);
|
||||
}
|
||||
}
|
||||
|
||||
JSEP_INIT_KERNEL_ATTRIBUTE(Pad, ({"mode" : $1,
|
||||
"value" : $2,
|
||||
"pads" : $3 ? Array.from(HEAP32.subarray($4, $4 + $3)) : []}),
|
||||
static_cast<int32_t>(mode_),
|
||||
static_cast<double>(value_),
|
||||
gsl::narrow_cast<int32_t>(pads.size()),
|
||||
reinterpret_cast<int32_t>((pads.size() > 0) ? pads.data() : nullptr) >> 2);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue