From a80bfed5b428756df59397851e351a2b9c2efb6f Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Thu, 5 Sep 2024 03:04:04 +0800 Subject: [PATCH] [js/webgpu] Optimize transpose (#21964) ### Description Fix bugs in previous implementation and add more situations to go the optimized path. Below situations will go to the optimized path. 1. 2d inputs or squeezed 2d inputs 2. channels last or channels first transpose. For example, channel last transpose: [1, 256, 512, 512] -> [1, 512, 512, 256] For this case, the transpose becomes [256, 512x512] -> [512x512, 256] ### Motivation and Context For SD Turbo demo, the total transpose time becomes 39.98ms from 122.09ms. And the correspnding percents becomes 3.89% from 11.05% in this demo. This PR will also help #21618, the total transpose time in that demo becomes 17.32 ms from 70.25 ms on my iGPUs. --- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 11 +-- js/web/lib/wasm/jsep/webgpu/ops/transpose.ts | 71 +++++++++++++------ js/web/test/data/ops/transpose.jsonc | 72 ++++++++++++++++++++ 3 files changed, 129 insertions(+), 25 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 65e54414e9..c40229cde9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -875,11 +875,12 @@ class ShaderHelperImpl implements ShaderHelper { @builtin(workgroup_id) workgroup_id : vec3, @builtin(num_workgroups) num_workgroups : vec3`; const globalIdxDefinition = is1DimensionDispatch - ? 'let global_idx = global_id.x; let local_idx = local_id.x;' - : `let global_idx = (workgroup_id.z * num_workgroups[0] * num_workgroups[1] + - workgroup_id.y * num_workgroups[0] + workgroup_id.x) * ${ - workgroupSizeX * workgroupSizeY * workgroupSizeZ - }u + local_idx;`; + ? `let global_idx = global_id.x; + let local_idx = local_id.x; + let workgroup_index = workgroup_id.x;` + : `let workgroup_index = workgroup_id.z * num_workgroups[0] * num_workgroups[1] + + workgroup_id.y * num_workgroups[0] + workgroup_id.x; + let global_idx = workgroup_index * ${workgroupSizeX * workgroupSizeY * workgroupSizeZ}u + local_idx;`; return `@compute @workgroup_size(${workgroupSizeX}, ${workgroupSizeY}, ${workgroupSizeZ}) fn main(${paramList}) { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts index 3c08580128..ee877f8f0c 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts @@ -36,33 +36,62 @@ const permFunctionBody = (perm: number[], rank: number, input: IndicesHelper, ou return reverseFunc.join('\n'); }; +const squeezeShape = (shape: readonly number[], adjustedPerm: number[]): { newShape: number[]; newPerm: number[] } => { + const newShape: number[] = []; + const newPerm: number[] = []; + for (let i = 0; i < shape.length; ++i) { + if (shape[i] !== 1) { + newShape.push(shape[i]); + } + if (shape[adjustedPerm[i]] !== 1) { + newPerm.push(adjustedPerm[i]); + } + } + return { newShape, newPerm }; +}; + export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: number[]): ProgramInfo => { const inputDataType = inputTensor.dataType; const inputRank = inputTensor.dims.length; const perm = getAdjustedPerm(inputRank, permAttr); const outputShape = getOutputShape(inputTensor.dims, perm); - const output = outputVariable('output', inputDataType, outputShape.length); - const input = inputVariable('a', inputDataType, inputRank); + const { newShape, newPerm } = squeezeShape(inputTensor.dims, perm); + const channelsLast = ShapeUtil.areEqual(newPerm, [2, 3, 1]); + const channelsFirst = ShapeUtil.areEqual(newPerm, [3, 1, 2]); + const useShared = (newShape.length === 2 && newPerm[0] > newPerm[1]) || channelsLast || channelsFirst; + let newInputShape = useShared ? newShape : inputTensor.dims; + let newOutputShape = outputShape; + if (useShared) { + newInputShape = channelsLast + ? [newShape[0], newShape[1] * newShape[2]] + : channelsFirst + ? [newShape[0] * newShape[1], newShape[2]] + : newShape; + newOutputShape = [newInputShape[1], newInputShape[0]]; + } + const input = inputVariable('a', inputDataType, newInputShape.length); + const output = outputVariable('output', inputDataType, newOutputShape.length); + const tileSize = 16; let getShaderSource; - if (perm.length === 2 && perm[0] === 1 && perm[1] === 0) { - const wgslType = output.type.value; - const workgroupSize: [number, number, number] = [16, 16, 1]; + if (useShared) { getShaderSource = (shaderHelper: ShaderHelper) => ` ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)} - var tile : array, ${workgroupSize[0]}>; - ${shaderHelper.mainStart(workgroupSize)} - var x = workgroup_id.x * ${workgroupSize[0]}u + local_id.x; - var y = workgroup_id.y * ${workgroupSize[0]}u + local_id.y; - let width = uniforms.output_shape[0]; - let height = uniforms.output_shape[1]; - if (x < width && y < height) { - tile[local_id.y][local_id.x] = ${input.getByOffset('y * width + x')}; + var tile : array, ${tileSize}>; + ${shaderHelper.mainStart([tileSize, tileSize, 1])} + let stride = (uniforms.output_shape[1] - 1) / ${tileSize} + 1; + let workgroup_id_x = workgroup_index % stride; + let workgroup_id_y = workgroup_index / stride; + let input_col = workgroup_id_y * ${tileSize}u + local_id.x; + let input_row = workgroup_id_x * ${tileSize}u + local_id.y; + if (input_row < uniforms.a_shape[0] && input_col < uniforms.a_shape[1]) { + tile[local_id.y][local_id.x] = ${input.getByIndices(`${input.type.indices}(input_row, input_col)`)}; } workgroupBarrier(); - x = workgroup_id.y * ${workgroupSize[0]}u + local_id.x; - y = workgroup_id.x * ${workgroupSize[0]}u + local_id.y; - if (x < height && y < width) { - ${output.setByOffset('y * height + x', 'tile[local_id.x][local_id.y]')} + + let output_col = workgroup_id_x * ${tileSize}u + local_id.x; + let output_row = workgroup_id_y * ${tileSize}u + local_id.y; + if (output_row < uniforms.output_shape[0] && output_col < uniforms.output_shape[1]) { + ${output.setByIndices(`${output.type.indices}(output_row, output_col)`, 'tile[local_id.x][local_id.y]')} } }`; } else { @@ -81,16 +110,18 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu }`; } return { - name: 'Transpose', + name: useShared ? 'TransposeShared' : 'Transpose', shaderCache: { hint: `${permAttr}`, inputDependencies: ['rank'] }, getRunData: () => { const outputSize = ShapeUtil.size(outputShape); return { outputs: [{ dims: outputShape, dataType: inputTensor.dataType }], - dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + dispatchGroup: useShared + ? { x: Math.ceil(newOutputShape[1] / tileSize), y: Math.ceil(newOutputShape[0] / tileSize) } + : { x: Math.ceil(outputSize / 64 /* workgroup size */) }, programUniforms: [ { type: DataType.uint32, data: outputSize }, - ...createTensorShapeVariables(inputTensor.dims, outputShape), + ...createTensorShapeVariables(newInputShape, newOutputShape), ], }; }, diff --git a/js/web/test/data/ops/transpose.jsonc b/js/web/test/data/ops/transpose.jsonc index 2b01475522..a7265d6444 100644 --- a/js/web/test/data/ops/transpose.jsonc +++ b/js/web/test/data/ops/transpose.jsonc @@ -167,6 +167,78 @@ } ] }, + { + "name": "Transpose squeezed 2d - perms:[0, 2, 1, 3]", + "operator": "Transpose", + "attributes": [{ "name": "perm", "data": [0, 2, 1, 3], "type": "ints" }], + "cases": [ + { + "name": "T[1, 3 , 4, 1]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [1, 3, 4, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12], + "dims": [1, 4, 3, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Transpose 4D channelsFirst - perms:[0, 3, 1, 2]", + "operator": "Transpose", + "attributes": [{ "name": "perm", "data": [0, 3, 1, 2], "type": "ints" }], + "cases": [ + { + "name": "T[1, 2, 3, 4]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], + "dims": [1, 2, 3, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 5, 9, 13, 17, 21, 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23, 4, 8, 12, 16, 20, 24], + "dims": [1, 4, 2, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Transpose 4D channelsLast - perms:[0, 2, 3, 1]", + "operator": "Transpose", + "attributes": [{ "name": "perm", "data": [0, 2, 3, 1], "type": "ints" }], + "cases": [ + { + "name": "T[1, 2, 3, 4]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], + "dims": [1, 2, 3, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 13, 2, 14, 3, 15, 4, 16, 5, 17, 6, 18, 7, 19, 8, 20, 9, 21, 10, 22, 11, 23, 12, 24], + "dims": [1, 3, 4, 2], + "type": "float32" + } + ] + } + ] + }, { "name": "Transpose 5D - perms:[4, 3, 1, 0, 2]", "operator": "Transpose",