diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 65e54414e9..c40229cde9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -875,11 +875,12 @@ class ShaderHelperImpl implements ShaderHelper { @builtin(workgroup_id) workgroup_id : vec3, @builtin(num_workgroups) num_workgroups : vec3`; const globalIdxDefinition = is1DimensionDispatch - ? 'let global_idx = global_id.x; let local_idx = local_id.x;' - : `let global_idx = (workgroup_id.z * num_workgroups[0] * num_workgroups[1] + - workgroup_id.y * num_workgroups[0] + workgroup_id.x) * ${ - workgroupSizeX * workgroupSizeY * workgroupSizeZ - }u + local_idx;`; + ? `let global_idx = global_id.x; + let local_idx = local_id.x; + let workgroup_index = workgroup_id.x;` + : `let workgroup_index = workgroup_id.z * num_workgroups[0] * num_workgroups[1] + + workgroup_id.y * num_workgroups[0] + workgroup_id.x; + let global_idx = workgroup_index * ${workgroupSizeX * workgroupSizeY * workgroupSizeZ}u + local_idx;`; return `@compute @workgroup_size(${workgroupSizeX}, ${workgroupSizeY}, ${workgroupSizeZ}) fn main(${paramList}) { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts index 3c08580128..ee877f8f0c 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts @@ -36,33 +36,62 @@ const permFunctionBody = (perm: number[], rank: number, input: IndicesHelper, ou return reverseFunc.join('\n'); }; +const squeezeShape = (shape: readonly number[], adjustedPerm: number[]): { newShape: number[]; newPerm: number[] } => { + const newShape: number[] = []; + const newPerm: number[] = []; + for (let i = 0; i < shape.length; ++i) { + if (shape[i] !== 1) { + newShape.push(shape[i]); + } + if (shape[adjustedPerm[i]] !== 1) { + newPerm.push(adjustedPerm[i]); + } + } + return { newShape, newPerm }; +}; + export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: number[]): ProgramInfo => { const inputDataType = inputTensor.dataType; const inputRank = inputTensor.dims.length; const perm = getAdjustedPerm(inputRank, permAttr); const outputShape = getOutputShape(inputTensor.dims, perm); - const output = outputVariable('output', inputDataType, outputShape.length); - const input = inputVariable('a', inputDataType, inputRank); + const { newShape, newPerm } = squeezeShape(inputTensor.dims, perm); + const channelsLast = ShapeUtil.areEqual(newPerm, [2, 3, 1]); + const channelsFirst = ShapeUtil.areEqual(newPerm, [3, 1, 2]); + const useShared = (newShape.length === 2 && newPerm[0] > newPerm[1]) || channelsLast || channelsFirst; + let newInputShape = useShared ? newShape : inputTensor.dims; + let newOutputShape = outputShape; + if (useShared) { + newInputShape = channelsLast + ? [newShape[0], newShape[1] * newShape[2]] + : channelsFirst + ? [newShape[0] * newShape[1], newShape[2]] + : newShape; + newOutputShape = [newInputShape[1], newInputShape[0]]; + } + const input = inputVariable('a', inputDataType, newInputShape.length); + const output = outputVariable('output', inputDataType, newOutputShape.length); + const tileSize = 16; let getShaderSource; - if (perm.length === 2 && perm[0] === 1 && perm[1] === 0) { - const wgslType = output.type.value; - const workgroupSize: [number, number, number] = [16, 16, 1]; + if (useShared) { getShaderSource = (shaderHelper: ShaderHelper) => ` ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)} - var tile : array, ${workgroupSize[0]}>; - ${shaderHelper.mainStart(workgroupSize)} - var x = workgroup_id.x * ${workgroupSize[0]}u + local_id.x; - var y = workgroup_id.y * ${workgroupSize[0]}u + local_id.y; - let width = uniforms.output_shape[0]; - let height = uniforms.output_shape[1]; - if (x < width && y < height) { - tile[local_id.y][local_id.x] = ${input.getByOffset('y * width + x')}; + var tile : array, ${tileSize}>; + ${shaderHelper.mainStart([tileSize, tileSize, 1])} + let stride = (uniforms.output_shape[1] - 1) / ${tileSize} + 1; + let workgroup_id_x = workgroup_index % stride; + let workgroup_id_y = workgroup_index / stride; + let input_col = workgroup_id_y * ${tileSize}u + local_id.x; + let input_row = workgroup_id_x * ${tileSize}u + local_id.y; + if (input_row < uniforms.a_shape[0] && input_col < uniforms.a_shape[1]) { + tile[local_id.y][local_id.x] = ${input.getByIndices(`${input.type.indices}(input_row, input_col)`)}; } workgroupBarrier(); - x = workgroup_id.y * ${workgroupSize[0]}u + local_id.x; - y = workgroup_id.x * ${workgroupSize[0]}u + local_id.y; - if (x < height && y < width) { - ${output.setByOffset('y * height + x', 'tile[local_id.x][local_id.y]')} + + let output_col = workgroup_id_x * ${tileSize}u + local_id.x; + let output_row = workgroup_id_y * ${tileSize}u + local_id.y; + if (output_row < uniforms.output_shape[0] && output_col < uniforms.output_shape[1]) { + ${output.setByIndices(`${output.type.indices}(output_row, output_col)`, 'tile[local_id.x][local_id.y]')} } }`; } else { @@ -81,16 +110,18 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu }`; } return { - name: 'Transpose', + name: useShared ? 'TransposeShared' : 'Transpose', shaderCache: { hint: `${permAttr}`, inputDependencies: ['rank'] }, getRunData: () => { const outputSize = ShapeUtil.size(outputShape); return { outputs: [{ dims: outputShape, dataType: inputTensor.dataType }], - dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + dispatchGroup: useShared + ? { x: Math.ceil(newOutputShape[1] / tileSize), y: Math.ceil(newOutputShape[0] / tileSize) } + : { x: Math.ceil(outputSize / 64 /* workgroup size */) }, programUniforms: [ { type: DataType.uint32, data: outputSize }, - ...createTensorShapeVariables(inputTensor.dims, outputShape), + ...createTensorShapeVariables(newInputShape, newOutputShape), ], }; }, diff --git a/js/web/test/data/ops/transpose.jsonc b/js/web/test/data/ops/transpose.jsonc index 2b01475522..a7265d6444 100644 --- a/js/web/test/data/ops/transpose.jsonc +++ b/js/web/test/data/ops/transpose.jsonc @@ -167,6 +167,78 @@ } ] }, + { + "name": "Transpose squeezed 2d - perms:[0, 2, 1, 3]", + "operator": "Transpose", + "attributes": [{ "name": "perm", "data": [0, 2, 1, 3], "type": "ints" }], + "cases": [ + { + "name": "T[1, 3 , 4, 1]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [1, 3, 4, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12], + "dims": [1, 4, 3, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Transpose 4D channelsFirst - perms:[0, 3, 1, 2]", + "operator": "Transpose", + "attributes": [{ "name": "perm", "data": [0, 3, 1, 2], "type": "ints" }], + "cases": [ + { + "name": "T[1, 2, 3, 4]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], + "dims": [1, 2, 3, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 5, 9, 13, 17, 21, 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23, 4, 8, 12, 16, 20, 24], + "dims": [1, 4, 2, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Transpose 4D channelsLast - perms:[0, 2, 3, 1]", + "operator": "Transpose", + "attributes": [{ "name": "perm", "data": [0, 2, 3, 1], "type": "ints" }], + "cases": [ + { + "name": "T[1, 2, 3, 4]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], + "dims": [1, 2, 3, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 13, 2, 14, 3, 15, 4, 16, 5, 17, 6, 18, 7, 19, 8, 20, 9, 21, 10, 22, 11, 23, 12, 24], + "dims": [1, 3, 4, 2], + "type": "float32" + } + ] + } + ] + }, { "name": "Transpose 5D - perms:[4, 3, 1, 0, 2]", "operator": "Transpose",