mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
[JS/WebGPU] Support GatherElements kernel (#17243)
### Description As title ### Motivation and Context Improve WebGPU kernel coverage
This commit is contained in:
parent
53169f59e5
commit
cbd97515cd
8 changed files with 418 additions and 3 deletions
|
|
@ -38,6 +38,7 @@ Do not modify directly.*
|
|||
| Flatten | ai.onnx(1-8,9-10,11-12,13+) | |
|
||||
| Floor | ai.onnx(6-12,13+) | |
|
||||
| Gather | ai.onnx(1-10,11-12,13+) | |
|
||||
| GatherElements | ai.onnx(11-12,13+) | |
|
||||
| Gelu | com.microsoft(1+) | |
|
||||
| Gemm | ai.onnx(7-8,9-10,11-12,13+) | |
|
||||
| GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import {conv, parseConvAttributes} from './ops/conv';
|
|||
import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose';
|
||||
import {expand} from './ops/expand';
|
||||
import {gather, parseGatherAttributes} from './ops/gather';
|
||||
import {gatherElements, parseGatherElementsAttributes} from './ops/gather-elements';
|
||||
import {gemm, parseGemmAttributes} from './ops/gemm';
|
||||
import {instanceNorm, parseInstanceNormAttributes} from './ops/instance-norm';
|
||||
import {layerNorm, parseLayerNormAttributes} from './ops/layer-norm';
|
||||
|
|
@ -58,6 +59,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
|
|||
['Expand', [expand]],
|
||||
['Floor', [unaryOps.floor]],
|
||||
['Gather', [gather, parseGatherAttributes]],
|
||||
['GatherElements', [gatherElements, parseGatherElementsAttributes]],
|
||||
['Gelu', [unaryOps.gelu]],
|
||||
['Gemm', [gemm, parseGemmAttributes]],
|
||||
['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]],
|
||||
|
|
|
|||
110
js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts
Normal file
110
js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {TensorView} from '../../tensor';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types';
|
||||
|
||||
import {inputVariable, outputVariable, ShaderHelper} from './common';
|
||||
|
||||
export interface GatherElementsAttributes extends AttributeWithCacheKey {
|
||||
axis: number;
|
||||
}
|
||||
|
||||
const validateInputs = (inputs: readonly TensorView[]): void => {
|
||||
if (!inputs || inputs.length !== 2) {
|
||||
throw new Error('GatherElements requires 2 inputs.');
|
||||
}
|
||||
|
||||
if (inputs[0].dims.length < 1) {
|
||||
throw new Error('GatherElements requires that the data input be rank >= 1.');
|
||||
}
|
||||
|
||||
if (inputs[0].dims.length !== inputs[1].dims.length) {
|
||||
throw new Error(`GatherElements requires that the data input and
|
||||
indices input tensors be of same rank.`);
|
||||
}
|
||||
};
|
||||
|
||||
const createGatherElementsProgramInfo =
|
||||
(metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: GatherElementsAttributes): ProgramInfo => {
|
||||
const inputShape = inputs[0].dims;
|
||||
const inputOutputDataType = inputs[0].dataType;
|
||||
const inputRank = inputShape.length;
|
||||
const inputStrides = ShapeUtil.computeStrides(inputShape);
|
||||
const inputSize = ShapeUtil.size(inputShape);
|
||||
|
||||
const indicesShape = inputs[1].dims;
|
||||
const indicesDataType = inputs[1].dataType;
|
||||
const indicesSize = ShapeUtil.size(indicesShape);
|
||||
|
||||
const axis = ShapeUtil.normalizeAxis(attributes.axis, inputRank);
|
||||
const axisDimLimit = inputShape[axis];
|
||||
|
||||
const outputShape = indicesShape.slice(0);
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
|
||||
const input = inputVariable('input', inputOutputDataType, inputShape);
|
||||
const indices = inputVariable('indices', indicesDataType, [indicesSize]);
|
||||
const output = outputVariable('output', inputOutputDataType, outputShape);
|
||||
|
||||
|
||||
// int64 indices would be treated as little endian i32 with assumption they fall in i32 limits
|
||||
// That assumption is safe as it's not possible to allocate >2gb buffer for input tensor
|
||||
// Input data will be treated as u32 or two u32 for 8-byte tensors
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
const inputStrides = array<u32, ${inputStrides.length}>(${inputStrides.map(i => `${i}u`).join(',')});
|
||||
${shaderHelper.declareVariables(input, indices, output)}
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
|
||||
|
||||
let outputIndices = ${output.offsetToIndices('global_idx')};
|
||||
|
||||
var idx = ${indices.getByOffset('global_idx')};
|
||||
if (idx < 0) {
|
||||
idx = idx + ${axisDimLimit};
|
||||
}
|
||||
|
||||
var srcOffset = u32(0);
|
||||
|
||||
for (var i = 0; i < ${inputShape.length}; i++) {
|
||||
if (i == ${axis}) {
|
||||
srcOffset += u32(idx) * inputStrides[i];
|
||||
} else {
|
||||
srcOffset += ${output.indicesGet('outputIndices', 'i')} * inputStrides[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Should never hit this with valid values in indices
|
||||
// This is a guard against malicious data in the indices input
|
||||
if (srcOffset < 0 || srcOffset >= ${inputSize}) {
|
||||
return;
|
||||
}
|
||||
|
||||
output[global_idx] = input[srcOffset];
|
||||
}`;
|
||||
|
||||
return {
|
||||
...metadata,
|
||||
outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}],
|
||||
getShaderSource,
|
||||
dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)})
|
||||
};
|
||||
};
|
||||
|
||||
export const parseGatherElementsAttributes = (attributes: Record<string, unknown>): GatherElementsAttributes =>
|
||||
createAttributeWithCacheKey({axis: attributes.axis as number});
|
||||
|
||||
export const gatherElements = (context: ComputeContext, attributes: GatherElementsAttributes): void => {
|
||||
const inputs = context.inputs;
|
||||
validateInputs(inputs);
|
||||
|
||||
const metadata = {
|
||||
name: 'GatherElements',
|
||||
inputTypes: [GpuDataType.default, GpuDataType.default],
|
||||
cacheHint: attributes.cacheKey,
|
||||
};
|
||||
|
||||
context.compute(createGatherElementsProgramInfo(metadata, context.inputs, attributes));
|
||||
};
|
||||
234
js/web/test/data/ops/gather-elements.jsonc
Normal file
234
js/web/test/data/ops/gather-elements.jsonc
Normal file
|
|
@ -0,0 +1,234 @@
|
|||
[
|
||||
{
|
||||
"name": "GatherElements float32 data + int32 indices-1",
|
||||
"operator": "GatherElements",
|
||||
"attributes": [{ "name": "axis", "data": 1, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "float32 data + int32 indices-1",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4],
|
||||
"dims": [2, 2],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [0, 0, 1, 0],
|
||||
"dims": [2, 2],
|
||||
"type": "int32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [1, 1, 4, 3],
|
||||
"dims": [2, 2],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "GatherElements float32 data + int32 indices-2",
|
||||
"operator": "GatherElements",
|
||||
"attributes": [{ "name": "axis", "data": 1, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "float32 data + int32 indices-2",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4],
|
||||
"dims": [2, 2],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [0, 1, 1, 0],
|
||||
"dims": [2, 2],
|
||||
"type": "int32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [1, 2, 4, 3],
|
||||
"dims": [2, 2],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "GatherElements float32 data + int64 indices - 1",
|
||||
"operator": "GatherElements",
|
||||
"attributes": [{ "name": "axis", "data": 1, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "float32 data + int64 indices - 1",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4],
|
||||
"dims": [2, 2],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [0, 0, -1, 0],
|
||||
"dims": [2, 2],
|
||||
"type": "int64"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [1, 1, 4, 3],
|
||||
"dims": [2, 2],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "GatherElements float32 data + int64 indices - 2",
|
||||
"operator": "GatherElements",
|
||||
"attributes": [{ "name": "axis", "data": 1, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "float32 data + int64 indices - 2",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4],
|
||||
"dims": [2, 2],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [0, 0, -2, 0],
|
||||
"dims": [2, 2],
|
||||
"type": "int64"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [1, 1, 3, 3],
|
||||
"dims": [2, 2],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "GatherElements int32 data + int32 indices-1",
|
||||
"operator": "GatherElements",
|
||||
"attributes": [{ "name": "axis", "data": 1, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "int32 data + int32 indices-1",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4],
|
||||
"dims": [2, 2],
|
||||
"type": "int32"
|
||||
},
|
||||
{
|
||||
"data": [0, 0, 1, 0],
|
||||
"dims": [2, 2],
|
||||
"type": "int32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [1, 1, 4, 3],
|
||||
"dims": [2, 2],
|
||||
"type": "int32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "GatherElements uint32 data + int32 indices-1",
|
||||
"operator": "GatherElements",
|
||||
"attributes": [{ "name": "axis", "data": 1, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "uint32 data + int32 indices-1",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4],
|
||||
"dims": [2, 2],
|
||||
"type": "uint32"
|
||||
},
|
||||
{
|
||||
"data": [0, 0, 1, 0],
|
||||
"dims": [2, 2],
|
||||
"type": "int32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [1, 1, 4, 3],
|
||||
"dims": [2, 2],
|
||||
"type": "uint32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "GatherElements float32 data + int32 indices-1 + Negative axis + Negative indices",
|
||||
"operator": "GatherElements",
|
||||
"attributes": [{ "name": "axis", "data": -1, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "GatherElements float32 data + int32 indices-1 + Negative axis + Negative indices",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4],
|
||||
"dims": [2, 2],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [0, 0, -1, 0],
|
||||
"dims": [2, 2],
|
||||
"type": "int32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [1, 1, 4, 3],
|
||||
"dims": [2, 2],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "GatherElements float32 data + int32 indices-3",
|
||||
"operator": "GatherElements",
|
||||
"attributes": [{ "name": "axis", "data": 0, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "GatherElements float32 data + int32 indices-3",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9],
|
||||
"dims": [3, 3],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1, 2, 0, 2, 0, 0],
|
||||
"dims": [2, 3],
|
||||
"type": "int32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [4, 8, 3, 7, 2, 3],
|
||||
"dims": [2, 3],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
|
@ -539,9 +539,9 @@
|
|||
"test_gather_1",
|
||||
"test_gather_2d_indices",
|
||||
"test_gather_negative_indices",
|
||||
// "test_gather_elements_0",
|
||||
// "test_gather_elements_1",
|
||||
// "test_gather_elements_negative_indices",
|
||||
"test_gather_elements_0",
|
||||
"test_gather_elements_1",
|
||||
"test_gather_elements_negative_indices",
|
||||
// "test_gather_negative_indices",
|
||||
// // "test_gathernd_example_float32",
|
||||
// // "test_gathernd_example_int32_batch_dim1",
|
||||
|
|
@ -1339,6 +1339,7 @@
|
|||
"exp.jsonc",
|
||||
"expand.jsonc",
|
||||
"floor.jsonc",
|
||||
"gather-elements.jsonc",
|
||||
"gemm.jsonc",
|
||||
"global-average-pool.jsonc",
|
||||
"greater.jsonc",
|
||||
|
|
|
|||
|
|
@ -291,6 +291,9 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
|
|||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Gather);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gather);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, GatherElements);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, GatherElements);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 12, Resize);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 13, 17, Resize);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 18, 18, Resize);
|
||||
|
|
@ -532,6 +535,9 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Gather)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gather)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, GatherElements)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, GatherElements)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, Resize)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Resize)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Resize)>,
|
||||
|
|
|
|||
37
onnxruntime/core/providers/js/operators/gather_elements.cc
Normal file
37
onnxruntime/core/providers/js/operators/gather_elements.cc
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/js/js_kernel.h"
|
||||
#include "core/providers/js/js_data_types.h"
|
||||
#include "gather_elements.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace js {
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
GatherElements,
|
||||
kOnnxDomain,
|
||||
11,
|
||||
12,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<uint32_t>()})
|
||||
.TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList<TypeList<int32_t, int64_t>>()),
|
||||
GatherElements);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
GatherElements,
|
||||
kOnnxDomain,
|
||||
13,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<int32_t>(),
|
||||
DataTypeImpl::GetTensorType<uint32_t>()})
|
||||
.TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList<TypeList<int32_t, int64_t>>()),
|
||||
GatherElements);
|
||||
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
24
onnxruntime/core/providers/js/operators/gather_elements.h
Normal file
24
onnxruntime/core/providers/js/operators/gather_elements.h
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/js/js_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace js {
|
||||
|
||||
class GatherElements : public JsKernel {
|
||||
public:
|
||||
GatherElements(const OpKernelInfo& info) : JsKernel(info) {
|
||||
int64_t axis = info.GetAttrOrDefault<int64_t>("axis", 0);
|
||||
|
||||
JSEP_INIT_KERNEL_ATTRIBUTE(GatherElements, ({
|
||||
"axis" : Number($1),
|
||||
}),
|
||||
static_cast<int32_t>(axis));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue