[JS/WebGPU] Support GatherElements kernel (#17243)

### Description
As title


### Motivation and Context
Improve WebGPU kernel coverage
This commit is contained in:
Hariharan Seshadri 2023-08-28 09:55:25 -07:00 committed by GitHub
parent 53169f59e5
commit cbd97515cd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 418 additions and 3 deletions

View file

@ -38,6 +38,7 @@ Do not modify directly.*
| Flatten | ai.onnx(1-8,9-10,11-12,13+) | |
| Floor | ai.onnx(6-12,13+) | |
| Gather | ai.onnx(1-10,11-12,13+) | |
| GatherElements | ai.onnx(11-12,13+) | |
| Gelu | com.microsoft(1+) | |
| Gemm | ai.onnx(7-8,9-10,11-12,13+) | |
| GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |

View file

@ -8,6 +8,7 @@ import {conv, parseConvAttributes} from './ops/conv';
import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose';
import {expand} from './ops/expand';
import {gather, parseGatherAttributes} from './ops/gather';
import {gatherElements, parseGatherElementsAttributes} from './ops/gather-elements';
import {gemm, parseGemmAttributes} from './ops/gemm';
import {instanceNorm, parseInstanceNormAttributes} from './ops/instance-norm';
import {layerNorm, parseLayerNormAttributes} from './ops/layer-norm';
@ -58,6 +59,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Expand', [expand]],
['Floor', [unaryOps.floor]],
['Gather', [gather, parseGatherAttributes]],
['GatherElements', [gatherElements, parseGatherElementsAttributes]],
['Gelu', [unaryOps.gelu]],
['Gemm', [gemm, parseGemmAttributes]],
['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]],

View file

@ -0,0 +1,110 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {TensorView} from '../../tensor';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types';
import {inputVariable, outputVariable, ShaderHelper} from './common';
export interface GatherElementsAttributes extends AttributeWithCacheKey {
axis: number;
}
const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs || inputs.length !== 2) {
throw new Error('GatherElements requires 2 inputs.');
}
if (inputs[0].dims.length < 1) {
throw new Error('GatherElements requires that the data input be rank >= 1.');
}
if (inputs[0].dims.length !== inputs[1].dims.length) {
throw new Error(`GatherElements requires that the data input and
indices input tensors be of same rank.`);
}
};
const createGatherElementsProgramInfo =
(metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: GatherElementsAttributes): ProgramInfo => {
const inputShape = inputs[0].dims;
const inputOutputDataType = inputs[0].dataType;
const inputRank = inputShape.length;
const inputStrides = ShapeUtil.computeStrides(inputShape);
const inputSize = ShapeUtil.size(inputShape);
const indicesShape = inputs[1].dims;
const indicesDataType = inputs[1].dataType;
const indicesSize = ShapeUtil.size(indicesShape);
const axis = ShapeUtil.normalizeAxis(attributes.axis, inputRank);
const axisDimLimit = inputShape[axis];
const outputShape = indicesShape.slice(0);
const outputSize = ShapeUtil.size(outputShape);
const input = inputVariable('input', inputOutputDataType, inputShape);
const indices = inputVariable('indices', indicesDataType, [indicesSize]);
const output = outputVariable('output', inputOutputDataType, outputShape);
// int64 indices would be treated as little endian i32 with assumption they fall in i32 limits
// That assumption is safe as it's not possible to allocate >2gb buffer for input tensor
// Input data will be treated as u32 or two u32 for 8-byte tensors
const getShaderSource = (shaderHelper: ShaderHelper) => `
const inputStrides = array<u32, ${inputStrides.length}>(${inputStrides.map(i => `${i}u`).join(',')});
${shaderHelper.declareVariables(input, indices, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
let outputIndices = ${output.offsetToIndices('global_idx')};
var idx = ${indices.getByOffset('global_idx')};
if (idx < 0) {
idx = idx + ${axisDimLimit};
}
var srcOffset = u32(0);
for (var i = 0; i < ${inputShape.length}; i++) {
if (i == ${axis}) {
srcOffset += u32(idx) * inputStrides[i];
} else {
srcOffset += ${output.indicesGet('outputIndices', 'i')} * inputStrides[i];
}
}
// Should never hit this with valid values in indices
// This is a guard against malicious data in the indices input
if (srcOffset < 0 || srcOffset >= ${inputSize}) {
return;
}
output[global_idx] = input[srcOffset];
}`;
return {
...metadata,
outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}],
getShaderSource,
dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)})
};
};
export const parseGatherElementsAttributes = (attributes: Record<string, unknown>): GatherElementsAttributes =>
createAttributeWithCacheKey({axis: attributes.axis as number});
export const gatherElements = (context: ComputeContext, attributes: GatherElementsAttributes): void => {
const inputs = context.inputs;
validateInputs(inputs);
const metadata = {
name: 'GatherElements',
inputTypes: [GpuDataType.default, GpuDataType.default],
cacheHint: attributes.cacheKey,
};
context.compute(createGatherElementsProgramInfo(metadata, context.inputs, attributes));
};

