From cbd97515cd6566f1cd369d49240e5331c9028775 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Mon, 28 Aug 2023 09:55:25 -0700 Subject: [PATCH] [JS/WebGPU] Support GatherElements kernel (#17243) ### Description As title ### Motivation and Context Improve WebGPU kernel coverage --- js/web/docs/webgpu-operators.md | 1 + .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 2 + .../wasm/jsep/webgpu/ops/gather-elements.ts | 110 ++++++++ js/web/test/data/ops/gather-elements.jsonc | 234 ++++++++++++++++++ js/web/test/suite-test-list.jsonc | 7 +- .../providers/js/js_execution_provider.cc | 6 + .../providers/js/operators/gather_elements.cc | 37 +++ .../providers/js/operators/gather_elements.h | 24 ++ 8 files changed, 418 insertions(+), 3 deletions(-) create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts create mode 100644 js/web/test/data/ops/gather-elements.jsonc create mode 100644 onnxruntime/core/providers/js/operators/gather_elements.cc create mode 100644 onnxruntime/core/providers/js/operators/gather_elements.h diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index c56bf4c6ff..a969e1b86b 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -38,6 +38,7 @@ Do not modify directly.* | Flatten | ai.onnx(1-8,9-10,11-12,13+) | | | Floor | ai.onnx(6-12,13+) | | | Gather | ai.onnx(1-10,11-12,13+) | | +| GatherElements | ai.onnx(11-12,13+) | | | Gelu | com.microsoft(1+) | | | Gemm | ai.onnx(7-8,9-10,11-12,13+) | | | GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index ae4b754f76..23aabb6531 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -8,6 +8,7 @@ import {conv, parseConvAttributes} from './ops/conv'; import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose'; import {expand} from './ops/expand'; import {gather, parseGatherAttributes} from './ops/gather'; +import {gatherElements, parseGatherElementsAttributes} from './ops/gather-elements'; import {gemm, parseGemmAttributes} from './ops/gemm'; import {instanceNorm, parseInstanceNormAttributes} from './ops/instance-norm'; import {layerNorm, parseLayerNormAttributes} from './ops/layer-norm'; @@ -58,6 +59,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Expand', [expand]], ['Floor', [unaryOps.floor]], ['Gather', [gather, parseGatherAttributes]], + ['GatherElements', [gatherElements, parseGatherElementsAttributes]], ['Gelu', [unaryOps.gelu]], ['Gemm', [gemm, parseGemmAttributes]], ['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts new file mode 100644 index 0000000000..57c5fccfd8 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts @@ -0,0 +1,110 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TensorView} from '../../tensor'; +import {ShapeUtil} from '../../util'; +import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; + +import {inputVariable, outputVariable, ShaderHelper} from './common'; + +export interface GatherElementsAttributes extends AttributeWithCacheKey { + axis: number; +} + +const validateInputs = (inputs: readonly TensorView[]): void => { + if (!inputs || inputs.length !== 2) { + throw new Error('GatherElements requires 2 inputs.'); + } + + if (inputs[0].dims.length < 1) { + throw new Error('GatherElements requires that the data input be rank >= 1.'); + } + + if (inputs[0].dims.length !== inputs[1].dims.length) { + throw new Error(`GatherElements requires that the data input and + indices input tensors be of same rank.`); + } +}; + +const createGatherElementsProgramInfo = + (metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: GatherElementsAttributes): ProgramInfo => { + const inputShape = inputs[0].dims; + const inputOutputDataType = inputs[0].dataType; + const inputRank = inputShape.length; + const inputStrides = ShapeUtil.computeStrides(inputShape); + const inputSize = ShapeUtil.size(inputShape); + + const indicesShape = inputs[1].dims; + const indicesDataType = inputs[1].dataType; + const indicesSize = ShapeUtil.size(indicesShape); + + const axis = ShapeUtil.normalizeAxis(attributes.axis, inputRank); + const axisDimLimit = inputShape[axis]; + + const outputShape = indicesShape.slice(0); + const outputSize = ShapeUtil.size(outputShape); + + const input = inputVariable('input', inputOutputDataType, inputShape); + const indices = inputVariable('indices', indicesDataType, [indicesSize]); + const output = outputVariable('output', inputOutputDataType, outputShape); + + + // int64 indices would be treated as little endian i32 with assumption they fall in i32 limits + // That assumption is safe as it's not possible to allocate >2gb buffer for input tensor + // Input data will be treated as u32 or two u32 for 8-byte tensors + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const inputStrides = array(${inputStrides.map(i => `${i}u`).join(',')}); + ${shaderHelper.declareVariables(input, indices, output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + + let outputIndices = ${output.offsetToIndices('global_idx')}; + + var idx = ${indices.getByOffset('global_idx')}; + if (idx < 0) { + idx = idx + ${axisDimLimit}; + } + + var srcOffset = u32(0); + + for (var i = 0; i < ${inputShape.length}; i++) { + if (i == ${axis}) { + srcOffset += u32(idx) * inputStrides[i]; + } else { + srcOffset += ${output.indicesGet('outputIndices', 'i')} * inputStrides[i]; + } + } + + // Should never hit this with valid values in indices + // This is a guard against malicious data in the indices input + if (srcOffset < 0 || srcOffset >= ${inputSize}) { + return; + } + + output[global_idx] = input[srcOffset]; + }`; + + return { + ...metadata, + outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + getShaderSource, + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) + }; + }; + +export const parseGatherElementsAttributes = (attributes: Record): GatherElementsAttributes => + createAttributeWithCacheKey({axis: attributes.axis as number}); + +export const gatherElements = (context: ComputeContext, attributes: GatherElementsAttributes): void => { + const inputs = context.inputs; + validateInputs(inputs); + + const metadata = { + name: 'GatherElements', + inputTypes: [GpuDataType.default, GpuDataType.default], + cacheHint: attributes.cacheKey, + }; + + context.compute(createGatherElementsProgramInfo(metadata, context.inputs, attributes)); +}; diff --git a/js/web/test/data/ops/gather-elements.jsonc b/js/web/test/data/ops/gather-elements.jsonc new file mode 100644 index 0000000000..caab3c11f6 --- /dev/null +++ b/js/web/test/data/ops/gather-elements.jsonc @@ -0,0 +1,234 @@ +[ + { + "name": "GatherElements float32 data + int32 indices-1", + "operator": "GatherElements", + "attributes": [{ "name": "axis", "data": 1, "type": "int" }], + "cases": [ + { + "name": "float32 data + int32 indices-1", + "inputs": [ + { + "data": [1, 2, 3, 4], + "dims": [2, 2], + "type": "float32" + }, + { + "data": [0, 0, 1, 0], + "dims": [2, 2], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 1, 4, 3], + "dims": [2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GatherElements float32 data + int32 indices-2", + "operator": "GatherElements", + "attributes": [{ "name": "axis", "data": 1, "type": "int" }], + "cases": [ + { + "name": "float32 data + int32 indices-2", + "inputs": [ + { + "data": [1, 2, 3, 4], + "dims": [2, 2], + "type": "float32" + }, + { + "data": [0, 1, 1, 0], + "dims": [2, 2], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 2, 4, 3], + "dims": [2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GatherElements float32 data + int64 indices - 1", + "operator": "GatherElements", + "attributes": [{ "name": "axis", "data": 1, "type": "int" }], + "cases": [ + { + "name": "float32 data + int64 indices - 1", + "inputs": [ + { + "data": [1, 2, 3, 4], + "dims": [2, 2], + "type": "float32" + }, + { + "data": [0, 0, -1, 0], + "dims": [2, 2], + "type": "int64" + } + ], + "outputs": [ + { + "data": [1, 1, 4, 3], + "dims": [2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GatherElements float32 data + int64 indices - 2", + "operator": "GatherElements", + "attributes": [{ "name": "axis", "data": 1, "type": "int" }], + "cases": [ + { + "name": "float32 data + int64 indices - 2", + "inputs": [ + { + "data": [1, 2, 3, 4], + "dims": [2, 2], + "type": "float32" + }, + { + "data": [0, 0, -2, 0], + "dims": [2, 2], + "type": "int64" + } + ], + "outputs": [ + { + "data": [1, 1, 3, 3], + "dims": [2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GatherElements int32 data + int32 indices-1", + "operator": "GatherElements", + "attributes": [{ "name": "axis", "data": 1, "type": "int" }], + "cases": [ + { + "name": "int32 data + int32 indices-1", + "inputs": [ + { + "data": [1, 2, 3, 4], + "dims": [2, 2], + "type": "int32" + }, + { + "data": [0, 0, 1, 0], + "dims": [2, 2], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 1, 4, 3], + "dims": [2, 2], + "type": "int32" + } + ] + } + ] + }, + { + "name": "GatherElements uint32 data + int32 indices-1", + "operator": "GatherElements", + "attributes": [{ "name": "axis", "data": 1, "type": "int" }], + "cases": [ + { + "name": "uint32 data + int32 indices-1", + "inputs": [ + { + "data": [1, 2, 3, 4], + "dims": [2, 2], + "type": "uint32" + }, + { + "data": [0, 0, 1, 0], + "dims": [2, 2], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 1, 4, 3], + "dims": [2, 2], + "type": "uint32" + } + ] + } + ] + }, + { + "name": "GatherElements float32 data + int32 indices-1 + Negative axis + Negative indices", + "operator": "GatherElements", + "attributes": [{ "name": "axis", "data": -1, "type": "int" }], + "cases": [ + { + "name": "GatherElements float32 data + int32 indices-1 + Negative axis + Negative indices", + "inputs": [ + { + "data": [1, 2, 3, 4], + "dims": [2, 2], + "type": "float32" + }, + { + "data": [0, 0, -1, 0], + "dims": [2, 2], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 1, 4, 3], + "dims": [2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GatherElements float32 data + int32 indices-3", + "operator": "GatherElements", + "attributes": [{ "name": "axis", "data": 0, "type": "int" }], + "cases": [ + { + "name": "GatherElements float32 data + int32 indices-3", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9], + "dims": [3, 3], + "type": "float32" + }, + { + "data": [1, 2, 0, 2, 0, 0], + "dims": [2, 3], + "type": "int32" + } + ], + "outputs": [ + { + "data": [4, 8, 3, 7, 2, 3], + "dims": [2, 3], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index e0b0207c9f..31505d95b9 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -539,9 +539,9 @@ "test_gather_1", "test_gather_2d_indices", "test_gather_negative_indices", - // "test_gather_elements_0", - // "test_gather_elements_1", - // "test_gather_elements_negative_indices", + "test_gather_elements_0", + "test_gather_elements_1", + "test_gather_elements_negative_indices", // "test_gather_negative_indices", // // "test_gathernd_example_float32", // // "test_gathernd_example_int32_batch_dim1", @@ -1339,6 +1339,7 @@ "exp.jsonc", "expand.jsonc", "floor.jsonc", + "gather-elements.jsonc", "gemm.jsonc", "global-average-pool.jsonc", "greater.jsonc", diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 2732eb0c3d..829f3e5f4f 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -291,6 +291,9 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Gather); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gather); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, GatherElements); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, GatherElements); + class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 12, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 13, 17, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 18, 18, Resize); @@ -532,6 +535,9 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/js/operators/gather_elements.cc b/onnxruntime/core/providers/js/operators/gather_elements.cc new file mode 100644 index 0000000000..b4db122341 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/gather_elements.cc @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" +#include "core/providers/js/js_data_types.h" +#include "gather_elements.h" + +namespace onnxruntime { +namespace js { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + GatherElements, + kOnnxDomain, + 11, + 12, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}) + .TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), + GatherElements); + +ONNX_OPERATOR_KERNEL_EX( + GatherElements, + kOnnxDomain, + 13, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}) + .TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), + GatherElements); + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/gather_elements.h b/onnxruntime/core/providers/js/operators/gather_elements.h new file mode 100644 index 0000000000..ce90145133 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/gather_elements.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace js { + +class GatherElements : public JsKernel { + public: + GatherElements(const OpKernelInfo& info) : JsKernel(info) { + int64_t axis = info.GetAttrOrDefault("axis", 0); + + JSEP_INIT_KERNEL_ATTRIBUTE(GatherElements, ({ + "axis" : Number($1), + }), + static_cast(axis)); + } +}; + +} // namespace js +} // namespace onnxruntime