[Web/JS] Added Slice operator in JSEP. (#16811)

### Description
Added Slice operator support to JSEP.



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
satyajandhyala 2023-07-25 14:19:20 -07:00 committed by GitHub
parent a1bb670536
commit 03ce0a5693
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 334 additions and 9 deletions

View file

@ -60,6 +60,7 @@ Do not modify directly.*
| Sigmoid | ai.onnx(6-12,13+) | |
| Sin | ai.onnx(7+) | |
| Sinh | ai.onnx(9+) | |
| Slice | ai.onnx(1-9,10,11-12,13+) | |
| Split | ai.onnx(1,2-10,11-12,13-17,18+) | |
| Sqrt | ai.onnx(6-12,13+) | |
| Squeeze | ai.onnx(1-10,11-12,13+) | |

View file

@ -37,6 +37,14 @@ class TensorViewImpl implements TensorView {
new BigInt64Array(this.module.HEAP8.buffer, this.data, elementCount);
}
getInt32Array(): Int32Array {
if (this.dataType !== DataType.int32) {
throw new Error('Invalid data type');
}
const elementCount = ShapeUtil.size(this.dims);
return elementCount === 0 ? new Int32Array() : new Int32Array(this.module.HEAP8.buffer, this.data, elementCount);
}
reshape(newDims: readonly number[]): TensorView {
if (ShapeUtil.size(newDims) !== ShapeUtil.size(this.dims)) {
throw new Error('Invalid new shape');

View file

@ -103,6 +103,11 @@ export interface TensorView {
*/
getBigInt64Array(): BigInt64Array;
/**
* get a Int32Array data view of the tensor data. tensor data must be on CPU.
*/
getInt32Array(): Int32Array;
/**
* create a new tensor view with the same data but different dimensions.
*/

View file

@ -10,6 +10,7 @@ import {gemm, parseGemmAttributes} from './ops/gemm';
import {matMul} from './ops/matmul';
import * as pool from './ops/pool';
import {parseReduceAttributes, reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce';
import {parseSliceAttributes, slice} from './ops/slice';
import {parseSplitAttributes, split} from './ops/split';
import {parseTransposeAttributes, transpose} from './ops/transpose';
import * as unaryOps from './ops/unary-op';
@ -69,6 +70,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Sigmoid', [unaryOps.sigmoid]],
['Sin', [unaryOps.sin]],
['Sinh', [unaryOps.sinh]],
['Slice', [slice, parseSliceAttributes]],
['Split', [split, parseSplitAttributes]],
['Sqrt', [unaryOps.sqrt]],
['Sub', [binaryOps.sub]],

View file

@ -0,0 +1,201 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata, TensorInfo} from '../types';
import {createIndicesHelper, ShaderHelper} from './common';
export interface SliceAttributes extends AttributeWithCacheKey {
readonly starts: number[];
readonly ends: number[];
readonly axes: number[];
}
const validateInputs = (inputs: readonly TensorView[], attributes: SliceAttributes): void => {
if (!inputs || inputs.length < 1) {
throw new Error('too few inputs');
}
if (attributes.axes.length !== 0) {
if (attributes.axes.length !== attributes.starts.length || attributes.axes.length !== attributes.ends.length) {
throw new Error('axes, starts and ends must have the same length');
}
} else if (attributes.starts.length !== attributes.ends.length) {
throw new Error('starts and ends must have the same length');
}
inputs.slice(1).forEach((_, idx) => {
if (inputs[idx + 1].dataType !== DataType.int32 && inputs[idx + 1].dataType !== DataType.int64) {
throw new Error(`Input ${idx} must be an array of int32 or int64`);
}
});
};
const readInput = (inputs: readonly TensorView[], idx: number): number[] => {
const input: number[] = [];
if (inputs.length > idx) {
if (inputs[idx].dataType === DataType.int64) {
inputs[idx].getBigInt64Array().forEach(v => input.push(Number(v)));
} else if (inputs[1].dataType === DataType.int32) {
inputs[idx].getInt32Array().forEach(v => input.push(Number(v)));
} else {
throw new Error(`Input ${idx} must be an array of int32 or int64`);
}
}
return input;
};
const createSliceAttributesFromInputs =
(inputs: readonly TensorView[], attributes: SliceAttributes): SliceAttributes => {
if (inputs.length > 1) {
const starts: number[] = readInput(inputs, 1);
const ends: number[] = readInput(inputs, 2);
let axes: number[] = readInput(inputs, 3);
if (axes.length === 0) {
axes = [...Array(inputs[0].dims.length).keys()];
}
return createAttributeWithCacheKey({starts, ends, axes});
} else {
return attributes;
}
};
const fixStartEndValues =
(value: number, index: number, inputShape: readonly number[], axes: readonly number[], steps: readonly number[]):
number => {
let newValue = value;
if (value < 0) {
newValue += inputShape[axes[index]];
}
if (steps[index] < 0) {
return Math.max(0, Math.min(newValue, inputShape[axes[index]] - 1));
} else {
return Math.max(0, Math.min(newValue, inputShape[axes[index]]));
}
};
const calculateInputIndicesImpl = (inputShape: readonly number[], outputShape: readonly number[]): string => {
const outputIndicesHelper = createIndicesHelper('output', outputShape);
const inputIndicesHelper = createIndicesHelper('input', inputShape);
return `fn calculateInputIndices(outputIndices: ${outputIndicesHelper.iType}) -> ${inputIndicesHelper.iType} {
${inputIndicesHelper.indicesVariableDeclaration('inputIndices')};
var carry = 0u;
for (var i = ${inputShape.length}; i >= 0; i--) {
var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'};
var inputIndex = outputIndex * steps[i] + starts[i] + carry;
carry = inputIndex / inputShape[i];
inputIndex = inputIndex % inputShape[i];
if (signs[i] < 0) {
inputIndex = inputShape[i] - inputIndex - 1u + starts[i];
}
${inputShape.length === 1 ? 'inputIndices' : 'inputIndices[i]'} = inputIndex;
}
return inputIndices;
}`;
};
const createSliceProgramInfo =
(metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: SliceAttributes): ProgramInfo => {
const inputShape = inputs[0].dims;
const inputSize = ShapeUtil.size(inputShape);
const axes = (attributes.axes.length > 0) ? ShapeUtil.normalizeAxes(attributes.axes, inputShape.length) :
[...Array(inputShape.length).keys()];
const dataType = 'f32'; // TODO: support other data type
let steps = readInput(inputs, 4);
steps.forEach((step) => step !== 0 || (() => {
throw new Error('step cannot be 0');
}));
if (steps.length === 0) {
steps = Array(axes.length).fill(1);
}
const starts = attributes.starts.map((start, i) => fixStartEndValues(start, i, inputShape, axes, steps));
const ends = attributes.ends.map((end, i) => fixStartEndValues(end, i, inputShape, axes, steps));
if (axes.length !== inputShape.length) {
for (let i = 0; i < inputShape.length; ++i) {
if (!axes.includes(i)) {
starts.splice(i, 0, 0);
ends.splice(i, 0, inputShape[i]);
steps.splice(i, 0, 1);
}
}
}
const signs = steps.map(step => Math.sign(step));
// Convert negative steps to positive steps and reverse starts and ends
steps.forEach((step, i, array) => {
if (step < 0) {
const numSteps = (ends[i] - starts[i]) / step;
const newEnd = starts[i];
const newStart = newEnd + numSteps * steps[i];
starts[i] = newStart;
ends[i] = newEnd;
array[i] = -step;
}
});
const outputShape = inputShape.slice(0);
axes.forEach((axis, _) => {
outputShape[axis] = Math.ceil((ends[axis] - starts[axis]) / steps[axis]);
});
const output: TensorInfo = {dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default};
const outputIndicesHelper = createIndicesHelper('output', outputShape);
const inputIndicesHelper = createIndicesHelper('input', inputShape);
const outputSize = ShapeUtil.size(outputShape);
const getShaderSource = (shaderHelper: ShaderHelper) => `
@group(0) @binding(0) var<storage, read> input: array<${dataType}>;
@group(0) @binding(1) var<storage, read_write> output: array<${dataType}>;
const signs = array<i32, ${signs.length}>(${signs.map(i => `${i}i`).join(',')});
const starts = array<u32, ${starts.length}>(${starts.map(i => `${i}u`).join(',')});
const ends = array<u32, ${ends.length}>(${ends.map(i => `${i}u`).join(',')});
const steps = array<u32, ${steps.length}>(${steps.map(i => `${i}u`).join(',')});
const inputShape = array<u32, ${inputShape.length}>(${inputShape.map(i => `${i}u`).join(',')});
${outputIndicesHelper.o2iImpl}
${inputIndicesHelper.i2oImpl}
${calculateInputIndicesImpl(inputShape, outputShape)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
${outputIndicesHelper.indicesVariableDeclaration('outputIndices')}
${outputIndicesHelper.o2iCall('global_idx', 'outputIndices')}
${inputIndicesHelper.indicesVariableDeclaration('inputIndices')}
inputIndices = calculateInputIndices(outputIndices);
output[global_idx] = input[${inputIndicesHelper.i2oExpression('inputIndices')}];
}`;
return {
...metadata,
getShaderSource,
outputs: [output],
dispatchGroup: () => ({x: Math.ceil(inputSize / 64 /* workgroup size */)})
};
};
const createSliceProgramInfoLoader =
(inputs: readonly TensorView[], attributes: SliceAttributes): ProgramInfoLoader => {
const updatedAttributes = createSliceAttributesFromInputs(inputs, attributes);
const metadata: ProgramMetadata = {
name: 'Slice',
inputTypes: [GpuDataType.default],
cacheHint: updatedAttributes.cacheKey + (inputs.length > 4 ? 'steps_' + inputs[4].dims.toString() : '')
};
return {...metadata, get: () => createSliceProgramInfo(metadata, inputs, updatedAttributes)};
};
export const slice = (context: ComputeContext, attributes: SliceAttributes): void => {
validateInputs(context.inputs, attributes);
context.compute(createSliceProgramInfoLoader(context.inputs, attributes), {inputs: [0]});
};
export const parseSliceAttributes = (attributes: Record<string, unknown>): SliceAttributes => {
const starts = attributes.starts as number[];
const ends = attributes.ends as number[];
const axes = attributes.axes as number[];
const steps: number[] = [];
return createAttributeWithCacheKey({starts, ends, axes, steps});
};

View file

@ -1121,14 +1121,14 @@
"test_sinh",
// // "test_size_example",
// // "test_size",
// "test_slice_default_axes",
// "test_slice_default_steps",
// "test_slice_end_out_of_bounds",
// "test_slice_neg_steps",
// "test_slice_neg",
// "test_slice_negative_axes",
// "test_slice_start_out_of_bounds",
// "test_slice",
"test_slice_default_axes",
"test_slice_default_steps",
"test_slice_end_out_of_bounds",
"test_slice_neg_steps",
"test_slice_neg",
"test_slice_negative_axes",
"test_slice_start_out_of_bounds",
"test_slice",
// "test_softmax_axis_0_expanded",
// "test_softmax_axis_0",
// "test_softmax_axis_1_expanded",

View file

@ -246,6 +246,11 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, Spl
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 8, 12, Expand);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Expand);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 9, Slice);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, Slice);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Slice);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Slice);
std::unique_ptr<KernelRegistry> RegisterKernels() {
auto kernel_registry = std::make_unique<onnxruntime::KernelRegistry>();
@ -428,6 +433,11 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, Split)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 8, 12, Expand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Expand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 9, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Slice)>,
};
for (auto& function_table_entry : function_table) {

View file

@ -0,0 +1,58 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "slice.h"
namespace onnxruntime {
namespace js {
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Slice,
kOnnxDomain,
1, 9,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Slice_1);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Slice,
kOnnxDomain,
10, 10,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPU, 1)
.InputMemoryType(OrtMemTypeCPU, 2)
.InputMemoryType(OrtMemTypeCPU, 3)
.InputMemoryType(OrtMemTypeCPU, 4)
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Slice);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Slice,
kOnnxDomain,
11, 12,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPU, 1)
.InputMemoryType(OrtMemTypeCPU, 2)
.InputMemoryType(OrtMemTypeCPU, 3)
.InputMemoryType(OrtMemTypeCPU, 4)
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Slice);
ONNX_OPERATOR_KERNEL_EX(
Slice,
kOnnxDomain,
13,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPU, 1)
.InputMemoryType(OrtMemTypeCPU, 2)
.InputMemoryType(OrtMemTypeCPU, 3)
.InputMemoryType(OrtMemTypeCPU, 4)
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Slice);
} // namespace js
} // namespace onnxruntime

