mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
[js/web] JSEP Gather OP (#16855)
### Description Added Gather op that works with both i32 and i64 indices, assuming that values fall into i32 limit. The assumption is safe because it's not possible to allocate more than 2gb buffer for inputs. It treats all data from input tensor as u32, copying 1 or 2 elements for i64, u64 and double. --------- Co-authored-by: Guenther Schmuelling <guschmue@microsoft.com>
This commit is contained in:
parent
acb9e56164
commit
ea55700e1c
7 changed files with 199 additions and 3 deletions
|
|
@ -35,6 +35,7 @@ Do not modify directly.*
|
|||
| Expand | ai.onnx(8-12,13+) | |
|
||||
| Flatten | ai.onnx(1-8,9-10,11-12,13+) | |
|
||||
| Floor | ai.onnx(6-12,13+) | |
|
||||
| Gather | ai.onnx(1-10,11-12,13+) | |
|
||||
| Gelu | com.microsoft(1+) | |
|
||||
| Gemm | ai.onnx(7-8,9-10,11+) | |
|
||||
| GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import {concat, parseConcatAttributes} from './ops/concat';
|
|||
import {conv, parseConvAttributes} from './ops/conv';
|
||||
import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose';
|
||||
import {expand} from './ops/expand';
|
||||
import {gather, parseGatherAttributes} from './ops/gather';
|
||||
import {gelu} from './ops/gelu';
|
||||
import {gemm, parseGemmAttributes} from './ops/gemm';
|
||||
import {matMul} from './ops/matmul';
|
||||
|
|
@ -51,6 +52,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
|
|||
['Exp', [unaryOps.exp]],
|
||||
['Expand', [expand]],
|
||||
['Floor', [unaryOps.floor]],
|
||||
['Gather', [gather, parseGatherAttributes]],
|
||||
['Gelu', [gelu]],
|
||||
['Gemm', [gemm, parseGemmAttributes]],
|
||||
['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]],
|
||||
|
|
|
|||
107
js/web/lib/wasm/jsep/webgpu/ops/gather.ts
Normal file
107
js/web/lib/wasm/jsep/webgpu/ops/gather.ts
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {DataType} from '../../../wasm-common';
|
||||
import {TensorView} from '../../tensor';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types';
|
||||
|
||||
import {ShaderHelper} from './common';
|
||||
|
||||
export interface GatherAttributes extends AttributeWithCacheKey {
|
||||
axis: number;
|
||||
}
|
||||
|
||||
const validateInputs = (inputs: readonly TensorView[]): void => {
|
||||
if (!inputs || inputs.length !== 2) {
|
||||
throw new Error('Gather requires 2 inputs.');
|
||||
}
|
||||
};
|
||||
|
||||
const createGatherProgramInfo =
|
||||
(metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: GatherAttributes): ProgramInfo => {
|
||||
const inputShape = inputs[0].dims;
|
||||
const indicesShape = inputs[1].dims;
|
||||
|
||||
const inputRank = inputShape.length;
|
||||
const axis = ShapeUtil.normalizeAxis(attributes.axis, inputRank);
|
||||
|
||||
const outputShape = inputShape.slice(0);
|
||||
outputShape.splice(axis, 1, ...indicesShape);
|
||||
|
||||
const inputDataType = inputs[0].dataType;
|
||||
const block = ShapeUtil.sizeFromDimension(inputShape, axis + 1);
|
||||
const elementSize = [DataType.int64, DataType.uint64, DataType.double].includes(inputDataType) ? 2 : 1;
|
||||
const indicesElementSize = inputs[1].dataType === DataType.int64 ? 2 : 1;
|
||||
const blockSize = elementSize * block;
|
||||
const M = ShapeUtil.sizeToDimension(inputShape, axis);
|
||||
const N = ShapeUtil.size(indicesShape);
|
||||
const dataBatchElements = ShapeUtil.sizeFromDimension(inputShape, axis) * elementSize;
|
||||
const gatheredBatchElements = N * block * elementSize;
|
||||
const axisDimLimit = inputShape[axis];
|
||||
|
||||
const inputSize = ShapeUtil.size(inputShape) * elementSize;
|
||||
const outputSize = ShapeUtil.size(outputShape) * elementSize;
|
||||
|
||||
const totalGathers = M * N;
|
||||
// int64 indices would be treated as little endian i32 with assumption they fall in i32 limits
|
||||
// That assumption is safe as it's not possible to allocate >2gb buffer for input tensor
|
||||
// Input data will be treated as u32 or two u32 for 8-byte tensors
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
const N: u32 = ${N};
|
||||
const elementSize: u32 = ${elementSize};
|
||||
const indicesElementSize: u32 = ${indicesElementSize};
|
||||
|
||||
@group(0) @binding(0) var<storage, read> input : array<u32>;
|
||||
@group(0) @binding(1) var<storage, read> inputIndices : array<i32>;
|
||||
@group(0) @binding(2) var<storage, read_write> output: array<u32>;
|
||||
|
||||
${shaderHelper.mainStart()}
|
||||
let batch: u32 = global_idx / N;
|
||||
let i: u32 = global_idx % N;
|
||||
|
||||
let srcOffsetBatch: u32 = batch * ${dataBatchElements};
|
||||
let dstOffsetBatch: u32 = batch * ${gatheredBatchElements};
|
||||
var idx = inputIndices[i * indicesElementSize];
|
||||
if (idx < 0) {
|
||||
idx = idx + ${axisDimLimit};
|
||||
}
|
||||
|
||||
let srcOffset = srcOffsetBatch + u32(idx) * ${blockSize};
|
||||
let dstOffset = dstOffsetBatch + i * ${blockSize};
|
||||
if (srcOffset >= ${inputSize}) {
|
||||
return;
|
||||
}
|
||||
if (dstOffset >= ${outputSize}) {
|
||||
return;
|
||||
}
|
||||
for (var j: u32 = 0; j < ${blockSize}; j++) {
|
||||
output[dstOffset + j] = input[srcOffset + j];
|
||||
}
|
||||
}`;
|
||||
return {
|
||||
...metadata,
|
||||
outputs: [
|
||||
{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default},
|
||||
],
|
||||
getShaderSource,
|
||||
dispatchGroup: () => ({x: Math.ceil(totalGathers / 64 /* workgroup size */)})
|
||||
};
|
||||
};
|
||||
|
||||
export const parseGatherAttributes = (attributes: Record<string, unknown>): GatherAttributes =>
|
||||
createAttributeWithCacheKey({axis: attributes.axis as number});
|
||||
|
||||
export const gather = (context: ComputeContext, attributes: GatherAttributes): void => {
|
||||
const inputs = context.inputs;
|
||||
validateInputs(inputs);
|
||||
|
||||
const metadata = {
|
||||
name: 'Gather',
|
||||
inputTypes: [GpuDataType.default, GpuDataType.default],
|
||||
cacheHint: attributes.cacheKey + inputs[0].dataType.toString(10) + inputs[1].dataType.toString(10),
|
||||
};
|
||||
|
||||
context.compute(createGatherProgramInfo(metadata, context.inputs, attributes));
|
||||
};
|
||||
|
|
@ -535,9 +535,10 @@
|
|||
"test_flatten_negative_axis4",
|
||||
"test_floor_example",
|
||||
"test_floor",
|
||||
// "test_gather_0",
|
||||
// "test_gather_1",
|
||||
// "test_gather_2d_indices",
|
||||
"test_gather_0",
|
||||
"test_gather_1",
|
||||
"test_gather_2d_indices",
|
||||
"test_gather_negative_indices",
|
||||
// "test_gather_elements_0",
|
||||
// "test_gather_elements_1",
|
||||
// "test_gather_elements_negative_indices",
|
||||
|
|
|
|||
|
|
@ -266,6 +266,10 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
|
|||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Resize);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Resize);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Gather);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Gather);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gather);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 10, 10, Resize);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 12, Resize);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 13, 17, Resize);
|
||||
|
|
@ -477,6 +481,10 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 8, 12, Expand)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Expand)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Gather)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Gather)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gather)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, Resize)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Resize)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Resize)>,
|
||||
|
|
|
|||
53
onnxruntime/core/providers/js/operators/gather.cc
Normal file
53
onnxruntime/core/providers/js/operators/gather.cc
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/js/js_kernel.h"
|
||||
|
||||
#include "gather.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace js {
|
||||
|
||||
using AllSupportedSize =
|
||||
TypeList<
|
||||
float,
|
||||
double,
|
||||
int64_t,
|
||||
uint64_t,
|
||||
int32_t,
|
||||
uint32_t>;
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Gather,
|
||||
kOnnxDomain,
|
||||
1,
|
||||
10,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", BuildKernelDefConstraintsFromTypeList<AllSupportedSize>())
|
||||
.TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList<TypeList<int32_t, int64_t>>()),
|
||||
Gather);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Gather,
|
||||
kOnnxDomain,
|
||||
11,
|
||||
12,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", BuildKernelDefConstraintsFromTypeList<AllSupportedSize>())
|
||||
.TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList<TypeList<int32_t, int64_t>>()),
|
||||
Gather);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
Gather,
|
||||
kOnnxDomain,
|
||||
13,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", BuildKernelDefConstraintsFromTypeList<AllSupportedSize>())
|
||||
.TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList<TypeList<int32_t, int64_t>>()),
|
||||
Gather);
|
||||
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
24
onnxruntime/core/providers/js/operators/gather.h
Normal file
24
onnxruntime/core/providers/js/operators/gather.h
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/js/js_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace js {
|
||||
|
||||
class Gather : public JsKernel {
|
||||
public:
|
||||
Gather(const OpKernelInfo& info) : JsKernel(info) {
|
||||
int64_t axis = info.GetAttrOrDefault<int64_t>("axis", 0);
|
||||
|
||||
JSEP_INIT_KERNEL_ATTRIBUTE(Gather, ({
|
||||
"axis" : Number($1),
|
||||
}),
|
||||
static_cast<int32_t>(axis));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue