mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
[js/webgpu] Add uniforms support to concat op (#18238)
This commit is contained in:
parent
64c91d790b
commit
8dba6efd61
1 changed files with 43 additions and 13 deletions
|
|
@ -4,9 +4,9 @@
|
|||
import {TensorView} from '../../tensor-view';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext, ProgramInfo} from '../types';
|
||||
import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';
|
||||
|
||||
import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
|
||||
import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
|
||||
|
||||
export interface ConcatAttributes extends AttributeWithCacheKey {
|
||||
readonly axis: number;
|
||||
|
|
@ -33,9 +33,10 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
|
|||
}
|
||||
};
|
||||
|
||||
const calculateInputIndexImpl = (numberOfTensors: number): string => `
|
||||
const calculateInputIndexImpl = (numberOfTensors: number, sizeInConcatAxisStr: string): string => `
|
||||
fn calculateInputIndex(index: u32) -> u32 {
|
||||
for (var i: u32 = 0u; i < ${numberOfTensors}u; i += 1u ) {
|
||||
let sizeInConcatAxis = array<u32, ${numberOfTensors}u>(${sizeInConcatAxisStr});
|
||||
for (var i: u32 = 0u; i < ${numberOfTensors}; i += 1u ) {
|
||||
if (index < sizeInConcatAxis[i]) {
|
||||
return i;
|
||||
}
|
||||
|
|
@ -92,40 +93,69 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P
|
|||
const dataType = inputs[0].dataType;
|
||||
|
||||
let previousSum = 0;
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] = [];
|
||||
const inputShapeOrRanks = [];
|
||||
const enableInputShapesUniforms = [];
|
||||
const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}];
|
||||
for (let i = 0; i < inputs.length; ++i) {
|
||||
previousSum += inputs[i].dims[adjustedAxis];
|
||||
sizeInConcatAxis[i] = previousSum;
|
||||
|
||||
inputVars[i] = inputVariable(`input${i}`, dataType, inputs[i].dims);
|
||||
enableInputShapesUniforms.push(enableShapesUniforms(inputs[i].dims.length));
|
||||
inputShapeOrRanks.push(enableInputShapesUniforms[i] ? inputs[i].dims.length : inputs[i].dims);
|
||||
inputVars[i] = inputVariable(`input${i}`, dataType, inputShapeOrRanks[i]);
|
||||
inputDependencies.push(enableInputShapesUniforms[i] ? 'rank' : 'dims');
|
||||
programUniforms.push({type: 'uint32', data: sizeInConcatAxis[i]});
|
||||
}
|
||||
for (let i = 0; i < inputs.length; ++i) {
|
||||
if (enableInputShapesUniforms[i]) {
|
||||
programUniforms.push(...createTensorShapeVariables(inputs[i].dims));
|
||||
}
|
||||
}
|
||||
|
||||
const output = outputVariable('output', dataType, outputShape);
|
||||
const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length);
|
||||
if (enableOutputShapesUniforms) {
|
||||
programUniforms.push(...createTensorShapeVariables(outputShape));
|
||||
}
|
||||
|
||||
const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape;
|
||||
const output = outputVariable('output', dataType, outputShapeOrRank);
|
||||
|
||||
const indicesAxis = output.indicesGet('indices', adjustedAxis);
|
||||
const sizeInConcatAxisStr =
|
||||
Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(',');
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
${shaderHelper.declareVariables(...inputVars, output)}
|
||||
|
||||
const sizeInConcatAxis = array<u32, ${sizeInConcatAxis.length}>(${sizeInConcatAxis.map(i => `${i}u`).join(',')});
|
||||
${calculateInputIndexImpl(sizeInConcatAxis.length)}
|
||||
${(() => {
|
||||
shaderHelper.registerUniform('outputSize', 'u32');
|
||||
for (let i = 0; i < inputs.length; i++) {
|
||||
shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32');
|
||||
}
|
||||
return shaderHelper.declareVariables(...inputVars, output);
|
||||
})()}
|
||||
|
||||
${calculateInputIndexImpl(sizeInConcatAxis.length, sizeInConcatAxisStr)}
|
||||
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
|
||||
|
||||
var indices = ${output.offsetToIndices('global_idx')};
|
||||
|
||||
let inputIndex = calculateInputIndex(${indicesAxis});
|
||||
if (inputIndex != 0u) {
|
||||
let sizeInConcatAxis = array<u32, ${sizeInConcatAxis.length}u>(${sizeInConcatAxisStr});
|
||||
${indicesAxis} -= sizeInConcatAxis[inputIndex - 1u];
|
||||
}
|
||||
|
||||
${assignOutputData(inputVars, output)}
|
||||
}`;
|
||||
|
||||
return {
|
||||
name: 'Concat',
|
||||
shaderCache: {hint: `${axis}`},
|
||||
shaderCache: {hint: `${axis}`, inputDependencies},
|
||||
getRunData: () => ({
|
||||
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
|
||||
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}
|
||||
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
|
||||
programUniforms,
|
||||
}),
|
||||
getShaderSource,
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in a new issue