View file

@ -0,0 +1,234 @@
[
{
"name": "GatherElements float32 data + int32 indices-1",
"operator": "GatherElements",
"attributes": [{ "name": "axis", "data": 1, "type": "int" }],
"cases": [
{
"name": "float32 data + int32 indices-1",
"inputs": [
{
"data": [1, 2, 3, 4],
"dims": [2, 2],
"type": "float32"
},
{
"data": [0, 0, 1, 0],
"dims": [2, 2],
"type": "int32"
}
],
"outputs": [
{
"data": [1, 1, 4, 3],
"dims": [2, 2],
"type": "float32"
}
]
}
]
},
{
"name": "GatherElements float32 data + int32 indices-2",
"operator": "GatherElements",
"attributes": [{ "name": "axis", "data": 1, "type": "int" }],
"cases": [
{
"name": "float32 data + int32 indices-2",
"inputs": [
{
"data": [1, 2, 3, 4],
"dims": [2, 2],
"type": "float32"
},
{
"data": [0, 1, 1, 0],
"dims": [2, 2],
"type": "int32"
}
],
"outputs": [
{
"data": [1, 2, 4, 3],
"dims": [2, 2],
"type": "float32"
}
]
}
]
},
{
"name": "GatherElements float32 data + int64 indices - 1",
"operator": "GatherElements",
"attributes": [{ "name": "axis", "data": 1, "type": "int" }],
"cases": [
{
"name": "float32 data + int64 indices - 1",
"inputs": [
{
"data": [1, 2, 3, 4],
"dims": [2, 2],
"type": "float32"
},
{
"data": [0, 0, -1, 0],
"dims": [2, 2],
"type": "int64"
}
],
"outputs": [
{
"data": [1, 1, 4, 3],
"dims": [2, 2],
"type": "float32"
}
]
}
]
},
{
"name": "GatherElements float32 data + int64 indices - 2",
"operator": "GatherElements",
"attributes": [{ "name": "axis", "data": 1, "type": "int" }],
"cases": [
{
"name": "float32 data + int64 indices - 2",
"inputs": [
{
"data": [1, 2, 3, 4],
"dims": [2, 2],
"type": "float32"
},
{
"data": [0, 0, -2, 0],
"dims": [2, 2],
"type": "int64"
}
],
"outputs": [
{
"data": [1, 1, 3, 3],
"dims": [2, 2],
"type": "float32"
}
]
}
]
},
{
"name": "GatherElements int32 data + int32 indices-1",
"operator": "GatherElements",
"attributes": [{ "name": "axis", "data": 1, "type": "int" }],
"cases": [
{
"name": "int32 data + int32 indices-1",
"inputs": [
{
"data": [1, 2, 3, 4],
"dims": [2, 2],
"type": "int32"
},
{
"data": [0, 0, 1, 0],
"dims": [2, 2],
"type": "int32"
}
],
"outputs": [
{
"data": [1, 1, 4, 3],
"dims": [2, 2],
"type": "int32"
}
]
}
]
},
{
"name": "GatherElements uint32 data + int32 indices-1",
"operator": "GatherElements",
"attributes": [{ "name": "axis", "data": 1, "type": "int" }],
"cases": [
{
"name": "uint32 data + int32 indices-1",
"inputs": [
{
"data": [1, 2, 3, 4],
"dims": [2, 2],
"type": "uint32"
},
{
"data": [0, 0, 1, 0],
"dims": [2, 2],
"type": "int32"
}
],
"outputs": [
{
"data": [1, 1, 4, 3],
"dims": [2, 2],
"type": "uint32"
}
]
}
]
},
{
"name": "GatherElements float32 data + int32 indices-1 + Negative axis + Negative indices",
"operator": "GatherElements",
"attributes": [{ "name": "axis", "data": -1, "type": "int" }],
"cases": [
{
"name": "GatherElements float32 data + int32 indices-1 + Negative axis + Negative indices",
"inputs": [
{
"data": [1, 2, 3, 4],
"dims": [2, 2],
"type": "float32"
},
{
"data": [0, 0, -1, 0],
"dims": [2, 2],
"type": "int32"
}
],
"outputs": [
{
"data": [1, 1, 4, 3],
"dims": [2, 2],
"type": "float32"
}
]
}
]
},
{
"name": "GatherElements float32 data + int32 indices-3",
"operator": "GatherElements",
"attributes": [{ "name": "axis", "data": 0, "type": "int" }],
"cases": [
{
"name": "GatherElements float32 data + int32 indices-3",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9],
"dims": [3, 3],
"type": "float32"
},
{
"data": [1, 2, 0, 2, 0, 0],
"dims": [2, 3],
"type": "int32"
}
],
"outputs": [
{
"data": [4, 8, 3, 7, 2, 3],
"dims": [2, 3],
"type": "float32"
}
]
}
]
}
]

