onnxruntime/js/web/lib/wasm/jsep/webgpu/ops/slice.ts
Yulong Wang 14a8315f10
[js/web] [webgpu] new incides helper (#16957)
### Description
This PR introduces the new incides helper.

IndicesHelper is a helper class for generating WGSL code for
manipulating indices and data for a shader's input or output.

This class is designed to offer a unified way to generate WGSL code for
manipulating indices and data for a shader's input or output. The
following is a list of terminologies used in this class:
- `offset`: a uint32 value representing the offset of an element in the
data buffer.
- `indices`: an abstraction of a multi-dimensional array's indices
representing the data's index on each dimension.
- `value`: a value of a data element.

Users are expected to create an instance of this class for each shader's
input or output, and use the instance to generate WGSL code for
manipulating indices and data. The following 2 exported functions are
for users to call to create an instance of an indices helper:
 - `inputVariable()`: create an indices helper instance for an input.
 - `outputVariable()`: create an indices helper instance for an output.


An indices helper instance contains helper functions for the following
operations:
- access readonly basic information, including: `name`(the name of the
input or output), `usage`(whether it's an input or an output) and
`shape`(the passed in shape).
- `type`: access readonly type information, including: `indices`(the
type of indices), `value`(the type of value at runtime), `storage`(the
type of value at storage) and `tensor`(the tensor type as represented in
TensorView).
- generate WGSL code for getting indices from offset. Use
`offsetToIndices()` for WGSL code snippet to calculate incides from
offset, and use `indicesToOffset()` for WGSL code snippet to calculate
offset from indices.
- to manipulate an instance of indices, use `setIndices()` and
`getIndices()` to set and get the indices on an indices variable.
- to manipulate data, use `set()`/`get()` to access data at the given
indices from parameter list, use `setByIndices()`/`getByIndices()` to
access data at the given indices from an indices variable, and use
`setByOffset()`/`getByOffset()` to access data at the given offset.
- `impl`: get WGSL code of function implementation for the util
functions mentioned above.

This change applies the usage of new IndicesHelper through the code, but
not necessary for all code.
2023-08-11 11:36:59 -07:00

201 lines
8.5 KiB
TypeScript

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata, TensorInfo} from '../types';
import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
export interface SliceAttributes extends AttributeWithCacheKey {
readonly starts: number[];
readonly ends: number[];
readonly axes: number[];
}
const validateInputs = (inputs: readonly TensorView[], attributes: SliceAttributes): void => {
if (!inputs || inputs.length < 1) {
throw new Error('too few inputs');
}
if (attributes.axes.length !== 0) {
if (attributes.axes.length !== attributes.starts.length || attributes.axes.length !== attributes.ends.length) {
throw new Error('axes, starts and ends must have the same length');
}
} else if (attributes.starts.length !== attributes.ends.length) {
throw new Error('starts and ends must have the same length');
}
inputs.slice(1).forEach((_, idx) => {
if (inputs[idx + 1].dataType !== DataType.int32 && inputs[idx + 1].dataType !== DataType.int64) {
throw new Error(`Input ${idx} must be an array of int32 or int64`);
}
});
};
const readInput = (inputs: readonly TensorView[], idx: number): number[] => {
const input: number[] = [];
if (inputs.length > idx) {
if (inputs[idx].dataType === DataType.int64) {
inputs[idx].getBigInt64Array().forEach(v => input.push(Number(v)));
} else if (inputs[idx].dataType === DataType.int32) {
inputs[idx].getInt32Array().forEach(v => input.push(Number(v)));
} else {
throw new Error(`Input ${idx} must be an array of int32 or int64`);
}
}
return input;
};
const createSliceAttributesFromInputs =
(inputs: readonly TensorView[], attributes: SliceAttributes): SliceAttributes => {
if (inputs.length > 1) {
const starts: number[] = readInput(inputs, 1);
const ends: number[] = readInput(inputs, 2);
let axes: number[] = readInput(inputs, 3);
if (axes.length === 0) {
axes = [...Array(inputs[0].dims.length).keys()];
}
return createAttributeWithCacheKey({starts, ends, axes});
} else {
return attributes;
}
};
const fixStartEndValues =
(value: number, index: number, inputShape: readonly number[], axes: readonly number[], steps: readonly number[]):
number => {
let newValue = value;
if (value < 0) {
newValue += inputShape[axes[index]];
}
if (steps[index] < 0) {
return Math.max(0, Math.min(newValue, inputShape[axes[index]] - 1));
} else {
return Math.max(0, Math.min(newValue, inputShape[axes[index]]));
}
};
const calculateInputIndicesImpl =
(input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[]):
string => `fn calculateInputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} {
var inputIndices: ${input.type.indices};
var carry = 0u;
for (var i = ${inputShape.length}; i >= 0; i--) {
var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'};
var inputIndex = outputIndex * steps[i] + starts[i] + carry;
carry = inputIndex / inputShape[i];
inputIndex = inputIndex % inputShape[i];
if (signs[i] < 0) {
inputIndex = inputShape[i] - inputIndex - 1u + starts[i];
}
${inputShape.length === 1 ? 'inputIndices' : 'inputIndices[i]'} = inputIndex;
}
return inputIndices;
}`;
const createSliceProgramInfo =
(metadata: ProgramMetadata, inputs: readonly TensorView[], attributes: SliceAttributes): ProgramInfo => {
const inputShape = inputs[0].dims;
const inputSize = ShapeUtil.size(inputShape);
const axes = (attributes.axes.length > 0) ? ShapeUtil.normalizeAxes(attributes.axes, inputShape.length) :
[...Array(inputShape.length).keys()];
let steps = readInput(inputs, 4);
steps.forEach((step) => step !== 0 || (() => {
throw new Error('step cannot be 0');
}));
if (steps.length === 0) {
steps = Array(axes.length).fill(1);
}
const starts = attributes.starts.map((start, i) => fixStartEndValues(start, i, inputShape, axes, steps));
const ends = attributes.ends.map((end, i) => fixStartEndValues(end, i, inputShape, axes, steps));
if (axes.length !== inputShape.length) {
for (let i = 0; i < inputShape.length; ++i) {
if (!axes.includes(i)) {
starts.splice(i, 0, 0);
ends.splice(i, 0, inputShape[i]);
steps.splice(i, 0, 1);
}
}
}
const signs = steps.map(step => Math.sign(step));
// Convert negative steps to positive steps and reverse starts and ends
steps.forEach((step, i, array) => {
if (step < 0) {
const numSteps = (ends[i] - starts[i]) / step;
const newEnd = starts[i];
const newStart = newEnd + numSteps * steps[i];
starts[i] = newStart;
ends[i] = newEnd;
array[i] = -step;
}
});
const outputShape = inputShape.slice(0);
axes.forEach((axis, _) => {
outputShape[axis] = Math.ceil((ends[axis] - starts[axis]) / steps[axis]);
});
const outputTensorInfo:
TensorInfo = {dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default};
const output = outputVariable('output', inputs[0].dataType, outputShape);
const input = inputVariable('input', inputs[0].dataType, inputShape);
const outputSize = ShapeUtil.size(outputShape);
const getShaderSource = (shaderHelper: ShaderHelper) => `
${shaderHelper.declareVariables(input, output)}
const signs = array<i32, ${signs.length}>(${signs.map(i => `${i}i`).join(',')});
const starts = array<u32, ${starts.length}>(${starts.map(i => `${i}u`).join(',')});
const ends = array<u32, ${ends.length}>(${ends.map(i => `${i}u`).join(',')});
const steps = array<u32, ${steps.length}>(${steps.map(i => `${i}u`).join(',')});
const inputShape = array<u32, ${inputShape.length}>(${inputShape.map(i => `${i}u`).join(',')});
${output.impl('offsetToIndices')}
${input.impl('indicesToOffset', 'get')}
${calculateInputIndicesImpl(input, output, inputShape, outputShape)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
let outputIndices = ${output.offsetToIndices('global_idx')};
let inputIndices = calculateInputIndices(outputIndices);
${output.setByOffset('global_idx', input.getByIndices('inputIndices'))}
}`;
return {
...metadata,
getShaderSource,
outputs: [outputTensorInfo],
dispatchGroup: () => ({x: Math.ceil(inputSize / 64 /* workgroup size */)})
};
};
const createSliceProgramInfoLoader =
(inputs: readonly TensorView[], attributes: SliceAttributes): ProgramInfoLoader => {
const updatedAttributes = createSliceAttributesFromInputs(inputs, attributes);
const metadata: ProgramMetadata = {
name: 'Slice',
inputTypes: [GpuDataType.default],
cacheHint: updatedAttributes.cacheKey + (inputs.length > 4 ? 'steps_' + inputs[4].dims.toString() : '')
};
return {...metadata, get: () => createSliceProgramInfo(metadata, inputs, updatedAttributes)};
};
export const slice = (context: ComputeContext, attributes: SliceAttributes): void => {
validateInputs(context.inputs, attributes);
const programInfoLoader = createSliceProgramInfoLoader(context.inputs, attributes);
const program = programInfoLoader.get();
if (ShapeUtil.size(program.outputs[0].dims) > 0) {
context.compute(programInfoLoader, {inputs: [0]});
} else {
// TODO: support empty output
throw new Error('slice: output size is 0');
}
};
export const parseSliceAttributes = (attributes: Record<string, unknown>): SliceAttributes => {
const starts = attributes.starts as number[];
const ends = attributes.ends as number[];
const axes = attributes.axes as number[];
return createAttributeWithCacheKey({starts, ends, axes});
};