[js/web] JSEP Gather OP (#16855)

### Description
Added Gather op that works with both i32 and i64 indices, assuming that
values fall into i32 limit. The assumption is safe because it's not
possible to allocate more than 2gb buffer for inputs.

It treats all data from input tensor as u32, copying 1 or 2 elements for
i64, u64 and double.

---------

Co-authored-by: Guenther Schmuelling <guschmue@microsoft.com>
This commit is contained in:
Arthur Islamov 2023-08-04 01:09:37 +04:00 committed by GitHub
parent acb9e56164
commit ea55700e1c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 199 additions and 3 deletions

View file

@ -35,6 +35,7 @@ Do not modify directly.*
| Expand | ai.onnx(8-12,13+) | |
| Flatten | ai.onnx(1-8,9-10,11-12,13+) | |
| Floor | ai.onnx(6-12,13+) | |
| Gather | ai.onnx(1-10,11-12,13+) | |
| Gelu | com.microsoft(1+) | |
| Gemm | ai.onnx(7-8,9-10,11+) | |
| GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |

View file

@ -7,6 +7,7 @@ import {concat, parseConcatAttributes} from './ops/concat';
import {conv, parseConvAttributes} from './ops/conv';
import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose';
import {expand} from './ops/expand';
import {gather, parseGatherAttributes} from './ops/gather';
import {gelu} from './ops/gelu';
import {gemm, parseGemmAttributes} from './ops/gemm';
import {matMul} from './ops/matmul';
@ -51,6 +52,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Exp', [unaryOps.exp]],
['Expand', [expand]],
['Floor', [unaryOps.floor]],
['Gather', [gather, parseGatherAttributes]],
['Gelu', [gelu]],
['Gemm', [gemm, parseGemmAttributes]],
['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]],

View file

@ -0,0 +1,107 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types';
import {ShaderHelper} from './common';
export interface GatherAttributes extends AttributeWithCacheKey {
axis: number;
}
const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs || inputs.length !== 2) {
throw new Error('Gather requires 2 inputs.');
}
};
const createGatherProgramInfo =
(metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: GatherAttributes): ProgramInfo => {
const inputShape = inputs[0].dims;
const indicesShape = inputs[1].dims;
const inputRank = inputShape.length;
const axis = ShapeUtil.normalizeAxis(attributes.axis, inputRank);
const outputShape = inputShape.slice(0);
outputShape.splice(axis, 1, ...indicesShape);
const inputDataType = inputs[0].dataType;
const block = ShapeUtil.sizeFromDimension(inputShape, axis + 1);
const elementSize = [DataType.int64, DataType.uint64, DataType.double].includes(inputDataType) ? 2 : 1;
const indicesElementSize = inputs[1].dataType === DataType.int64 ? 2 : 1;
const blockSize = elementSize * block;
const M = ShapeUtil.sizeToDimension(inputShape, axis);
const N = ShapeUtil.size(indicesShape);
const dataBatchElements = ShapeUtil.sizeFromDimension(inputShape, axis) * elementSize;
const gatheredBatchElements = N * block * elementSize;
const axisDimLimit = inputShape[axis];
const inputSize = ShapeUtil.size(inputShape) * elementSize;
const outputSize = ShapeUtil.size(outputShape) * elementSize;
const totalGathers = M * N;
// int64 indices would be treated as little endian i32 with assumption they fall in i32 limits
// That assumption is safe as it's not possible to allocate >2gb buffer for input tensor
// Input data will be treated as u32 or two u32 for 8-byte tensors
const getShaderSource = (shaderHelper: ShaderHelper) => `
const N: u32 = ${N};
const elementSize: u32 = ${elementSize};
const indicesElementSize: u32 = ${indicesElementSize};
@group(0) @binding(0) var<storage, read> input : array<u32>;
@group(0) @binding(1) var<storage, read> inputIndices : array<i32>;
@group(0) @binding(2) var<storage, read_write> output: array<u32>;
${shaderHelper.mainStart()}
let batch: u32 = global_idx / N;
let i: u32 = global_idx % N;
let srcOffsetBatch: u32 = batch * ${dataBatchElements};
let dstOffsetBatch: u32 = batch * ${gatheredBatchElements};
var idx = inputIndices[i * indicesElementSize];
if (idx < 0) {
idx = idx + ${axisDimLimit};
}
let srcOffset = srcOffsetBatch + u32(idx) * ${blockSize};
let dstOffset = dstOffsetBatch + i * ${blockSize};
if (srcOffset >= ${inputSize}) {
return;
}
if (dstOffset >= ${outputSize}) {
return;
}
for (var j: u32 = 0; j < ${blockSize}; j++) {
output[dstOffset + j] = input[srcOffset + j];
}
}`;
return {
...metadata,
outputs: [
{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default},
],
getShaderSource,
dispatchGroup: () => ({x: Math.ceil(totalGathers / 64 /* workgroup size */)})
};
};
export const parseGatherAttributes = (attributes: Record<string, unknown>): GatherAttributes =>
createAttributeWithCacheKey({axis: attributes.axis as number});
export const gather = (context: ComputeContext, attributes: GatherAttributes): void => {
const inputs = context.inputs;
validateInputs(inputs);
const metadata = {
name: 'Gather',
inputTypes: [GpuDataType.default, GpuDataType.default],
cacheHint: attributes.cacheKey + inputs[0].dataType.toString(10) + inputs[1].dataType.toString(10),
};
context.compute(createGatherProgramInfo(metadata, context.inputs, attributes));
};

