From 03ce0a56935f512d7acfb8a4217d21187a658ad8 Mon Sep 17 00:00:00 2001 From: satyajandhyala Date: Tue, 25 Jul 2023 14:19:20 -0700 Subject: [PATCH] [Web/JS] Added Slice operator in JSEP. (#16811) ### Description Added Slice operator support to JSEP. ### Motivation and Context --- js/web/docs/webgpu-operators.md | 1 + js/web/lib/wasm/jsep/init.ts | 8 + js/web/lib/wasm/jsep/tensor.ts | 5 + .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 2 + js/web/lib/wasm/jsep/webgpu/ops/slice.ts | 201 ++++++++++++++++++ js/web/test/suite-test-list.jsonc | 16 +- .../providers/js/js_execution_provider.cc | 10 + .../core/providers/js/operators/slice.cc | 58 +++++ .../core/providers/js/operators/slice.h | 40 ++++ .../core/providers/js/operators/split.h | 2 +- 10 files changed, 334 insertions(+), 9 deletions(-) create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/slice.ts create mode 100644 onnxruntime/core/providers/js/operators/slice.cc create mode 100644 onnxruntime/core/providers/js/operators/slice.h diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index f6b82b4ac7..7eea38db17 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -60,6 +60,7 @@ Do not modify directly.* | Sigmoid | ai.onnx(6-12,13+) | | | Sin | ai.onnx(7+) | | | Sinh | ai.onnx(9+) | | +| Slice | ai.onnx(1-9,10,11-12,13+) | | | Split | ai.onnx(1,2-10,11-12,13-17,18+) | | | Sqrt | ai.onnx(6-12,13+) | | | Squeeze | ai.onnx(1-10,11-12,13+) | | diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 7c0ab2c786..76a507a24b 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -37,6 +37,14 @@ class TensorViewImpl implements TensorView { new BigInt64Array(this.module.HEAP8.buffer, this.data, elementCount); } + getInt32Array(): Int32Array { + if (this.dataType !== DataType.int32) { + throw new Error('Invalid data type'); + } + const elementCount = ShapeUtil.size(this.dims); + return elementCount === 0 ? new Int32Array() : new Int32Array(this.module.HEAP8.buffer, this.data, elementCount); + } + reshape(newDims: readonly number[]): TensorView { if (ShapeUtil.size(newDims) !== ShapeUtil.size(this.dims)) { throw new Error('Invalid new shape'); diff --git a/js/web/lib/wasm/jsep/tensor.ts b/js/web/lib/wasm/jsep/tensor.ts index 7f80fb2466..abe61e07fc 100644 --- a/js/web/lib/wasm/jsep/tensor.ts +++ b/js/web/lib/wasm/jsep/tensor.ts @@ -103,6 +103,11 @@ export interface TensorView { */ getBigInt64Array(): BigInt64Array; + /** + * get a Int32Array data view of the tensor data. tensor data must be on CPU. + */ + getInt32Array(): Int32Array; + /** * create a new tensor view with the same data but different dimensions. */ diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 6a43cca4e1..c4bb5cf92a 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -10,6 +10,7 @@ import {gemm, parseGemmAttributes} from './ops/gemm'; import {matMul} from './ops/matmul'; import * as pool from './ops/pool'; import {parseReduceAttributes, reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce'; +import {parseSliceAttributes, slice} from './ops/slice'; import {parseSplitAttributes, split} from './ops/split'; import {parseTransposeAttributes, transpose} from './ops/transpose'; import * as unaryOps from './ops/unary-op'; @@ -69,6 +70,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Sigmoid', [unaryOps.sigmoid]], ['Sin', [unaryOps.sin]], ['Sinh', [unaryOps.sinh]], + ['Slice', [slice, parseSliceAttributes]], ['Split', [split, parseSplitAttributes]], ['Sqrt', [unaryOps.sqrt]], ['Sub', [binaryOps.sub]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts new file mode 100644 index 0000000000..ce1025b304 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts @@ -0,0 +1,201 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor'; +import {ShapeUtil} from '../../util'; +import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata, TensorInfo} from '../types'; + +import {createIndicesHelper, ShaderHelper} from './common'; + +export interface SliceAttributes extends AttributeWithCacheKey { + readonly starts: number[]; + readonly ends: number[]; + readonly axes: number[]; +} + +const validateInputs = (inputs: readonly TensorView[], attributes: SliceAttributes): void => { + if (!inputs || inputs.length < 1) { + throw new Error('too few inputs'); + } + if (attributes.axes.length !== 0) { + if (attributes.axes.length !== attributes.starts.length || attributes.axes.length !== attributes.ends.length) { + throw new Error('axes, starts and ends must have the same length'); + } + } else if (attributes.starts.length !== attributes.ends.length) { + throw new Error('starts and ends must have the same length'); + } + inputs.slice(1).forEach((_, idx) => { + if (inputs[idx + 1].dataType !== DataType.int32 && inputs[idx + 1].dataType !== DataType.int64) { + throw new Error(`Input ${idx} must be an array of int32 or int64`); + } + }); +}; + +const readInput = (inputs: readonly TensorView[], idx: number): number[] => { + const input: number[] = []; + if (inputs.length > idx) { + if (inputs[idx].dataType === DataType.int64) { + inputs[idx].getBigInt64Array().forEach(v => input.push(Number(v))); + } else if (inputs[1].dataType === DataType.int32) { + inputs[idx].getInt32Array().forEach(v => input.push(Number(v))); + } else { + throw new Error(`Input ${idx} must be an array of int32 or int64`); + } + } + return input; +}; + +const createSliceAttributesFromInputs = + (inputs: readonly TensorView[], attributes: SliceAttributes): SliceAttributes => { + if (inputs.length > 1) { + const starts: number[] = readInput(inputs, 1); + const ends: number[] = readInput(inputs, 2); + let axes: number[] = readInput(inputs, 3); + if (axes.length === 0) { + axes = [...Array(inputs[0].dims.length).keys()]; + } + return createAttributeWithCacheKey({starts, ends, axes}); + } else { + return attributes; + } + }; + +const fixStartEndValues = + (value: number, index: number, inputShape: readonly number[], axes: readonly number[], steps: readonly number[]): + number => { + let newValue = value; + if (value < 0) { + newValue += inputShape[axes[index]]; + } + if (steps[index] < 0) { + return Math.max(0, Math.min(newValue, inputShape[axes[index]] - 1)); + } else { + return Math.max(0, Math.min(newValue, inputShape[axes[index]])); + } + }; + +const calculateInputIndicesImpl = (inputShape: readonly number[], outputShape: readonly number[]): string => { + const outputIndicesHelper = createIndicesHelper('output', outputShape); + const inputIndicesHelper = createIndicesHelper('input', inputShape); + + return `fn calculateInputIndices(outputIndices: ${outputIndicesHelper.iType}) -> ${inputIndicesHelper.iType} { + ${inputIndicesHelper.indicesVariableDeclaration('inputIndices')}; + var carry = 0u; + for (var i = ${inputShape.length}; i >= 0; i--) { + var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'}; + var inputIndex = outputIndex * steps[i] + starts[i] + carry; + carry = inputIndex / inputShape[i]; + inputIndex = inputIndex % inputShape[i]; + if (signs[i] < 0) { + inputIndex = inputShape[i] - inputIndex - 1u + starts[i]; + } + ${inputShape.length === 1 ? 'inputIndices' : 'inputIndices[i]'} = inputIndex; + } + return inputIndices; + }`; +}; + +const createSliceProgramInfo = + (metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: SliceAttributes): ProgramInfo => { + const inputShape = inputs[0].dims; + const inputSize = ShapeUtil.size(inputShape); + const axes = (attributes.axes.length > 0) ? ShapeUtil.normalizeAxes(attributes.axes, inputShape.length) : + [...Array(inputShape.length).keys()]; + const dataType = 'f32'; // TODO: support other data type + let steps = readInput(inputs, 4); + steps.forEach((step) => step !== 0 || (() => { + throw new Error('step cannot be 0'); + })); + if (steps.length === 0) { + steps = Array(axes.length).fill(1); + } + const starts = attributes.starts.map((start, i) => fixStartEndValues(start, i, inputShape, axes, steps)); + + const ends = attributes.ends.map((end, i) => fixStartEndValues(end, i, inputShape, axes, steps)); + + if (axes.length !== inputShape.length) { + for (let i = 0; i < inputShape.length; ++i) { + if (!axes.includes(i)) { + starts.splice(i, 0, 0); + ends.splice(i, 0, inputShape[i]); + steps.splice(i, 0, 1); + } + } + } + const signs = steps.map(step => Math.sign(step)); + // Convert negative steps to positive steps and reverse starts and ends + steps.forEach((step, i, array) => { + if (step < 0) { + const numSteps = (ends[i] - starts[i]) / step; + const newEnd = starts[i]; + const newStart = newEnd + numSteps * steps[i]; + starts[i] = newStart; + ends[i] = newEnd; + array[i] = -step; + } + }); + + const outputShape = inputShape.slice(0); + axes.forEach((axis, _) => { + outputShape[axis] = Math.ceil((ends[axis] - starts[axis]) / steps[axis]); + }); + + const output: TensorInfo = {dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}; + + const outputIndicesHelper = createIndicesHelper('output', outputShape); + const inputIndicesHelper = createIndicesHelper('input', inputShape); + const outputSize = ShapeUtil.size(outputShape); + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + @group(0) @binding(0) var input: array<${dataType}>; + @group(0) @binding(1) var output: array<${dataType}>; + const signs = array(${signs.map(i => `${i}i`).join(',')}); + const starts = array(${starts.map(i => `${i}u`).join(',')}); + const ends = array(${ends.map(i => `${i}u`).join(',')}); + const steps = array(${steps.map(i => `${i}u`).join(',')}); + const inputShape = array(${inputShape.map(i => `${i}u`).join(',')}); + + ${outputIndicesHelper.o2iImpl} + ${inputIndicesHelper.i2oImpl} + ${calculateInputIndicesImpl(inputShape, outputShape)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${outputIndicesHelper.indicesVariableDeclaration('outputIndices')} + ${outputIndicesHelper.o2iCall('global_idx', 'outputIndices')} + ${inputIndicesHelper.indicesVariableDeclaration('inputIndices')} + inputIndices = calculateInputIndices(outputIndices); + output[global_idx] = input[${inputIndicesHelper.i2oExpression('inputIndices')}]; + }`; + return { + ...metadata, + getShaderSource, + outputs: [output], + dispatchGroup: () => ({x: Math.ceil(inputSize / 64 /* workgroup size */)}) + }; + }; + +const createSliceProgramInfoLoader = + (inputs: readonly TensorView[], attributes: SliceAttributes): ProgramInfoLoader => { + const updatedAttributes = createSliceAttributesFromInputs(inputs, attributes); + const metadata: ProgramMetadata = { + name: 'Slice', + inputTypes: [GpuDataType.default], + cacheHint: updatedAttributes.cacheKey + (inputs.length > 4 ? 'steps_' + inputs[4].dims.toString() : '') + }; + return {...metadata, get: () => createSliceProgramInfo(metadata, inputs, updatedAttributes)}; + }; + +export const slice = (context: ComputeContext, attributes: SliceAttributes): void => { + validateInputs(context.inputs, attributes); + context.compute(createSliceProgramInfoLoader(context.inputs, attributes), {inputs: [0]}); +}; + +export const parseSliceAttributes = (attributes: Record): SliceAttributes => { + const starts = attributes.starts as number[]; + const ends = attributes.ends as number[]; + const axes = attributes.axes as number[]; + const steps: number[] = []; + return createAttributeWithCacheKey({starts, ends, axes, steps}); +}; diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index e5ca91e2f1..1f3b2b979c 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1121,14 +1121,14 @@ "test_sinh", // // "test_size_example", // // "test_size", - // "test_slice_default_axes", - // "test_slice_default_steps", - // "test_slice_end_out_of_bounds", - // "test_slice_neg_steps", - // "test_slice_neg", - // "test_slice_negative_axes", - // "test_slice_start_out_of_bounds", - // "test_slice", + "test_slice_default_axes", + "test_slice_default_steps", + "test_slice_end_out_of_bounds", + "test_slice_neg_steps", + "test_slice_neg", + "test_slice_negative_axes", + "test_slice_start_out_of_bounds", + "test_slice", // "test_softmax_axis_0_expanded", // "test_softmax_axis_0", // "test_softmax_axis_1_expanded", diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 9a0c2a0c64..0365e0ae0d 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -246,6 +246,11 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, Spl class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 8, 12, Expand); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Expand); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 9, Slice); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, Slice); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Slice); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Slice); + std::unique_ptr RegisterKernels() { auto kernel_registry = std::make_unique(); @@ -428,6 +433,11 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/js/operators/slice.cc b/onnxruntime/core/providers/js/operators/slice.cc new file mode 100644 index 0000000000..9cc96a5308 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/slice.cc @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "slice.h" + +namespace onnxruntime { +namespace js { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Slice, + kOnnxDomain, + 1, 9, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + Slice_1); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Slice, + kOnnxDomain, + 10, 10, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPU, 1) + .InputMemoryType(OrtMemTypeCPU, 2) + .InputMemoryType(OrtMemTypeCPU, 3) + .InputMemoryType(OrtMemTypeCPU, 4) + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + Slice); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Slice, + kOnnxDomain, + 11, 12, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPU, 1) + .InputMemoryType(OrtMemTypeCPU, 2) + .InputMemoryType(OrtMemTypeCPU, 3) + .InputMemoryType(OrtMemTypeCPU, 4) + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + Slice); + +ONNX_OPERATOR_KERNEL_EX( + Slice, + kOnnxDomain, + 13, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPU, 1) + .InputMemoryType(OrtMemTypeCPU, 2) + .InputMemoryType(OrtMemTypeCPU, 3) + .InputMemoryType(OrtMemTypeCPU, 4) + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + Slice); + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/slice.h b/onnxruntime/core/providers/js/operators/slice.h new file mode 100644 index 0000000000..6792997025 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/slice.h @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/tensor.h" +#include "core/providers/cpu/tensor/slice.h" +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace js { + +class Slice : public JsKernel, public SliceBase { + public: + Slice(const OpKernelInfo& info, bool dynamic = true) : JsKernel(info), SliceBase(info, dynamic) { + auto attr_axes = AxesAttribute(); + auto attr_starts = StartsAttribute(); + auto attr_ends = EndsAttribute(); + std::vector axes(attr_axes.begin(), attr_axes.end()); + std::vector starts(attr_starts.begin(), attr_starts.end()); + std::vector ends(attr_ends.begin(), attr_ends.end()); + + JSEP_INIT_KERNEL_ATTRIBUTE(Slice, ({"starts" : $1 ? Array.from(HEAP32.subarray($2, $2 + $1)) : [], + "ends" : $3 ? Array.from(HEAP32.subarray($4, $4 + $3)) : [], + "axes" : $5 ? Array.from(HEAP32.subarray($6, $6 + $5)) : []}), + gsl::narrow_cast(starts.size()), + reinterpret_cast((starts.size() > 0) ? starts.data() : nullptr) >> 2, + gsl::narrow_cast(ends.size()), + reinterpret_cast((ends.size() > 0) ? ends.data() : nullptr) >> 2, + gsl::narrow_cast(axes.size()), + reinterpret_cast((axes.size() > 0) ? axes.data() : nullptr) >> 2); + } +}; + +class Slice_1 final : public Slice { + public: + Slice_1(const OpKernelInfo& info) : Slice(info, false) {} +}; +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/split.h b/onnxruntime/core/providers/js/operators/split.h index 4031145e46..691af48711 100644 --- a/onnxruntime/core/providers/js/operators/split.h +++ b/onnxruntime/core/providers/js/operators/split.h @@ -51,7 +51,7 @@ class Split : public JsKernel, public SplitBase { static_cast(axis_), static_cast(num_outputs_), gsl::narrow_cast(split_sizes.size()), - reinterpret_cast((!split_sizes.empty() > 0) ? split_sizes.data() : nullptr) >> 2); + reinterpret_cast((split_sizes.size() > 0) ? split_sizes.data() : nullptr) >> 2); } };