[JS/Web]Added uniforms support to Slice op. (#18422)

### Description
Support uniforms in Slice op



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Improve ferformance
This commit is contained in:
satyajandhyala 2023-11-16 09:44:13 -08:00 committed by GitHub
parent 999752a35d
commit b291b20fa0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 91 additions and 21 deletions

View file

@ -646,6 +646,8 @@ export const outputVariable =
(name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper =>
createIndicesHelper(name, type, shapeOrRank, false, components);
export type UniformsArrayType = Array<{name: string; type: string}>;
/**
* A ShaderHelper is a helper class for generating WGSL code.
*/
@ -697,6 +699,7 @@ export interface ShaderHelper {
* A helper function to register one uniform. Can be called multiple times to register multiple uniforms.
*/
registerUniform(name: string, type: string): ShaderHelper;
registerUniforms(nameToTypeMap: UniformsArrayType): ShaderHelper;
}
class ShaderHelperImpl implements ShaderHelper {
@ -755,8 +758,13 @@ class ShaderHelperImpl implements ShaderHelper {
return this;
}
registerUniforms(additionalUniforms: UniformsArrayType): ShaderHelper {
this.uniforms = this.uniforms.concat(additionalUniforms);
return this;
}
private indicesHelpers: IndicesHelper[] = [];
private uniforms: Array<{name: string; type: string}> = [];
private uniforms: UniformsArrayType = [];
private uniformDeclaration(): string {
if (this.uniforms.length === 0) {
return '';

View file

@ -5,9 +5,9 @@ import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo, TensorInfo} from '../types';
import {ComputeContext, ProgramInfo, ProgramUniform, TensorInfo} from '../types';
import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common';
export interface SliceAttributes extends AttributeWithCacheKey {
readonly starts: number[];
@ -77,17 +77,26 @@ const fixStartEndValues =
};
const calculateInputIndicesImpl =
(input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[]):
string => `fn calculateInputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} {
(input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[],
enableInputShapeUniforms: boolean): string =>
`fn calculateInputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} {
var inputIndices: ${input.type.indices};
var carry = 0u;
for (var i = ${inputShape.length}; i >= 0; i--) {
let input_shape_i = ${
enableInputShapeUniforms ? `uniforms.input_shape${inputShape.length > 1 ? '[i]' : ''}` : 'inputShape[i]'};
let steps_i = ${
enableInputShapeUniforms ? `uniforms.steps${inputShape.length > 1 ? '[i]' : ''}` : 'steps[i]'};
let signs_i = ${
enableInputShapeUniforms ? `uniforms.signs${inputShape.length > 1 ? '[i]' : ''}` : 'signs[i]'};
let starts_i = ${
enableInputShapeUniforms ? `uniforms.starts${inputShape.length > 1 ? '[i]' : ''}` : 'starts[i]'};
var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'};
var inputIndex = outputIndex * steps[i] + starts[i] + carry;
carry = inputIndex / inputShape[i];
inputIndex = inputIndex % inputShape[i];
if (signs[i] < 0) {
inputIndex = inputShape[i] - inputIndex - 1u + starts[i];
var inputIndex = outputIndex * steps_i + starts_i + carry;
carry = inputIndex / input_shape_i;
inputIndex = inputIndex % input_shape_i;
if (signs_i < 0) {
inputIndex = input_shape_i - inputIndex - 1u + starts_i;
}
${inputShape.length === 1 ? 'inputIndices' : 'inputIndices[i]'} = inputIndex;
}
@ -110,6 +119,10 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice
const ends = attributes.ends.map((end, i) => fixStartEndValues(end, i, inputShape, axes, steps));
if (axes.length !== starts.length || axes.length !== ends.length) {
throw new Error('start, ends and axes should have the same number of elements');
}
if (axes.length !== inputShape.length) {
for (let i = 0; i < inputShape.length; ++i) {
if (!axes.includes(i)) {
@ -131,40 +144,66 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice
array[i] = -step;
}
});
// Output rank is expected to be less than or equal to the input rank.
const enableShapeUniforms = enableShapesUniforms(inputs[0].dims.length);
const inputShapeOrRank = enableShapeUniforms ? inputs[0].dims.length : inputs[0].dims;
const outputShape = inputShape.slice(0);
axes.forEach((axis, _) => {
outputShape[axis] = Math.ceil((ends[axis] - starts[axis]) / steps[axis]);
});
const outputShapeOrRank = enableShapeUniforms ? outputShape.length : outputShape;
const outputTensorInfo: TensorInfo = {dims: outputShape, dataType: inputs[0].dataType};
const output = outputVariable('output', inputs[0].dataType, outputShape);
const input = inputVariable('input', inputs[0].dataType, inputShape);
const output = outputVariable('output', inputs[0].dataType, outputShapeOrRank);
const input = inputVariable('input', inputs[0].dataType, inputShapeOrRank);
const outputSize = ShapeUtil.size(outputShape);
const programUniforms: ProgramUniform[] = [];
const uniforms: UniformsArrayType = [];
if (enableShapeUniforms) {
uniforms.push({name: 'starts', type: starts.length > 1 ? `vec${starts.length}<u32>` : 'u32'});
uniforms.push({name: 'signs', type: signs.length > 1 ? `vec${signs.length}<i32>` : 'i32'});
uniforms.push({name: 'steps', type: steps.length > 1 ? `vec${steps.length}<u32>` : 'u32'});
programUniforms.push({type: 'uint32', data: starts});
programUniforms.push({type: 'int32', data: signs});
programUniforms.push({type: 'uint32', data: steps});
}
uniforms.push({name: 'outputSize', type: 'u32'});
programUniforms.push({type: 'uint32', data: outputSize});
if (enableShapeUniforms) {
programUniforms.push(...createTensorShapeVariables(inputs[0].dims));
programUniforms.push(...createTensorShapeVariables(outputShape));
}
const getShaderSource = (shaderHelper: ShaderHelper) => `
${shaderHelper.declareVariables(input, output)}
const signs = array<i32, ${signs.length}>(${signs.map(i => `${i}i`).join(',')});
const starts = array<u32, ${starts.length}>(${starts.map(i => `${i}u`).join(',')});
const ends = array<u32, ${ends.length}>(${ends.map(i => `${i}u`).join(',')});
const steps = array<u32, ${steps.length}>(${steps.map(i => `${i}u`).join(',')});
const inputShape = array<u32, ${inputShape.length}>(${inputShape.map(i => `${i}u`).join(',')});
${shaderHelper.registerUniforms(uniforms).declareVariables(input, output)}
${enableShapeUniforms ? '' : [
`const signs = array<i32, ${signs.length}>(${signs.map(i => `${i}i`).join(',')});`,
`const starts = array<u32, ${starts.length}>(${starts.map(i => `${i}u`).join(',')});`,
`const steps = array<u32, ${steps.length}>(${steps.map(i => `${i}u`).join(',')});`,
`const inputShape = array<u32, ${inputShape.length}>(${inputShape.map(i => `${i}u`).join(',')});`
].join('\n')}
${calculateInputIndicesImpl(input, output, inputShape, outputShape)}
${calculateInputIndicesImpl(input, output, inputShape, outputShape, enableShapeUniforms)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
let outputIndices = ${output.offsetToIndices('global_idx')};
let inputIndices = calculateInputIndices(outputIndices);
${output.setByOffset('global_idx', input.getByIndices('inputIndices'))}
}`;
return {
name: 'Slice',
shaderCache: {hint: `${attributes.cacheKey}|${inputs[4]?.dims ?? ''}`},
shaderCache: {
hint: enableShapeUniforms ? `${signs.length}_${starts.length}_${steps.length}` :
`${attributes.cacheKey} | ${inputs[4]?.dims ?? ''}`,
inputDependencies: [enableShapeUniforms ? 'rank' : 'dims']
},
getShaderSource,
getRunData: () => ({
outputs: [outputTensorInfo],
dispatchGroup: {x: Math.ceil(inputSize / 64 /* workgroup size */)},
programUniforms
})
};
};

View file

@ -21,6 +21,29 @@
}
]
},
{
"name": "Slice float32 with input[0] dim > 4",
"operator": "Slice",
"attributes": [],
"cases": [
{
"name": "T[1, 1, 1, 1, 5] T[1] T[1] T[1] (float32)",
"inputs": [
{
"data": [
0.3964604139328003, -0.8916832804679871, -1.6578896045684814, 1.960708737373352, 1.181204915046692
],
"dims": [1, 1, 1, 1, 5],
"type": "float32"
},
{ "data": [3], "dims": [1], "type": "int64" },
{ "data": [4], "dims": [1], "type": "int64" },
{ "data": [4], "dims": [1], "type": "int64" }
],
"outputs": [{ "data": [1.960708737373352], "dims": [1, 1, 1, 1, 1], "type": "float32" }]
}
]
},
{
"name": "Slice int32",
"operator": "Slice",