View file

@ -535,9 +535,10 @@
"test_flatten_negative_axis4",
"test_floor_example",
"test_floor",
// "test_gather_0",
// "test_gather_1",
// "test_gather_2d_indices",
"test_gather_0",
"test_gather_1",
"test_gather_2d_indices",
"test_gather_negative_indices",
// "test_gather_elements_0",
// "test_gather_elements_1",
// "test_gather_elements_negative_indices",

View file

@ -266,6 +266,10 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Resize);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Resize);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Gather);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Gather);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gather);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 10, 10, Resize);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 12, Resize);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 13, 17, Resize);
@ -477,6 +481,10 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 8, 12, Expand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Expand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Gather)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Gather)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gather)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Resize)>,

View file

@ -0,0 +1,53 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/js/js_kernel.h"
#include "gather.h"
namespace onnxruntime {
namespace js {
using AllSupportedSize =
TypeList<
float,
double,
int64_t,
uint64_t,
int32_t,
uint32_t>;
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Gather,
kOnnxDomain,
1,
10,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", BuildKernelDefConstraintsFromTypeList<AllSupportedSize>())
.TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList<TypeList<int32_t, int64_t>>()),
Gather);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Gather,
kOnnxDomain,
11,
12,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", BuildKernelDefConstraintsFromTypeList<AllSupportedSize>())
.TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList<TypeList<int32_t, int64_t>>()),
Gather);
ONNX_OPERATOR_KERNEL_EX(
Gather,
kOnnxDomain,
13,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", BuildKernelDefConstraintsFromTypeList<AllSupportedSize>())
.TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList<TypeList<int32_t, int64_t>>()),
Gather);
} // namespace js
} // namespace onnxruntime

View file

@ -0,0 +1,24 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/js/js_kernel.h"
namespace onnxruntime {
namespace js {
class Gather : public JsKernel {
public:
Gather(const OpKernelInfo& info) : JsKernel(info) {
int64_t axis = info.GetAttrOrDefault<int64_t>("axis", 0);
JSEP_INIT_KERNEL_ATTRIBUTE(Gather, ({
"axis" : Number($1),
}),
static_cast<int32_t>(axis));
}
};
} // namespace js
} // namespace onnxruntime