View file

@ -0,0 +1,40 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/framework/tensor.h"
#include "core/providers/cpu/tensor/slice.h"
#include "core/providers/js/js_kernel.h"
namespace onnxruntime {
namespace js {
class Slice : public JsKernel, public SliceBase {
public:
Slice(const OpKernelInfo& info, bool dynamic = true) : JsKernel(info), SliceBase(info, dynamic) {
auto attr_axes = AxesAttribute();
auto attr_starts = StartsAttribute();
auto attr_ends = EndsAttribute();
std::vector<int32_t> axes(attr_axes.begin(), attr_axes.end());
std::vector<int32_t> starts(attr_starts.begin(), attr_starts.end());
std::vector<int32_t> ends(attr_ends.begin(), attr_ends.end());
JSEP_INIT_KERNEL_ATTRIBUTE(Slice, ({"starts" : $1 ? Array.from(HEAP32.subarray($2, $2 + $1)) : [],
"ends" : $3 ? Array.from(HEAP32.subarray($4, $4 + $3)) : [],
"axes" : $5 ? Array.from(HEAP32.subarray($6, $6 + $5)) : []}),
gsl::narrow_cast<int32_t>(starts.size()),
reinterpret_cast<int32_t>((starts.size() > 0) ? starts.data() : nullptr) >> 2,
gsl::narrow_cast<int32_t>(ends.size()),
reinterpret_cast<int32_t>((ends.size() > 0) ? ends.data() : nullptr) >> 2,
gsl::narrow_cast<int32_t>(axes.size()),
reinterpret_cast<int32_t>((axes.size() > 0) ? axes.data() : nullptr) >> 2);
}
};
class Slice_1 final : public Slice {
public:
Slice_1(const OpKernelInfo& info) : Slice(info, false) {}
};
} // namespace js
} // namespace onnxruntime

View file

@ -51,7 +51,7 @@ class Split : public JsKernel, public SplitBase {
static_cast<int32_t>(axis_),
static_cast<int32_t>(num_outputs_),
gsl::narrow_cast<int32_t>(split_sizes.size()),
reinterpret_cast<int32_t>((!split_sizes.empty() > 0) ? split_sizes.data() : nullptr) >> 2);
reinterpret_cast<int32_t>((split_sizes.size() > 0) ? split_sizes.data() : nullptr) >> 2);
}
};