mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
[JS/Web]Added uniforms support to Slice op. (#18422)
### Description Support uniforms in Slice op ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Improve ferformance
This commit is contained in:
parent
999752a35d
commit
b291b20fa0
3 changed files with 91 additions and 21 deletions
|
|
@ -646,6 +646,8 @@ export const outputVariable =
|
|||
(name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper =>
|
||||
createIndicesHelper(name, type, shapeOrRank, false, components);
|
||||
|
||||
export type UniformsArrayType = Array<{name: string; type: string}>;
|
||||
|
||||
/**
|
||||
* A ShaderHelper is a helper class for generating WGSL code.
|
||||
*/
|
||||
|
|
@ -697,6 +699,7 @@ export interface ShaderHelper {
|
|||
* A helper function to register one uniform. Can be called multiple times to register multiple uniforms.
|
||||
*/
|
||||
registerUniform(name: string, type: string): ShaderHelper;
|
||||
registerUniforms(nameToTypeMap: UniformsArrayType): ShaderHelper;
|
||||
}
|
||||
|
||||
class ShaderHelperImpl implements ShaderHelper {
|
||||
|
|
@ -755,8 +758,13 @@ class ShaderHelperImpl implements ShaderHelper {
|
|||
return this;
|
||||
}
|
||||
|
||||
registerUniforms(additionalUniforms: UniformsArrayType): ShaderHelper {
|
||||
this.uniforms = this.uniforms.concat(additionalUniforms);
|
||||
return this;
|
||||
}
|
||||
|
||||
private indicesHelpers: IndicesHelper[] = [];
|
||||
private uniforms: Array<{name: string; type: string}> = [];
|
||||
private uniforms: UniformsArrayType = [];
|
||||
private uniformDeclaration(): string {
|
||||
if (this.uniforms.length === 0) {
|
||||
return '';
|
||||
|
|
|
|||
|
|
@ -5,9 +5,9 @@ import {DataType} from '../../../wasm-common';
|
|||
import {TensorView} from '../../tensor-view';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext, ProgramInfo, TensorInfo} from '../types';
|
||||
import {ComputeContext, ProgramInfo, ProgramUniform, TensorInfo} from '../types';
|
||||
|
||||
import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
|
||||
import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common';
|
||||
|
||||
export interface SliceAttributes extends AttributeWithCacheKey {
|
||||
readonly starts: number[];
|
||||
|
|
@ -77,17 +77,26 @@ const fixStartEndValues =
|
|||
};
|
||||
|
||||
const calculateInputIndicesImpl =
|
||||
(input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[]):
|
||||
string => `fn calculateInputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} {
|
||||
(input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[],
|
||||
enableInputShapeUniforms: boolean): string =>
|
||||
`fn calculateInputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} {
|
||||
var inputIndices: ${input.type.indices};
|
||||
var carry = 0u;
|
||||
for (var i = ${inputShape.length}; i >= 0; i--) {
|
||||
let input_shape_i = ${
|
||||
enableInputShapeUniforms ? `uniforms.input_shape${inputShape.length > 1 ? '[i]' : ''}` : 'inputShape[i]'};
|
||||
let steps_i = ${
|
||||
enableInputShapeUniforms ? `uniforms.steps${inputShape.length > 1 ? '[i]' : ''}` : 'steps[i]'};
|
||||
let signs_i = ${
|
||||
enableInputShapeUniforms ? `uniforms.signs${inputShape.length > 1 ? '[i]' : ''}` : 'signs[i]'};
|
||||
let starts_i = ${
|
||||
enableInputShapeUniforms ? `uniforms.starts${inputShape.length > 1 ? '[i]' : ''}` : 'starts[i]'};
|
||||
var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'};
|
||||
var inputIndex = outputIndex * steps[i] + starts[i] + carry;
|
||||
carry = inputIndex / inputShape[i];
|
||||
inputIndex = inputIndex % inputShape[i];
|
||||
if (signs[i] < 0) {
|
||||
inputIndex = inputShape[i] - inputIndex - 1u + starts[i];
|
||||
var inputIndex = outputIndex * steps_i + starts_i + carry;
|
||||
carry = inputIndex / input_shape_i;
|
||||
inputIndex = inputIndex % input_shape_i;
|
||||
if (signs_i < 0) {
|
||||
inputIndex = input_shape_i - inputIndex - 1u + starts_i;
|
||||
}
|
||||
${inputShape.length === 1 ? 'inputIndices' : 'inputIndices[i]'} = inputIndex;
|
||||
}
|
||||
|
|
@ -110,6 +119,10 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice
|
|||
|
||||
const ends = attributes.ends.map((end, i) => fixStartEndValues(end, i, inputShape, axes, steps));
|
||||
|
||||
if (axes.length !== starts.length || axes.length !== ends.length) {
|
||||
throw new Error('start, ends and axes should have the same number of elements');
|
||||
}
|
||||
|
||||
if (axes.length !== inputShape.length) {
|
||||
for (let i = 0; i < inputShape.length; ++i) {
|
||||
if (!axes.includes(i)) {
|
||||
|
|
@ -131,40 +144,66 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice
|
|||
array[i] = -step;
|
||||
}
|
||||
});
|
||||
// Output rank is expected to be less than or equal to the input rank.
|
||||
const enableShapeUniforms = enableShapesUniforms(inputs[0].dims.length);
|
||||
const inputShapeOrRank = enableShapeUniforms ? inputs[0].dims.length : inputs[0].dims;
|
||||
|
||||
const outputShape = inputShape.slice(0);
|
||||
axes.forEach((axis, _) => {
|
||||
outputShape[axis] = Math.ceil((ends[axis] - starts[axis]) / steps[axis]);
|
||||
});
|
||||
const outputShapeOrRank = enableShapeUniforms ? outputShape.length : outputShape;
|
||||
|
||||
const outputTensorInfo: TensorInfo = {dims: outputShape, dataType: inputs[0].dataType};
|
||||
|
||||
const output = outputVariable('output', inputs[0].dataType, outputShape);
|
||||
const input = inputVariable('input', inputs[0].dataType, inputShape);
|
||||
const output = outputVariable('output', inputs[0].dataType, outputShapeOrRank);
|
||||
const input = inputVariable('input', inputs[0].dataType, inputShapeOrRank);
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
const programUniforms: ProgramUniform[] = [];
|
||||
const uniforms: UniformsArrayType = [];
|
||||
if (enableShapeUniforms) {
|
||||
uniforms.push({name: 'starts', type: starts.length > 1 ? `vec${starts.length}<u32>` : 'u32'});
|
||||
uniforms.push({name: 'signs', type: signs.length > 1 ? `vec${signs.length}<i32>` : 'i32'});
|
||||
uniforms.push({name: 'steps', type: steps.length > 1 ? `vec${steps.length}<u32>` : 'u32'});
|
||||
programUniforms.push({type: 'uint32', data: starts});
|
||||
programUniforms.push({type: 'int32', data: signs});
|
||||
programUniforms.push({type: 'uint32', data: steps});
|
||||
}
|
||||
uniforms.push({name: 'outputSize', type: 'u32'});
|
||||
programUniforms.push({type: 'uint32', data: outputSize});
|
||||
if (enableShapeUniforms) {
|
||||
programUniforms.push(...createTensorShapeVariables(inputs[0].dims));
|
||||
programUniforms.push(...createTensorShapeVariables(outputShape));
|
||||
}
|
||||
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
${shaderHelper.declareVariables(input, output)}
|
||||
const signs = array<i32, ${signs.length}>(${signs.map(i => `${i}i`).join(',')});
|
||||
const starts = array<u32, ${starts.length}>(${starts.map(i => `${i}u`).join(',')});
|
||||
const ends = array<u32, ${ends.length}>(${ends.map(i => `${i}u`).join(',')});
|
||||
const steps = array<u32, ${steps.length}>(${steps.map(i => `${i}u`).join(',')});
|
||||
const inputShape = array<u32, ${inputShape.length}>(${inputShape.map(i => `${i}u`).join(',')});
|
||||
${shaderHelper.registerUniforms(uniforms).declareVariables(input, output)}
|
||||
${enableShapeUniforms ? '' : [
|
||||
`const signs = array<i32, ${signs.length}>(${signs.map(i => `${i}i`).join(',')});`,
|
||||
`const starts = array<u32, ${starts.length}>(${starts.map(i => `${i}u`).join(',')});`,
|
||||
`const steps = array<u32, ${steps.length}>(${steps.map(i => `${i}u`).join(',')});`,
|
||||
`const inputShape = array<u32, ${inputShape.length}>(${inputShape.map(i => `${i}u`).join(',')});`
|
||||
].join('\n')}
|
||||
|
||||
${calculateInputIndicesImpl(input, output, inputShape, outputShape)}
|
||||
${calculateInputIndicesImpl(input, output, inputShape, outputShape, enableShapeUniforms)}
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
|
||||
let outputIndices = ${output.offsetToIndices('global_idx')};
|
||||
let inputIndices = calculateInputIndices(outputIndices);
|
||||
${output.setByOffset('global_idx', input.getByIndices('inputIndices'))}
|
||||
}`;
|
||||
return {
|
||||
name: 'Slice',
|
||||
shaderCache: {hint: `${attributes.cacheKey}|${inputs[4]?.dims ?? ''}`},
|
||||
shaderCache: {
|
||||
hint: enableShapeUniforms ? `${signs.length}_${starts.length}_${steps.length}` :
|
||||
`${attributes.cacheKey} | ${inputs[4]?.dims ?? ''}`,
|
||||
inputDependencies: [enableShapeUniforms ? 'rank' : 'dims']
|
||||
},
|
||||
getShaderSource,
|
||||
getRunData: () => ({
|
||||
outputs: [outputTensorInfo],
|
||||
dispatchGroup: {x: Math.ceil(inputSize / 64 /* workgroup size */)},
|
||||
programUniforms
|
||||
})
|
||||
};
|
||||
};
|
||||
|
|
|
|||
|
|
@ -21,6 +21,29 @@
|
|||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Slice float32 with input[0] dim > 4",
|
||||
"operator": "Slice",
|
||||
"attributes": [],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[1, 1, 1, 1, 5] T[1] T[1] T[1] (float32)",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [
|
||||
0.3964604139328003, -0.8916832804679871, -1.6578896045684814, 1.960708737373352, 1.181204915046692
|
||||
],
|
||||
"dims": [1, 1, 1, 1, 5],
|
||||
"type": "float32"
|
||||
},
|
||||
{ "data": [3], "dims": [1], "type": "int64" },
|
||||
{ "data": [4], "dims": [1], "type": "int64" },
|
||||
{ "data": [4], "dims": [1], "type": "int64" }
|
||||
],
|
||||
"outputs": [{ "data": [1.960708737373352], "dims": [1, 1, 1, 1, 1], "type": "float32" }]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Slice int32",
|
||||
"operator": "Slice",
|
||||
|
|
|
|||
Loading…
Reference in a new issue