View file

@ -539,9 +539,9 @@
"test_gather_1",
"test_gather_2d_indices",
"test_gather_negative_indices",
// "test_gather_elements_0",
// "test_gather_elements_1",
// "test_gather_elements_negative_indices",
"test_gather_elements_0",
"test_gather_elements_1",
"test_gather_elements_negative_indices",
// "test_gather_negative_indices",
// // "test_gathernd_example_float32",
// // "test_gathernd_example_int32_batch_dim1",
@ -1339,6 +1339,7 @@
"exp.jsonc",
"expand.jsonc",
"floor.jsonc",
"gather-elements.jsonc",
"gemm.jsonc",
"global-average-pool.jsonc",
"greater.jsonc",

View file

@ -291,6 +291,9 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Gather);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gather);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, GatherElements);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, GatherElements);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 12, Resize);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 13, 17, Resize);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 18, 18, Resize);
@ -532,6 +535,9 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Gather)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gather)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, GatherElements)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, GatherElements)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Resize)>,

View file

@ -0,0 +1,37 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/js/js_kernel.h"
#include "core/providers/js/js_data_types.h"
#include "gather_elements.h"
namespace onnxruntime {
namespace js {
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
GatherElements,
kOnnxDomain,
11,
12,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<uint32_t>()})
.TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList<TypeList<int32_t, int64_t>>()),
GatherElements);
ONNX_OPERATOR_KERNEL_EX(
GatherElements,
kOnnxDomain,
13,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<uint32_t>()})
.TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList<TypeList<int32_t, int64_t>>()),
GatherElements);
} // namespace js
} // namespace onnxruntime

View file

@ -0,0 +1,24 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/js/js_kernel.h"
namespace onnxruntime {
namespace js {
class GatherElements : public JsKernel {
public:
GatherElements(const OpKernelInfo& info) : JsKernel(info) {
int64_t axis = info.GetAttrOrDefault<int64_t>("axis", 0);
JSEP_INIT_KERNEL_ATTRIBUTE(GatherElements, ({
"axis" : Number($1),
}),
static_cast<int32_t>(axis));
}
};
} // namespace js
} // namespace onnxruntime