mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-07 00:13:17 +00:00
[js/webgpu] Add GatherND (#22847)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
a615bd6688
commit
c19617a24a
8 changed files with 403 additions and 0 deletions
|
|
@ -50,6 +50,7 @@ Do not modify directly.*
|
|||
| Gather | ai.onnx(1-10,11-12,13+) | |
|
||||
| GatherBlockQuantized | com.microsoft(1+) | |
|
||||
| GatherElements | ai.onnx(11-12,13+) | |
|
||||
| GatherND | ai.onnx(11,12,13+) | |
|
||||
| Gelu | ai.onnx(20+); com.microsoft(1+) | |
|
||||
| Gemm | ai.onnx(7-8,9-10,11-12,13+) | |
|
||||
| GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ import { einsum, parseEinsumAttributes } from './ops/einsum';
|
|||
import { expand } from './ops/expand';
|
||||
import { fastGelu } from './ops/fast-gelu';
|
||||
import { gather, parseGatherAttributes } from './ops/gather';
|
||||
import { gatherND, parseGatherNDAttributes } from './ops/gather-nd';
|
||||
import { gatherBlockQuantized, parseGatherBlockQuantizedAttributes } from './ops/gather-block-quantized';
|
||||
import { gatherElements, parseGatherElementsAttributes } from './ops/gather-elements';
|
||||
import { gemm, parseGemmAttributes } from './ops/gemm';
|
||||
|
|
@ -100,6 +101,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
|
|||
['Gather', [gather, parseGatherAttributes]],
|
||||
['GatherElements', [gatherElements, parseGatherElementsAttributes]],
|
||||
['GatherBlockQuantized', [gatherBlockQuantized, parseGatherBlockQuantizedAttributes]],
|
||||
['GatherND', [gatherND, parseGatherNDAttributes]],
|
||||
['Gelu', [unaryOps.gelu]],
|
||||
['Gemm', [gemm, parseGemmAttributes]],
|
||||
['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]],
|
||||
|
|
|
|||
179
js/web/lib/wasm/jsep/webgpu/ops/gather-nd.ts
Normal file
179
js/web/lib/wasm/jsep/webgpu/ops/gather-nd.ts
Normal file
|
|
@ -0,0 +1,179 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import { DataType } from '../../../wasm-common';
|
||||
import { TensorView } from '../../tensor-view';
|
||||
import { ShapeUtil } from '../../util';
|
||||
import { AttributeWithCacheKey } from '../attribute-with-cache-key';
|
||||
import { ComputeContext, ProgramUniform } from '../types';
|
||||
|
||||
import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType } from './common';
|
||||
|
||||
export interface GatherNDAttributes extends AttributeWithCacheKey {
|
||||
readonly batchDims: number;
|
||||
}
|
||||
|
||||
const computeSliceOffsets = (
|
||||
context: ComputeContext,
|
||||
indicesData: TensorView,
|
||||
sizesFromSliceDimsData: number[],
|
||||
batchDims: number,
|
||||
inputDims: readonly number[],
|
||||
numSlices: number,
|
||||
numSlicesPerBatch: number,
|
||||
inputBatchStride: number,
|
||||
numSliceDims: number,
|
||||
) => {
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{ type: DataType.uint32, data: numSlices },
|
||||
{ type: DataType.uint32, data: batchDims },
|
||||
{ type: DataType.uint32, data: inputDims },
|
||||
{ type: DataType.uint32, data: sizesFromSliceDimsData },
|
||||
{ type: DataType.uint32, data: numSlicesPerBatch },
|
||||
{ type: DataType.uint32, data: inputBatchStride },
|
||||
{ type: DataType.uint32, data: numSliceDims },
|
||||
];
|
||||
|
||||
const outputShape = [numSlices];
|
||||
programUniforms.push(...createTensorShapeVariables(indicesData.dims, outputShape));
|
||||
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
const indices = inputVariable('indices_data', indicesData.dataType, indicesData.dims.length);
|
||||
const output = outputVariable('input_slice_offsets_data', DataType.uint32, 1, 1);
|
||||
const variables = [indices, output];
|
||||
const uniforms: UniformsArrayType = [
|
||||
{ name: 'output_size', type: 'u32' },
|
||||
{ name: 'batch_dims', type: 'u32' },
|
||||
{ name: 'input_dims', type: 'u32', length: inputDims.length },
|
||||
{ name: 'sizes_from_slice_dims_data', type: 'u32', length: sizesFromSliceDimsData.length },
|
||||
{ name: 'num_slices_per_batch', type: 'u32' },
|
||||
{ name: 'input_batch_stride', type: 'u32' },
|
||||
{ name: 'num_slice_dims', type: 'u32' },
|
||||
];
|
||||
return `
|
||||
${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)}
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
|
||||
let batch_idx = global_idx / uniforms.num_slices_per_batch;
|
||||
let base_offset = batch_idx * uniforms.input_batch_stride;
|
||||
|
||||
let slice_indices_base_offset = global_idx * uniforms.num_slice_dims;
|
||||
var relative_slice_offset = 0;
|
||||
for (var dim_idx = 0u; dim_idx < uniforms.num_slice_dims; dim_idx ++) {
|
||||
var index = i32(indices_data[dim_idx + slice_indices_base_offset].x);
|
||||
let input_dim_idx = uniforms.batch_dims + dim_idx;
|
||||
if (index < 0) {
|
||||
${
|
||||
inputDims.length === 1
|
||||
? 'index += i32(uniforms.input_dims);'
|
||||
: 'index += i32(uniforms.input_dims[input_dim_idx]);'
|
||||
}
|
||||
}
|
||||
${
|
||||
sizesFromSliceDimsData.length === 1
|
||||
? 'relative_slice_offset += index * i32(uniforms.sizes_from_slice_dims_data);'
|
||||
: 'relative_slice_offset += index * i32(uniforms.sizes_from_slice_dims_data[dim_idx]);'
|
||||
}
|
||||
}
|
||||
|
||||
input_slice_offsets_data[global_idx] = base_offset + u32(relative_slice_offset);
|
||||
}`;
|
||||
};
|
||||
|
||||
return context.compute(
|
||||
{
|
||||
name: 'computeSliceOffsets',
|
||||
shaderCache: { hint: `${inputDims.length}_${sizesFromSliceDimsData.length}`, inputDependencies: ['rank'] },
|
||||
getRunData: () => ({
|
||||
outputs: [{ dims: outputShape, dataType: context.inputs[1].dataType }],
|
||||
dispatchGroup: { x: Math.ceil(numSlices / 64) },
|
||||
programUniforms,
|
||||
}),
|
||||
getShaderSource,
|
||||
},
|
||||
{ inputs: [indicesData], outputs: [-1] },
|
||||
)[0];
|
||||
};
|
||||
|
||||
export const gatherND = (context: ComputeContext, attributes: GatherNDAttributes) => {
|
||||
const inputs = context.inputs;
|
||||
const inputShape = inputs[0].dims;
|
||||
const inputType = inputs[0].dataType;
|
||||
const indicesShape = inputs[1].dims;
|
||||
const numSliceDims = indicesShape[indicesShape.length - 1];
|
||||
const numSlices = ShapeUtil.sizeToDimension(indicesShape, indicesShape.length - 1);
|
||||
const sliceSize = ShapeUtil.sizeFromDimension(inputShape, attributes.batchDims + numSliceDims);
|
||||
const numBatches = ShapeUtil.sizeToDimension(inputShape, attributes.batchDims);
|
||||
const inputBatchStride = ShapeUtil.sizeFromDimension(inputShape, attributes.batchDims);
|
||||
const numSlicesPerBatch = numSlices / numBatches;
|
||||
const sizesFromSliceDims = new Array(numSliceDims);
|
||||
let runningProduct = sliceSize;
|
||||
for (let i = 0; i < numSliceDims; ++i) {
|
||||
sizesFromSliceDims[numSliceDims - 1 - i] = runningProduct;
|
||||
runningProduct *= inputShape[attributes.batchDims + numSliceDims - 1 - i];
|
||||
}
|
||||
|
||||
const inputSliceOffsets = computeSliceOffsets(
|
||||
context,
|
||||
inputs[1],
|
||||
sizesFromSliceDims,
|
||||
attributes.batchDims,
|
||||
inputShape,
|
||||
numSlices,
|
||||
numSlicesPerBatch,
|
||||
inputBatchStride,
|
||||
numSliceDims,
|
||||
);
|
||||
|
||||
const lastIndicesDimension = attributes.batchDims + numSliceDims;
|
||||
if (lastIndicesDimension > inputShape.length) {
|
||||
throw new Error('last dimension of indices must not be larger than rank of input tensor');
|
||||
}
|
||||
|
||||
const outputShape = indicesShape.slice(0, -1).concat(inputShape.slice(lastIndicesDimension));
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{ type: DataType.uint32, data: outputSize },
|
||||
{ type: DataType.uint32, data: sliceSize },
|
||||
...createTensorShapeVariables(inputs[0].dims, inputSliceOffsets.dims, outputShape),
|
||||
];
|
||||
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
const input = inputVariable('data', inputs[0].dataType, inputs[0].dims.length);
|
||||
const indices = inputVariable('slice_offsets', DataType.uint32, inputSliceOffsets.dims.length);
|
||||
|
||||
const output = outputVariable('output', inputs[0].dataType, outputShape.length);
|
||||
return `
|
||||
${shaderHelper
|
||||
.registerUniform('output_size', 'u32')
|
||||
.registerUniform('slice_size', 'u32')
|
||||
.declareVariables(input, indices, output)}
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
|
||||
let slice_offset = slice_offsets[global_idx / uniforms.slice_size];
|
||||
output[global_idx] = data[u32(slice_offset) + global_idx % uniforms.slice_size];
|
||||
}`;
|
||||
};
|
||||
context.compute(
|
||||
{
|
||||
name: 'GatherND',
|
||||
shaderCache: { hint: attributes.cacheKey, inputDependencies: ['rank', 'rank'] },
|
||||
getRunData: () => ({
|
||||
outputs: [{ dims: outputShape, dataType: inputType }],
|
||||
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
|
||||
programUniforms,
|
||||
}),
|
||||
getShaderSource,
|
||||
},
|
||||
{ inputs: [inputs[0], inputSliceOffsets] },
|
||||
);
|
||||
};
|
||||
|
||||
export const parseGatherNDAttributes = (attributes: Record<string, unknown>): GatherNDAttributes => {
|
||||
const batchDims = attributes.batch_dims as number;
|
||||
return {
|
||||
batchDims,
|
||||
cacheKey: '',
|
||||
};
|
||||
};
|
||||
147
js/web/test/data/ops/gather-nd.jsonc
Normal file
147
js/web/test/data/ops/gather-nd.jsonc
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
[
|
||||
{
|
||||
"name": "GatherND int32",
|
||||
"operator": "GatherND",
|
||||
"attributes": [],
|
||||
"cases": [
|
||||
{
|
||||
"name": "data[4] indices[]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [100, 101, 102, 777, 778, 779, 1000, 1001, 1002],
|
||||
"dims": [9],
|
||||
"type": "int32"
|
||||
},
|
||||
{
|
||||
"data": [0, 4, 8],
|
||||
"dims": [3, 1],
|
||||
"type": "int64"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [100, 778, 1002],
|
||||
"dims": [3],
|
||||
"type": "int32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "GatherND float32",
|
||||
"operator": "GatherND",
|
||||
"attributes": [],
|
||||
"cases": [
|
||||
{
|
||||
"name": "data[4] indices[]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [100.1, 101.2, 102.3, 777.4, 778.5, 779.6, 1000.7, 1001.8, 1002.9],
|
||||
"dims": [9],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [0, 4, 8],
|
||||
"dims": [3, 1],
|
||||
"type": "int64"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [100.0999984741211, 778.5, 1002.9000244140625],
|
||||
"dims": [3],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "GatherND int32 [2 2 2], batch_dims",
|
||||
"operator": "GatherND",
|
||||
"attributes": [{ "name": "batch_dims", "data": 1, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "data[4] indices[]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [0, 1, 2, 3, 4, 5, 6, 7],
|
||||
"dims": [2, 2, 2],
|
||||
"type": "int32"
|
||||
},
|
||||
{
|
||||
"data": [1, 0],
|
||||
"dims": [2, 1],
|
||||
"type": "int64"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [2, 3, 4, 5],
|
||||
"dims": [2, 2],
|
||||
"type": "int32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "GatherND float16",
|
||||
"operator": "GatherND",
|
||||
"attributes": [],
|
||||
"cases": [
|
||||
{
|
||||
"name": "data[4] indices[]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [100.1, 101.2, 102.3, 777.4, 778.5, 779.6, 1000.7, 1001.8, 1002.9],
|
||||
"dims": [9],
|
||||
"type": "float16"
|
||||
},
|
||||
{
|
||||
"data": [0, 4, 8],
|
||||
"dims": [3, 1],
|
||||
"type": "int64"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [100.0999984741211, 778.5, 1002.9000244140625],
|
||||
"dims": [3],
|
||||
"type": "float16"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "GatherND uint32 [2 2 2], batch_dims",
|
||||
"operator": "GatherND",
|
||||
"attributes": [{ "name": "batch_dims", "data": 1, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "data[4] indices[]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [0, 1, 2, 3, 4, 5, 6, 7],
|
||||
"dims": [2, 2, 2],
|
||||
"type": "uint32"
|
||||
},
|
||||
{
|
||||
"data": [1, 0],
|
||||
"dims": [2, 1],
|
||||
"type": "int64"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [2, 3, 4, 5],
|
||||
"dims": [2, 2],
|
||||
"type": "uint32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
|
@ -1365,6 +1365,7 @@
|
|||
"gather.jsonc",
|
||||
"gather-block-quantized.jsonc",
|
||||
"gather-elements.jsonc",
|
||||
"gather-nd.jsonc",
|
||||
"gemm.jsonc",
|
||||
"global-average-pool.jsonc",
|
||||
"greater.jsonc",
|
||||
|
|
|
|||
|
|
@ -341,6 +341,10 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gat
|
|||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, GatherElements);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, GatherElements);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, GatherND);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 12, GatherND);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, GatherND);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 9, Slice);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, Slice);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Slice);
|
||||
|
|
@ -667,6 +671,10 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, GatherElements)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, GatherElements)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, GatherND)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 12, GatherND)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, GatherND)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, Resize)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Resize)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Resize)>,
|
||||
|
|
|
|||
41
onnxruntime/core/providers/js/operators/gather_nd.cc
Normal file
41
onnxruntime/core/providers/js/operators/gather_nd.cc
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/js/js_kernel.h"
|
||||
#include "core/providers/js/js_data_types.h"
|
||||
#include "gather_nd.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace js {
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
GatherND,
|
||||
kOnnxDomain,
|
||||
13,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", JsepSupportedDataTypes()),
|
||||
GatherND);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
GatherND,
|
||||
kOnnxDomain,
|
||||
12,
|
||||
12,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", JsepSupportedDataTypes()),
|
||||
GatherND);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
GatherND,
|
||||
kOnnxDomain,
|
||||
11,
|
||||
11,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", JsepSupportedDataTypes()),
|
||||
GatherND);
|
||||
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
24
onnxruntime/core/providers/js/operators/gather_nd.h
Normal file
24
onnxruntime/core/providers/js/operators/gather_nd.h
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/js/js_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace js {
|
||||
|
||||
class GatherND : public JsKernel {
|
||||
public:
|
||||
GatherND(const OpKernelInfo& info) : JsKernel(info) {
|
||||
int64_t batchDims = info.GetAttrOrDefault<int64_t>("batch_dims", 0);
|
||||
|
||||
JSEP_INIT_KERNEL_ATTRIBUTE(GatherND, ({
|
||||
"batch_dims" : Number($1),
|
||||
}),
|
||||
static_cast<int32_t>(batchDims));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue