From ea55700e1cb3e4f8485e9c9cfd69ebe470701397 Mon Sep 17 00:00:00 2001 From: Arthur Islamov Date: Fri, 4 Aug 2023 01:09:37 +0400 Subject: [PATCH] [js/web] JSEP Gather OP (#16855) ### Description Added Gather op that works with both i32 and i64 indices, assuming that values fall into i32 limit. The assumption is safe because it's not possible to allocate more than 2gb buffer for inputs. It treats all data from input tensor as u32, copying 1 or 2 elements for i64, u64 and double. --------- Co-authored-by: Guenther Schmuelling --- js/web/docs/webgpu-operators.md | 1 + .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 2 + js/web/lib/wasm/jsep/webgpu/ops/gather.ts | 107 ++++++++++++++++++ js/web/test/suite-test-list.jsonc | 7 +- .../providers/js/js_execution_provider.cc | 8 ++ .../core/providers/js/operators/gather.cc | 53 +++++++++ .../core/providers/js/operators/gather.h | 24 ++++ 7 files changed, 199 insertions(+), 3 deletions(-) create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/gather.ts create mode 100644 onnxruntime/core/providers/js/operators/gather.cc create mode 100644 onnxruntime/core/providers/js/operators/gather.h diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 7b6d72bc78..a0ff4a3aae 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -35,6 +35,7 @@ Do not modify directly.* | Expand | ai.onnx(8-12,13+) | | | Flatten | ai.onnx(1-8,9-10,11-12,13+) | | | Floor | ai.onnx(6-12,13+) | | +| Gather | ai.onnx(1-10,11-12,13+) | | | Gelu | com.microsoft(1+) | | | Gemm | ai.onnx(7-8,9-10,11+) | | | GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 4fa468cde4..23b47033e5 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -7,6 +7,7 @@ import {concat, parseConcatAttributes} from './ops/concat'; import {conv, parseConvAttributes} from './ops/conv'; import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose'; import {expand} from './ops/expand'; +import {gather, parseGatherAttributes} from './ops/gather'; import {gelu} from './ops/gelu'; import {gemm, parseGemmAttributes} from './ops/gemm'; import {matMul} from './ops/matmul'; @@ -51,6 +52,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Exp', [unaryOps.exp]], ['Expand', [expand]], ['Floor', [unaryOps.floor]], + ['Gather', [gather, parseGatherAttributes]], ['Gelu', [gelu]], ['Gemm', [gemm, parseGemmAttributes]], ['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts new file mode 100644 index 0000000000..113bf7c7cc --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts @@ -0,0 +1,107 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor'; +import {ShapeUtil} from '../../util'; +import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; + +import {ShaderHelper} from './common'; + +export interface GatherAttributes extends AttributeWithCacheKey { + axis: number; +} + +const validateInputs = (inputs: readonly TensorView[]): void => { + if (!inputs || inputs.length !== 2) { + throw new Error('Gather requires 2 inputs.'); + } +}; + +const createGatherProgramInfo = + (metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: GatherAttributes): ProgramInfo => { + const inputShape = inputs[0].dims; + const indicesShape = inputs[1].dims; + + const inputRank = inputShape.length; + const axis = ShapeUtil.normalizeAxis(attributes.axis, inputRank); + + const outputShape = inputShape.slice(0); + outputShape.splice(axis, 1, ...indicesShape); + + const inputDataType = inputs[0].dataType; + const block = ShapeUtil.sizeFromDimension(inputShape, axis + 1); + const elementSize = [DataType.int64, DataType.uint64, DataType.double].includes(inputDataType) ? 2 : 1; + const indicesElementSize = inputs[1].dataType === DataType.int64 ? 2 : 1; + const blockSize = elementSize * block; + const M = ShapeUtil.sizeToDimension(inputShape, axis); + const N = ShapeUtil.size(indicesShape); + const dataBatchElements = ShapeUtil.sizeFromDimension(inputShape, axis) * elementSize; + const gatheredBatchElements = N * block * elementSize; + const axisDimLimit = inputShape[axis]; + + const inputSize = ShapeUtil.size(inputShape) * elementSize; + const outputSize = ShapeUtil.size(outputShape) * elementSize; + + const totalGathers = M * N; + // int64 indices would be treated as little endian i32 with assumption they fall in i32 limits + // That assumption is safe as it's not possible to allocate >2gb buffer for input tensor + // Input data will be treated as u32 or two u32 for 8-byte tensors + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const N: u32 = ${N}; + const elementSize: u32 = ${elementSize}; + const indicesElementSize: u32 = ${indicesElementSize}; + + @group(0) @binding(0) var input : array; + @group(0) @binding(1) var inputIndices : array; + @group(0) @binding(2) var output: array; + + ${shaderHelper.mainStart()} + let batch: u32 = global_idx / N; + let i: u32 = global_idx % N; + + let srcOffsetBatch: u32 = batch * ${dataBatchElements}; + let dstOffsetBatch: u32 = batch * ${gatheredBatchElements}; + var idx = inputIndices[i * indicesElementSize]; + if (idx < 0) { + idx = idx + ${axisDimLimit}; + } + + let srcOffset = srcOffsetBatch + u32(idx) * ${blockSize}; + let dstOffset = dstOffsetBatch + i * ${blockSize}; + if (srcOffset >= ${inputSize}) { + return; + } + if (dstOffset >= ${outputSize}) { + return; + } + for (var j: u32 = 0; j < ${blockSize}; j++) { + output[dstOffset + j] = input[srcOffset + j]; + } + }`; + return { + ...metadata, + outputs: [ + {dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}, + ], + getShaderSource, + dispatchGroup: () => ({x: Math.ceil(totalGathers / 64 /* workgroup size */)}) + }; + }; + +export const parseGatherAttributes = (attributes: Record): GatherAttributes => + createAttributeWithCacheKey({axis: attributes.axis as number}); + +export const gather = (context: ComputeContext, attributes: GatherAttributes): void => { + const inputs = context.inputs; + validateInputs(inputs); + + const metadata = { + name: 'Gather', + inputTypes: [GpuDataType.default, GpuDataType.default], + cacheHint: attributes.cacheKey + inputs[0].dataType.toString(10) + inputs[1].dataType.toString(10), + }; + + context.compute(createGatherProgramInfo(metadata, context.inputs, attributes)); +}; diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 00ac7acfc9..c253aeff30 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -535,9 +535,10 @@ "test_flatten_negative_axis4", "test_floor_example", "test_floor", - // "test_gather_0", - // "test_gather_1", - // "test_gather_2d_indices", + "test_gather_0", + "test_gather_1", + "test_gather_2d_indices", + "test_gather_negative_indices", // "test_gather_elements_0", // "test_gather_elements_1", // "test_gather_elements_negative_indices", diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index dba68137c7..677a254301 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -266,6 +266,10 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Resize); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Gather); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Gather); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gather); + class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 10, 10, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 12, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 13, 17, Resize); @@ -477,6 +481,10 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/js/operators/gather.cc b/onnxruntime/core/providers/js/operators/gather.cc new file mode 100644 index 0000000000..ec1ae71243 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/gather.cc @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" + +#include "gather.h" + +namespace onnxruntime { +namespace js { + +using AllSupportedSize = + TypeList< + float, + double, + int64_t, + uint64_t, + int32_t, + uint32_t>; + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Gather, + kOnnxDomain, + 1, + 10, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", BuildKernelDefConstraintsFromTypeList()) + .TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), + Gather); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Gather, + kOnnxDomain, + 11, + 12, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", BuildKernelDefConstraintsFromTypeList()) + .TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), + Gather); + +ONNX_OPERATOR_KERNEL_EX( + Gather, + kOnnxDomain, + 13, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", BuildKernelDefConstraintsFromTypeList()) + .TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), + Gather); + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/gather.h b/onnxruntime/core/providers/js/operators/gather.h new file mode 100644 index 0000000000..72603d461c --- /dev/null +++ b/onnxruntime/core/providers/js/operators/gather.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace js { + +class Gather : public JsKernel { + public: + Gather(const OpKernelInfo& info) : JsKernel(info) { + int64_t axis = info.GetAttrOrDefault("axis", 0); + + JSEP_INIT_KERNEL_ATTRIBUTE(Gather, ({ + "axis" : Number($1), + }), + static_cast(axis)); + } +}; + +} // namespace js +} // namespace onnxruntime