diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 57cdfb0021..85848753f7 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -58,6 +58,7 @@ Do not modify directly.* | Sigmoid | ai.onnx(6-12,13+) | | | Sin | ai.onnx(7+) | | | Sinh | ai.onnx(9+) | | +| Split | ai.onnx(1,2-10,11-12,13-17,18+) | | | Sqrt | ai.onnx(6-12,13+) | | | Squeeze | ai.onnx(1-10,11-12,13+) | | | Sub | ai.onnx(7-12,13,14+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index b9fc9d26fb..34fc8c0ae9 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -8,6 +8,7 @@ import {gemm, parseGemmAttributes} from './ops/gemm'; import {matMul} from './ops/matmul'; import * as pool from './ops/pool'; import {parseReduceAttributes, reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce'; +import {parseSplitAttributes, split} from './ops/split'; import {parseTransposeAttributes, transpose} from './ops/transpose'; import * as unaryOps from './ops/unary-op'; import {ComputeContext} from './types'; @@ -64,6 +65,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Sigmoid', [unaryOps.sigmoid]], ['Sin', [unaryOps.sin]], ['Sinh', [unaryOps.sinh]], + ['Split', [split, parseSplitAttributes]], ['Sqrt', [unaryOps.sqrt]], ['Sub', [binaryOps.sub]], ['Tan', [unaryOps.tan]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts new file mode 100644 index 0000000000..82668ce45c --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -0,0 +1,138 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TensorView} from '../../tensor'; +import {ShapeUtil} from '../../util'; +import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata, TensorInfo} from '../types'; + +import {createIndicesHelper, IndicesHelper, ShaderHelper} from './common'; + +export interface SplitAttributes extends AttributeWithCacheKey { + readonly axis: number; + readonly numOutputs: number; + readonly splitSizes: number[]; +} + +const validateInputs = (inputs: readonly TensorView[]): void => { + if (!inputs || inputs.length < 1) { + throw new Error('too few inputs'); + } +}; + +const createSplitAttributesFromInputs = + (inputs: readonly TensorView[], attributes: SplitAttributes): SplitAttributes => { + const splitSizes: number[] = []; + if (inputs[1].dims[0] > 0) { + inputs[1].getBigInt64Array().forEach(v => splitSizes.push(Number(v))); + } + return createAttributeWithCacheKey({numOutputs: attributes.numOutputs, axis: attributes.axis, splitSizes}); + }; + +const calculateOutputIndexImpl = (numberOfTensors: number): string => ` +fn calculateOutputIndex(index: u32) -> u32 { + for (var i: u32 = 0u; i < ${numberOfTensors}u; i += 1u ) { + if (index < sizeInConcatAxis[i]) { + return i; + } + } + return ${numberOfTensors}u; +}`; +const writeBufferDataImpl = (indicesHelper: readonly IndicesHelper[]) => { + const numberOfTensors = indicesHelper.length; + const codeLines: string[] = []; + for (let i = 0; i < numberOfTensors; ++i) { + const returnSnippet = `output${i}[${indicesHelper[i].i2oExpression('indices', true)}] = input[global_idx];`; + if (numberOfTensors === 1) { + codeLines.push(returnSnippet); + } else if (i === 0) { + codeLines.push(`if (outputNumber == ${i}u) { ${returnSnippet} }`); + } else if (i === numberOfTensors - 1) { + codeLines.push(`else { ${returnSnippet} }`); + } else { + codeLines.push(`else if (outputNumber == ${i}) { ${returnSnippet} }`); + } + } + return ` + fn writeBufferData(outputNumber: u32, indices: ptr, global_idx: u32) { + ${codeLines.join('\n')} + }`; +}; + +const createSplitProgramInfo = + (metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: SplitAttributes, dataType = 'f32'): + ProgramInfo => { + const inputShape = inputs[0].dims; + const inputSize = ShapeUtil.size(inputShape); + const rank = inputShape.length; + const axis = attributes.axis; + const adjustedAxis = (axis < 0) ? inputShape.length + axis : axis; + const outputStorageBuffersDeclarations = new Array(attributes.numOutputs); + const outputIndicesHelpers = new Array(attributes.numOutputs); + const inputIndicesHelper = createIndicesHelper('input', inputShape); + const sizeInConcatAxis = new Array(attributes.numOutputs); + const outputs: TensorInfo[] = []; + const outputShapes: number[][] = []; + let previousSum = 0; + for (let i = 0; i < attributes.numOutputs; i++) { + previousSum += attributes.splitSizes[i]; + sizeInConcatAxis[i] = previousSum; + outputStorageBuffersDeclarations[i] = + `@group(0) @binding(${i + 1}) var output${i} : array<${dataType}>;`; + const outputShape = inputShape.slice(); + outputShape[attributes.axis] = attributes.splitSizes[i]; + outputShapes.push(outputShape); + outputIndicesHelpers[i] = createIndicesHelper(`output${i}`, outputShapes[i]); + outputs.push({dims: outputShapes[i], dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}); + } + const indicesAxis = rank < 2 ? 'indices' : `indices[${adjustedAxis}]`; + const getShaderSource = (shaderHelper: ShaderHelper) => ` + @group(0) @binding(0) var input : array<${dataType}>; + ${outputStorageBuffersDeclarations.join('\n')} + ${inputIndicesHelper.o2iImpl} + ${outputIndicesHelpers.map(o => o.i2oImpl).join('\n')} + const sizeInConcatAxis = array(${sizeInConcatAxis.map(i => `${i}u`).join(',')}); + ${calculateOutputIndexImpl(sizeInConcatAxis.length)} + ${writeBufferDataImpl(outputIndicesHelpers)} + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(inputSize)} + + ${inputIndicesHelper.indicesVariableDeclaration('indices')} + ${inputIndicesHelper.o2iCall('global_idx', 'indices')} + let outputNumber = calculateOutputIndex(${indicesAxis}); + if (outputNumber != 0) { + ${indicesAxis} -= sizeInConcatAxis[outputNumber - 1u]; + } + writeBufferData(outputNumber, &indices, global_idx); + }`; + return { + ...metadata, + getShaderSource, + outputs, + dispatchGroup: () => ({x: Math.ceil(inputSize / 64 /* workgroup size */)}) + }; + }; + +const createSplitProgramInfoLoader = + (inputs: readonly TensorView[], attributes: SplitAttributes): ProgramInfoLoader => { + const updatedAttributes = inputs.length === 1 ? attributes : createSplitAttributesFromInputs(inputs, attributes); + const metadata: + ProgramMetadata = {name: 'Split', inputTypes: [GpuDataType.default], cacheHint: updatedAttributes.cacheKey}; + return {...metadata, get: () => createSplitProgramInfo(metadata, [inputs[0]], attributes)}; + }; + +export const split = (context: ComputeContext, attributes: SplitAttributes): void => { + validateInputs(context.inputs); + context.compute(createSplitProgramInfoLoader(context.inputs, attributes), {inputs: [0]}); +}; + +export const parseSplitAttributes = (attributes: Record): SplitAttributes => { + const axis = attributes.axis as number; + const splitSizes: number[] = attributes.splitSizes as number[]; + const numOutputs = attributes.numOutputs as number < 0 ? splitSizes.length : attributes.numOutputs as number; + if (numOutputs !== splitSizes.length) { + throw new Error('numOutputs and splitSizes lengh must be equal'); + } + return createAttributeWithCacheKey({axis, numOutputs, splitSizes}); +}; diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 17f0b1870c..b818d4b5cb 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1217,12 +1217,12 @@ // // "test_softsign", // "test_spacetodepth_example", // "test_spacetodepth", - // // "test_split_equal_parts_1d", - // // "test_split_equal_parts_2d", - // // "test_split_equal_parts_default_axis", - // // "test_split_variable_parts_1d", - // // "test_split_variable_parts_2d", - // // "test_split_variable_parts_default_axis", + "test_split_equal_parts_1d", + "test_split_equal_parts_2d", + "test_split_equal_parts_default_axis", + "test_split_variable_parts_1d", + "test_split_variable_parts_2d", + "test_split_variable_parts_default_axis", // // "test_split_zero_size_splits", "test_sqrt_example", "test_sqrt", diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index b09ccc3ae6..0b8431edf9 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -235,6 +235,12 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Concat); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Concat); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 1, Split); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 2, 10, Split); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Split); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Split); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, Split); + std::unique_ptr RegisterKernels() { auto kernel_registry = std::make_unique(); @@ -406,6 +412,12 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/js/operators/split.cc b/onnxruntime/core/providers/js/operators/split.cc new file mode 100644 index 0000000000..e5bb89ade1 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/split.cc @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "split.h" + +namespace onnxruntime { +namespace js { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Split, + kOnnxDomain, + 1, 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPU, 1), + Split_1); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Split, + kOnnxDomain, + 2, 10, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + Split_2_10); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Split, + kOnnxDomain, + 11, 12, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + Split_11_12); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Split, + kOnnxDomain, + 13, 17, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPU, 1), + Split_13_17); + +ONNX_OPERATOR_KERNEL_EX( + Split, + kOnnxDomain, + 18, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPU, 1), + Split_18); + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/split.h b/onnxruntime/core/providers/js/operators/split.h new file mode 100644 index 0000000000..5bd08221ea --- /dev/null +++ b/onnxruntime/core/providers/js/operators/split.h @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/tensor.h" +#include "core/providers/cpu/tensor/split.h" +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace js { + +class Split : public JsKernel, public SplitBase { + public: + Split(const OpKernelInfo& info, uint32_t opset) : JsKernel(info), SplitBase(info, opset) { + std::vector split_sizes; + if (split_sizes_.size() > 0) { + ORT_ENFORCE(split_sizes_.size() == info.node().OutputDefs().size(), + "Number of outputs (", info.node().OutputDefs().size(), ") does not match split_sizes (", + split_sizes_.size(), ")"); + split_sizes.resize(split_sizes_.size()); + for (size_t i = 0; i < split_sizes_.size(); ++i) { + split_sizes[i] = gsl::narrow_cast(split_sizes_[i]); + } + if (num_outputs_ < 0) { + num_outputs_ = split_sizes.size(); + } + } else if (split_sizes_.size() == 0) { + // Compute split_sizes from input shape and num_outputs + auto total_split_size = info.node().InputDefs()[0]->Shape()->dim(axis_).dim_value(); + int64_t split_size_sum = 0; + if (num_outputs_ < 0) { + num_outputs_ = info.node().OutputDefs().size(); + } else { + ORT_ENFORCE(num_outputs_ == info.node().OutputDefs().size(), + "Number of outputs (", info.node().OutputDefs().size(), ") does not match num_outputs (", + num_outputs_, ")"); + } + for (auto output : info.node().OutputDefs()) { + auto split_size = output->Shape()->dim(axis_).dim_value(); + split_sizes.push_back(gsl::narrow_cast(split_size)); + split_size_sum += split_size; + } + ORT_ENFORCE(split_size_sum == total_split_size, + "Sum of split sizes (", split_size_sum, ") does not match input size (", total_split_size, ")"); + } + + JSEP_INIT_KERNEL_ATTRIBUTE(Split, ({"axis" : $1, + "numOutputs" : $2, + "splitSizes" : $3 ? Array.from(HEAP32.subarray($4, $4 + $3)) : []}), + static_cast(axis_), + static_cast(num_outputs_), + gsl::narrow_cast(split_sizes.size()), + reinterpret_cast((!split_sizes.empty() > 0) ? split_sizes.data() : nullptr) >> 2); + } +}; + +class Split_1 final : public Split { + public: + Split_1(const OpKernelInfo& info) : Split(info, 1) {} +}; + +class Split_2_10 final : public Split { + public: + Split_2_10(const OpKernelInfo& info) : Split(info, 2) {} +}; + +class Split_11_12 final : public Split { + public: + Split_11_12(const OpKernelInfo& info) : Split(info, 11) {} +}; + +class Split_13_17 final : public Split { + public: + Split_13_17(const OpKernelInfo& info) : Split(info, 13) {} +}; + +class Split_18 final : public Split { + public: + Split_18(const OpKernelInfo& info) : Split(info, 18) {} +}; + +} // namespace js +} // namespace onnxruntime