From 8dba6efd612ee647923670ec0919d8192d815cee Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Sat, 11 Nov 2023 05:46:03 +0800 Subject: [PATCH] [js/webgpu] Add uniforms support to concat op (#18238) --- js/web/lib/wasm/jsep/webgpu/ops/concat.ts | 56 +++++++++++++++++------ 1 file changed, 43 insertions(+), 13 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index 4b5ca869f0..43cc4a4c08 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -4,9 +4,9 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; export interface ConcatAttributes extends AttributeWithCacheKey { readonly axis: number; @@ -33,9 +33,10 @@ const validateInputs = (inputs: readonly TensorView[]): void => { } }; -const calculateInputIndexImpl = (numberOfTensors: number): string => ` +const calculateInputIndexImpl = (numberOfTensors: number, sizeInConcatAxisStr: string): string => ` fn calculateInputIndex(index: u32) -> u32 { - for (var i: u32 = 0u; i < ${numberOfTensors}u; i += 1u ) { + let sizeInConcatAxis = array(${sizeInConcatAxisStr}); + for (var i: u32 = 0u; i < ${numberOfTensors}; i += 1u ) { if (index < sizeInConcatAxis[i]) { return i; } @@ -92,40 +93,69 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P const dataType = inputs[0].dataType; let previousSum = 0; + const inputDependencies: ProgramInputTensorInfoDependency[] = []; + const inputShapeOrRanks = []; + const enableInputShapesUniforms = []; + const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}]; for (let i = 0; i < inputs.length; ++i) { previousSum += inputs[i].dims[adjustedAxis]; sizeInConcatAxis[i] = previousSum; - - inputVars[i] = inputVariable(`input${i}`, dataType, inputs[i].dims); + enableInputShapesUniforms.push(enableShapesUniforms(inputs[i].dims.length)); + inputShapeOrRanks.push(enableInputShapesUniforms[i] ? inputs[i].dims.length : inputs[i].dims); + inputVars[i] = inputVariable(`input${i}`, dataType, inputShapeOrRanks[i]); + inputDependencies.push(enableInputShapesUniforms[i] ? 'rank' : 'dims'); + programUniforms.push({type: 'uint32', data: sizeInConcatAxis[i]}); + } + for (let i = 0; i < inputs.length; ++i) { + if (enableInputShapesUniforms[i]) { + programUniforms.push(...createTensorShapeVariables(inputs[i].dims)); + } } - const output = outputVariable('output', dataType, outputShape); + const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length); + if (enableOutputShapesUniforms) { + programUniforms.push(...createTensorShapeVariables(outputShape)); + } + + const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape; + const output = outputVariable('output', dataType, outputShapeOrRank); const indicesAxis = output.indicesGet('indices', adjustedAxis); + const sizeInConcatAxisStr = + Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(','); const getShaderSource = (shaderHelper: ShaderHelper) => ` - ${shaderHelper.declareVariables(...inputVars, output)} - const sizeInConcatAxis = array(${sizeInConcatAxis.map(i => `${i}u`).join(',')}); - ${calculateInputIndexImpl(sizeInConcatAxis.length)} + ${(() => { + shaderHelper.registerUniform('outputSize', 'u32'); + for (let i = 0; i < inputs.length; i++) { + shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32'); + } + return shaderHelper.declareVariables(...inputVars, output); + })()} + + ${calculateInputIndexImpl(sizeInConcatAxis.length, sizeInConcatAxisStr)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} var indices = ${output.offsetToIndices('global_idx')}; let inputIndex = calculateInputIndex(${indicesAxis}); if (inputIndex != 0u) { + let sizeInConcatAxis = array(${sizeInConcatAxisStr}); ${indicesAxis} -= sizeInConcatAxis[inputIndex - 1u]; } ${assignOutputData(inputVars, output)} }`; + return { name: 'Concat', - shaderCache: {hint: `${axis}`}, + shaderCache: {hint: `${axis}`, inputDependencies}, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms, }), getShaderSource, };