diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 7b6d72bc78..a0ff4a3aae 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -35,6 +35,7 @@ Do not modify directly.* | Expand | ai.onnx(8-12,13+) | | | Flatten | ai.onnx(1-8,9-10,11-12,13+) | | | Floor | ai.onnx(6-12,13+) | | +| Gather | ai.onnx(1-10,11-12,13+) | | | Gelu | com.microsoft(1+) | | | Gemm | ai.onnx(7-8,9-10,11+) | | | GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 4fa468cde4..23b47033e5 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -7,6 +7,7 @@ import {concat, parseConcatAttributes} from './ops/concat'; import {conv, parseConvAttributes} from './ops/conv'; import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose'; import {expand} from './ops/expand'; +import {gather, parseGatherAttributes} from './ops/gather'; import {gelu} from './ops/gelu'; import {gemm, parseGemmAttributes} from './ops/gemm'; import {matMul} from './ops/matmul'; @@ -51,6 +52,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Exp', [unaryOps.exp]], ['Expand', [expand]], ['Floor', [unaryOps.floor]], + ['Gather', [gather, parseGatherAttributes]], ['Gelu', [gelu]], ['Gemm', [gemm, parseGemmAttributes]], ['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts new file mode 100644 index 0000000000..113bf7c7cc --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts @@ -0,0 +1,107 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor'; +import {ShapeUtil} from '../../util'; +import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; + +import {ShaderHelper} from './common'; + +export interface GatherAttributes extends AttributeWithCacheKey { + axis: number; +} + +const validateInputs = (inputs: readonly TensorView[]): void => { + if (!inputs || inputs.length !== 2) { + throw new Error('Gather requires 2 inputs.'); + } +}; + +const createGatherProgramInfo = + (metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: GatherAttributes): ProgramInfo => { + const inputShape = inputs[0].dims; + const indicesShape = inputs[1].dims; + + const inputRank = inputShape.length; + const axis = ShapeUtil.normalizeAxis(attributes.axis, inputRank); + + const outputShape = inputShape.slice(0); + outputShape.splice(axis, 1, ...indicesShape); + + const inputDataType = inputs[0].dataType; + const block = ShapeUtil.sizeFromDimension(inputShape, axis + 1); + const elementSize = [DataType.int64, DataType.uint64, DataType.double].includes(inputDataType) ? 2 : 1; + const indicesElementSize = inputs[1].dataType === DataType.int64 ? 2 : 1; + const blockSize = elementSize * block; + const M = ShapeUtil.sizeToDimension(inputShape, axis); + const N = ShapeUtil.size(indicesShape); + const dataBatchElements = ShapeUtil.sizeFromDimension(inputShape, axis) * elementSize; + const gatheredBatchElements = N * block * elementSize; + const axisDimLimit = inputShape[axis]; + + const inputSize = ShapeUtil.size(inputShape) * elementSize; + const outputSize = ShapeUtil.size(outputShape) * elementSize; + + const totalGathers = M * N; + // int64 indices would be treated as little endian i32 with assumption they fall in i32 limits + // That assumption is safe as it's not possible to allocate >2gb buffer for input tensor + // Input data will be treated as u32 or two u32 for 8-byte tensors + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const N: u32 = ${N}; + const elementSize: u32 = ${elementSize}; + const indicesElementSize: u32 = ${indicesElementSize}; + + @group(0) @binding(0) var input : array; + @group(0) @binding(1) var inputIndices : array; + @group(0) @binding(2) var output: array; + + ${shaderHelper.mainStart()} + let batch: u32 = global_idx / N; + let i: u32 = global_idx % N; + + let srcOffsetBatch: u32 = batch * ${dataBatchElements}; + let dstOffsetBatch: u32 = batch * ${gatheredBatchElements}; + var idx = inputIndices[i * indicesElementSize]; + if (idx < 0) { + idx = idx + ${axisDimLimit}; + } + + let srcOffset = srcOffsetBatch + u32(idx) * ${blockSize}; + let dstOffset = dstOffsetBatch + i * ${blockSize}; + if (srcOffset >= ${inputSize}) { + return; + } + if (dstOffset >= ${outputSize}) { + return; + } + for (var j: u32 = 0; j < ${blockSize}; j++) { + output[dstOffset + j] = input[srcOffset + j]; + } + }`; + return { + ...metadata, + outputs: [ + {dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}, + ], + getShaderSource, + dispatchGroup: () => ({x: Math.ceil(totalGathers / 64 /* workgroup size */)}) + }; + }; + +export const parseGatherAttributes = (attributes: Record): GatherAttributes => + createAttributeWithCacheKey({axis: attributes.axis as number}); + +export const gather = (context: ComputeContext, attributes: GatherAttributes): void => { + const inputs = context.inputs; + validateInputs(inputs); + + const metadata = { + name: 'Gather', + inputTypes: [GpuDataType.default, GpuDataType.default], + cacheHint: attributes.cacheKey + inputs[0].dataType.toString(10) + inputs[1].dataType.toString(10), + }; + + context.compute(createGatherProgramInfo(metadata, context.inputs, attributes)); +}; diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 00ac7acfc9..c253aeff30 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -535,9 +535,10 @@ "test_flatten_negative_axis4", "test_floor_example", "test_floor", - // "test_gather_0", - // "test_gather_1", - // "test_gather_2d_indices", + "test_gather_0", + "test_gather_1", + "test_gather_2d_indices", + "test_gather_negative_indices", // "test_gather_elements_0", // "test_gather_elements_1", // "test_gather_elements_negative_indices", diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index dba68137c7..677a254301 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -266,6 +266,10 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Resize); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Gather); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Gather); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gather); + class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 10, 10, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 12, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 13, 17, Resize); @@ -477,6 +481,10 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/js/operators/gather.cc b/onnxruntime/core/providers/js/operators/gather.cc new file mode 100644 index 0000000000..ec1ae71243 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/gather.cc @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" + +#include "gather.h" + +namespace onnxruntime { +namespace js { + +using AllSupportedSize = + TypeList< + float, + double, + int64_t, + uint64_t, + int32_t, + uint32_t>; + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Gather, + kOnnxDomain, + 1, + 10, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", BuildKernelDefConstraintsFromTypeList()) + .TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), + Gather); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Gather, + kOnnxDomain, + 11, + 12, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", BuildKernelDefConstraintsFromTypeList()) + .TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), + Gather); + +ONNX_OPERATOR_KERNEL_EX( + Gather, + kOnnxDomain, + 13, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", BuildKernelDefConstraintsFromTypeList()) + .TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), + Gather); + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/gather.h b/onnxruntime/core/providers/js/operators/gather.h new file mode 100644 index 0000000000..72603d461c --- /dev/null +++ b/onnxruntime/core/providers/js/operators/gather.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace js { + +class Gather : public JsKernel { + public: + Gather(const OpKernelInfo& info) : JsKernel(info) { + int64_t axis = info.GetAttrOrDefault("axis", 0); + + JSEP_INIT_KERNEL_ATTRIBUTE(Gather, ({ + "axis" : Number($1), + }), + static_cast(axis)); + } +}; + +} // namespace js +} // namespace onnxruntime