mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-28 22:56:32 +00:00
[Web/JS] Added Split operator support. (#16567)
### Description Added WeGPU/JSEP Split operator support. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
dd13252506
commit
e55a20ece8
7 changed files with 301 additions and 6 deletions
|
|
@ -58,6 +58,7 @@ Do not modify directly.*
|
|||
| Sigmoid | ai.onnx(6-12,13+) | |
|
||||
| Sin | ai.onnx(7+) | |
|
||||
| Sinh | ai.onnx(9+) | |
|
||||
| Split | ai.onnx(1,2-10,11-12,13-17,18+) | |
|
||||
| Sqrt | ai.onnx(6-12,13+) | |
|
||||
| Squeeze | ai.onnx(1-10,11-12,13+) | |
|
||||
| Sub | ai.onnx(7-12,13,14+) | |
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import {gemm, parseGemmAttributes} from './ops/gemm';
|
|||
import {matMul} from './ops/matmul';
|
||||
import * as pool from './ops/pool';
|
||||
import {parseReduceAttributes, reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce';
|
||||
import {parseSplitAttributes, split} from './ops/split';
|
||||
import {parseTransposeAttributes, transpose} from './ops/transpose';
|
||||
import * as unaryOps from './ops/unary-op';
|
||||
import {ComputeContext} from './types';
|
||||
|
|
@ -64,6 +65,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
|
|||
['Sigmoid', [unaryOps.sigmoid]],
|
||||
['Sin', [unaryOps.sin]],
|
||||
['Sinh', [unaryOps.sinh]],
|
||||
['Split', [split, parseSplitAttributes]],
|
||||
['Sqrt', [unaryOps.sqrt]],
|
||||
['Sub', [binaryOps.sub]],
|
||||
['Tan', [unaryOps.tan]],
|
||||
|
|
|
|||
138
js/web/lib/wasm/jsep/webgpu/ops/split.ts
Normal file
138
js/web/lib/wasm/jsep/webgpu/ops/split.ts
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {TensorView} from '../../tensor';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata, TensorInfo} from '../types';
|
||||
|
||||
import {createIndicesHelper, IndicesHelper, ShaderHelper} from './common';
|
||||
|
||||
export interface SplitAttributes extends AttributeWithCacheKey {
|
||||
readonly axis: number;
|
||||
readonly numOutputs: number;
|
||||
readonly splitSizes: number[];
|
||||
}
|
||||
|
||||
const validateInputs = (inputs: readonly TensorView[]): void => {
|
||||
if (!inputs || inputs.length < 1) {
|
||||
throw new Error('too few inputs');
|
||||
}
|
||||
};
|
||||
|
||||
const createSplitAttributesFromInputs =
|
||||
(inputs: readonly TensorView[], attributes: SplitAttributes): SplitAttributes => {
|
||||
const splitSizes: number[] = [];
|
||||
if (inputs[1].dims[0] > 0) {
|
||||
inputs[1].getBigInt64Array().forEach(v => splitSizes.push(Number(v)));
|
||||
}
|
||||
return createAttributeWithCacheKey({numOutputs: attributes.numOutputs, axis: attributes.axis, splitSizes});
|
||||
};
|
||||
|
||||
const calculateOutputIndexImpl = (numberOfTensors: number): string => `
|
||||
fn calculateOutputIndex(index: u32) -> u32 {
|
||||
for (var i: u32 = 0u; i < ${numberOfTensors}u; i += 1u ) {
|
||||
if (index < sizeInConcatAxis[i]) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return ${numberOfTensors}u;
|
||||
}`;
|
||||
const writeBufferDataImpl = (indicesHelper: readonly IndicesHelper[]) => {
|
||||
const numberOfTensors = indicesHelper.length;
|
||||
const codeLines: string[] = [];
|
||||
for (let i = 0; i < numberOfTensors; ++i) {
|
||||
const returnSnippet = `output${i}[${indicesHelper[i].i2oExpression('indices', true)}] = input[global_idx];`;
|
||||
if (numberOfTensors === 1) {
|
||||
codeLines.push(returnSnippet);
|
||||
} else if (i === 0) {
|
||||
codeLines.push(`if (outputNumber == ${i}u) { ${returnSnippet} }`);
|
||||
} else if (i === numberOfTensors - 1) {
|
||||
codeLines.push(`else { ${returnSnippet} }`);
|
||||
} else {
|
||||
codeLines.push(`else if (outputNumber == ${i}) { ${returnSnippet} }`);
|
||||
}
|
||||
}
|
||||
return `
|
||||
fn writeBufferData(outputNumber: u32, indices: ptr<function, ${indicesHelper[0].iType}>, global_idx: u32) {
|
||||
${codeLines.join('\n')}
|
||||
}`;
|
||||
};
|
||||
|
||||
const createSplitProgramInfo =
|
||||
(metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: SplitAttributes, dataType = 'f32'):
|
||||
ProgramInfo => {
|
||||
const inputShape = inputs[0].dims;
|
||||
const inputSize = ShapeUtil.size(inputShape);
|
||||
const rank = inputShape.length;
|
||||
const axis = attributes.axis;
|
||||
const adjustedAxis = (axis < 0) ? inputShape.length + axis : axis;
|
||||
const outputStorageBuffersDeclarations = new Array<string>(attributes.numOutputs);
|
||||
const outputIndicesHelpers = new Array<IndicesHelper>(attributes.numOutputs);
|
||||
const inputIndicesHelper = createIndicesHelper('input', inputShape);
|
||||
const sizeInConcatAxis = new Array<number>(attributes.numOutputs);
|
||||
const outputs: TensorInfo[] = [];
|
||||
const outputShapes: number[][] = [];
|
||||
let previousSum = 0;
|
||||
for (let i = 0; i < attributes.numOutputs; i++) {
|
||||
previousSum += attributes.splitSizes[i];
|
||||
sizeInConcatAxis[i] = previousSum;
|
||||
outputStorageBuffersDeclarations[i] =
|
||||
`@group(0) @binding(${i + 1}) var<storage, read_write> output${i} : array<${dataType}>;`;
|
||||
const outputShape = inputShape.slice();
|
||||
outputShape[attributes.axis] = attributes.splitSizes[i];
|
||||
outputShapes.push(outputShape);
|
||||
outputIndicesHelpers[i] = createIndicesHelper(`output${i}`, outputShapes[i]);
|
||||
outputs.push({dims: outputShapes[i], dataType: inputs[0].dataType, gpuDataType: GpuDataType.default});
|
||||
}
|
||||
const indicesAxis = rank < 2 ? 'indices' : `indices[${adjustedAxis}]`;
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
@group(0) @binding(0) var<storage, read> input : array<${dataType}>;
|
||||
${outputStorageBuffersDeclarations.join('\n')}
|
||||
${inputIndicesHelper.o2iImpl}
|
||||
${outputIndicesHelpers.map(o => o.i2oImpl).join('\n')}
|
||||
const sizeInConcatAxis = array<u32, ${sizeInConcatAxis.length}>(${sizeInConcatAxis.map(i => `${i}u`).join(',')});
|
||||
${calculateOutputIndexImpl(sizeInConcatAxis.length)}
|
||||
${writeBufferDataImpl(outputIndicesHelpers)}
|
||||
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(inputSize)}
|
||||
|
||||
${inputIndicesHelper.indicesVariableDeclaration('indices')}
|
||||
${inputIndicesHelper.o2iCall('global_idx', 'indices')}
|
||||
let outputNumber = calculateOutputIndex(${indicesAxis});
|
||||
if (outputNumber != 0) {
|
||||
${indicesAxis} -= sizeInConcatAxis[outputNumber - 1u];
|
||||
}
|
||||
writeBufferData(outputNumber, &indices, global_idx);
|
||||
}`;
|
||||
return {
|
||||
...metadata,
|
||||
getShaderSource,
|
||||
outputs,
|
||||
dispatchGroup: () => ({x: Math.ceil(inputSize / 64 /* workgroup size */)})
|
||||
};
|
||||
};
|
||||
|
||||
const createSplitProgramInfoLoader =
|
||||
(inputs: readonly TensorView[], attributes: SplitAttributes): ProgramInfoLoader => {
|
||||
const updatedAttributes = inputs.length === 1 ? attributes : createSplitAttributesFromInputs(inputs, attributes);
|
||||
const metadata:
|
||||
ProgramMetadata = {name: 'Split', inputTypes: [GpuDataType.default], cacheHint: updatedAttributes.cacheKey};
|
||||
return {...metadata, get: () => createSplitProgramInfo(metadata, [inputs[0]], attributes)};
|
||||
};
|
||||
|
||||
export const split = (context: ComputeContext, attributes: SplitAttributes): void => {
|
||||
validateInputs(context.inputs);
|
||||
context.compute(createSplitProgramInfoLoader(context.inputs, attributes), {inputs: [0]});
|
||||
};
|
||||
|
||||
export const parseSplitAttributes = (attributes: Record<string, unknown>): SplitAttributes => {
|
||||
const axis = attributes.axis as number;
|
||||
const splitSizes: number[] = attributes.splitSizes as number[];
|
||||
const numOutputs = attributes.numOutputs as number < 0 ? splitSizes.length : attributes.numOutputs as number;
|
||||
if (numOutputs !== splitSizes.length) {
|
||||
throw new Error('numOutputs and splitSizes lengh must be equal');
|
||||
}
|
||||
return createAttributeWithCacheKey({axis, numOutputs, splitSizes});
|
||||
};
|
||||
|
|
@ -1217,12 +1217,12 @@
|
|||
// // "test_softsign",
|
||||
// "test_spacetodepth_example",
|
||||
// "test_spacetodepth",
|
||||
// // "test_split_equal_parts_1d",
|
||||
// // "test_split_equal_parts_2d",
|
||||
// // "test_split_equal_parts_default_axis",
|
||||
// // "test_split_variable_parts_1d",
|
||||
// // "test_split_variable_parts_2d",
|
||||
// // "test_split_variable_parts_default_axis",
|
||||
"test_split_equal_parts_1d",
|
||||
"test_split_equal_parts_2d",
|
||||
"test_split_equal_parts_default_axis",
|
||||
"test_split_variable_parts_1d",
|
||||
"test_split_variable_parts_2d",
|
||||
"test_split_variable_parts_default_axis",
|
||||
// // "test_split_zero_size_splits",
|
||||
"test_sqrt_example",
|
||||
"test_sqrt",
|
||||
|
|
|
|||
|
|
@ -235,6 +235,12 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
|
|||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Concat);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Concat);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 1, Split);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 2, 10, Split);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Split);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Split);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, Split);
|
||||
|
||||
std::unique_ptr<KernelRegistry> RegisterKernels() {
|
||||
auto kernel_registry = std::make_unique<onnxruntime::KernelRegistry>();
|
||||
|
||||
|
|
@ -406,6 +412,12 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 4, 10, Concat)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Concat)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Concat)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 1, Split)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 2, 10, Split)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Split)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Split)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, Split)>,
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
|
|
|
|||
58
onnxruntime/core/providers/js/operators/split.cc
Normal file
58
onnxruntime/core/providers/js/operators/split.cc
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "split.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace js {
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Split,
|
||||
kOnnxDomain,
|
||||
1, 1,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
|
||||
.InputMemoryType(OrtMemTypeCPU, 1),
|
||||
Split_1);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Split,
|
||||
kOnnxDomain,
|
||||
2, 10,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Split_2_10);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Split,
|
||||
kOnnxDomain,
|
||||
11, 12,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Split_11_12);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Split,
|
||||
kOnnxDomain,
|
||||
13, 17,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
|
||||
.InputMemoryType(OrtMemTypeCPU, 1),
|
||||
Split_13_17);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
Split,
|
||||
kOnnxDomain,
|
||||
18,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
|
||||
.InputMemoryType(OrtMemTypeCPU, 1),
|
||||
Split_18);
|
||||
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
84
onnxruntime/core/providers/js/operators/split.h
Normal file
84
onnxruntime/core/providers/js/operators/split.h
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/framework/tensor.h"
|
||||
#include "core/providers/cpu/tensor/split.h"
|
||||
#include "core/providers/js/js_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace js {
|
||||
|
||||
class Split : public JsKernel, public SplitBase {
|
||||
public:
|
||||
Split(const OpKernelInfo& info, uint32_t opset) : JsKernel(info), SplitBase(info, opset) {
|
||||
std::vector<int32_t> split_sizes;
|
||||
if (split_sizes_.size() > 0) {
|
||||
ORT_ENFORCE(split_sizes_.size() == info.node().OutputDefs().size(),
|
||||
"Number of outputs (", info.node().OutputDefs().size(), ") does not match split_sizes (",
|
||||
split_sizes_.size(), ")");
|
||||
split_sizes.resize(split_sizes_.size());
|
||||
for (size_t i = 0; i < split_sizes_.size(); ++i) {
|
||||
split_sizes[i] = gsl::narrow_cast<int32_t>(split_sizes_[i]);
|
||||
}
|
||||
if (num_outputs_ < 0) {
|
||||
num_outputs_ = split_sizes.size();
|
||||
}
|
||||
} else if (split_sizes_.size() == 0) {
|
||||
// Compute split_sizes from input shape and num_outputs
|
||||
auto total_split_size = info.node().InputDefs()[0]->Shape()->dim(axis_).dim_value();
|
||||
int64_t split_size_sum = 0;
|
||||
if (num_outputs_ < 0) {
|
||||
num_outputs_ = info.node().OutputDefs().size();
|
||||
} else {
|
||||
ORT_ENFORCE(num_outputs_ == info.node().OutputDefs().size(),
|
||||
"Number of outputs (", info.node().OutputDefs().size(), ") does not match num_outputs (",
|
||||
num_outputs_, ")");
|
||||
}
|
||||
for (auto output : info.node().OutputDefs()) {
|
||||
auto split_size = output->Shape()->dim(axis_).dim_value();
|
||||
split_sizes.push_back(gsl::narrow_cast<int32_t>(split_size));
|
||||
split_size_sum += split_size;
|
||||
}
|
||||
ORT_ENFORCE(split_size_sum == total_split_size,
|
||||
"Sum of split sizes (", split_size_sum, ") does not match input size (", total_split_size, ")");
|
||||
}
|
||||
|
||||
JSEP_INIT_KERNEL_ATTRIBUTE(Split, ({"axis" : $1,
|
||||
"numOutputs" : $2,
|
||||
"splitSizes" : $3 ? Array.from(HEAP32.subarray($4, $4 + $3)) : []}),
|
||||
static_cast<int32_t>(axis_),
|
||||
static_cast<int32_t>(num_outputs_),
|
||||
gsl::narrow_cast<int32_t>(split_sizes.size()),
|
||||
reinterpret_cast<int32_t>((!split_sizes.empty() > 0) ? split_sizes.data() : nullptr) >> 2);
|
||||
}
|
||||
};
|
||||
|
||||
class Split_1 final : public Split {
|
||||
public:
|
||||
Split_1(const OpKernelInfo& info) : Split(info, 1) {}
|
||||
};
|
||||
|
||||
class Split_2_10 final : public Split {
|
||||
public:
|
||||
Split_2_10(const OpKernelInfo& info) : Split(info, 2) {}
|
||||
};
|
||||
|
||||
class Split_11_12 final : public Split {
|
||||
public:
|
||||
Split_11_12(const OpKernelInfo& info) : Split(info, 11) {}
|
||||
};
|
||||
|
||||
class Split_13_17 final : public Split {
|
||||
public:
|
||||
Split_13_17(const OpKernelInfo& info) : Split(info, 13) {}
|
||||
};
|
||||
|
||||
class Split_18 final : public Split {
|
||||
public:
|
||||
Split_18(const OpKernelInfo& info) : Split(info, 18) {}
|
||||
};
|
||||
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue