[js/webgpu] Optimize transpose (#21964)

### Description
<!-- Describe your changes. -->
Fix bugs in previous implementation and add more situations to go the
optimized path.

Below situations will go to the optimized path.
1. 2d inputs or squeezed 2d inputs
2. channels last or channels first transpose. For example, channel last
transpose: [1, 256, 512, 512] -> [1, 512, 512, 256]
For this case, the transpose becomes [256, 512x512] -> [512x512, 256]

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
For SD Turbo demo, the total transpose time becomes 39.98ms from
122.09ms. And the correspnding percents becomes 3.89% from 11.05% in
this demo.

This PR will also help #21618, the total transpose time in that demo
becomes 17.32 ms from 70.25 ms on my iGPUs.
This commit is contained in:
Jiajia Qin 2024-09-05 03:04:04 +08:00 committed by GitHub
parent 190588bb64
commit a80bfed5b4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 129 additions and 25 deletions

View file

@ -875,11 +875,12 @@ class ShaderHelperImpl implements ShaderHelper {
@builtin(workgroup_id) workgroup_id : vec3<u32>,
@builtin(num_workgroups) num_workgroups : vec3<u32>`;
const globalIdxDefinition = is1DimensionDispatch
? 'let global_idx = global_id.x; let local_idx = local_id.x;'
: `let global_idx = (workgroup_id.z * num_workgroups[0] * num_workgroups[1] +
workgroup_id.y * num_workgroups[0] + workgroup_id.x) * ${
workgroupSizeX * workgroupSizeY * workgroupSizeZ
}u + local_idx;`;
? `let global_idx = global_id.x;
let local_idx = local_id.x;
let workgroup_index = workgroup_id.x;`
: `let workgroup_index = workgroup_id.z * num_workgroups[0] * num_workgroups[1] +
workgroup_id.y * num_workgroups[0] + workgroup_id.x;
let global_idx = workgroup_index * ${workgroupSizeX * workgroupSizeY * workgroupSizeZ}u + local_idx;`;
return `@compute @workgroup_size(${workgroupSizeX}, ${workgroupSizeY}, ${workgroupSizeZ})
fn main(${paramList}) {

View file

@ -36,33 +36,62 @@ const permFunctionBody = (perm: number[], rank: number, input: IndicesHelper, ou
return reverseFunc.join('\n');
};
const squeezeShape = (shape: readonly number[], adjustedPerm: number[]): { newShape: number[]; newPerm: number[] } => {
const newShape: number[] = [];
const newPerm: number[] = [];
for (let i = 0; i < shape.length; ++i) {
if (shape[i] !== 1) {
newShape.push(shape[i]);
}
if (shape[adjustedPerm[i]] !== 1) {
newPerm.push(adjustedPerm[i]);
}
}
return { newShape, newPerm };
};
export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: number[]): ProgramInfo => {
const inputDataType = inputTensor.dataType;
const inputRank = inputTensor.dims.length;
const perm = getAdjustedPerm(inputRank, permAttr);
const outputShape = getOutputShape(inputTensor.dims, perm);
const output = outputVariable('output', inputDataType, outputShape.length);
const input = inputVariable('a', inputDataType, inputRank);
const { newShape, newPerm } = squeezeShape(inputTensor.dims, perm);
const channelsLast = ShapeUtil.areEqual(newPerm, [2, 3, 1]);
const channelsFirst = ShapeUtil.areEqual(newPerm, [3, 1, 2]);
const useShared = (newShape.length === 2 && newPerm[0] > newPerm[1]) || channelsLast || channelsFirst;
let newInputShape = useShared ? newShape : inputTensor.dims;
let newOutputShape = outputShape;
if (useShared) {
newInputShape = channelsLast
? [newShape[0], newShape[1] * newShape[2]]
: channelsFirst
? [newShape[0] * newShape[1], newShape[2]]
: newShape;
newOutputShape = [newInputShape[1], newInputShape[0]];
}
const input = inputVariable('a', inputDataType, newInputShape.length);
const output = outputVariable('output', inputDataType, newOutputShape.length);
const tileSize = 16;
let getShaderSource;
if (perm.length === 2 && perm[0] === 1 && perm[1] === 0) {
const wgslType = output.type.value;
const workgroupSize: [number, number, number] = [16, 16, 1];
if (useShared) {
getShaderSource = (shaderHelper: ShaderHelper) => `
${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)}
var<workgroup> tile : array<array<${wgslType}, ${workgroupSize[0] + 1}>, ${workgroupSize[0]}>;
${shaderHelper.mainStart(workgroupSize)}
var x = workgroup_id.x * ${workgroupSize[0]}u + local_id.x;
var y = workgroup_id.y * ${workgroupSize[0]}u + local_id.y;
let width = uniforms.output_shape[0];
let height = uniforms.output_shape[1];
if (x < width && y < height) {
tile[local_id.y][local_id.x] = ${input.getByOffset('y * width + x')};
var<workgroup> tile : array<array<${output.type.value}, ${tileSize + 1}>, ${tileSize}>;
${shaderHelper.mainStart([tileSize, tileSize, 1])}
let stride = (uniforms.output_shape[1] - 1) / ${tileSize} + 1;
let workgroup_id_x = workgroup_index % stride;
let workgroup_id_y = workgroup_index / stride;
let input_col = workgroup_id_y * ${tileSize}u + local_id.x;
let input_row = workgroup_id_x * ${tileSize}u + local_id.y;
if (input_row < uniforms.a_shape[0] && input_col < uniforms.a_shape[1]) {
tile[local_id.y][local_id.x] = ${input.getByIndices(`${input.type.indices}(input_row, input_col)`)};
}
workgroupBarrier();
x = workgroup_id.y * ${workgroupSize[0]}u + local_id.x;
y = workgroup_id.x * ${workgroupSize[0]}u + local_id.y;
if (x < height && y < width) {
${output.setByOffset('y * height + x', 'tile[local_id.x][local_id.y]')}
let output_col = workgroup_id_x * ${tileSize}u + local_id.x;
let output_row = workgroup_id_y * ${tileSize}u + local_id.y;
if (output_row < uniforms.output_shape[0] && output_col < uniforms.output_shape[1]) {
${output.setByIndices(`${output.type.indices}(output_row, output_col)`, 'tile[local_id.x][local_id.y]')}
}
}`;
} else {
@ -81,16 +110,18 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu
}`;
}
return {
name: 'Transpose',
name: useShared ? 'TransposeShared' : 'Transpose',
shaderCache: { hint: `${permAttr}`, inputDependencies: ['rank'] },
getRunData: () => {
const outputSize = ShapeUtil.size(outputShape);
return {
outputs: [{ dims: outputShape, dataType: inputTensor.dataType }],
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
dispatchGroup: useShared
? { x: Math.ceil(newOutputShape[1] / tileSize), y: Math.ceil(newOutputShape[0] / tileSize) }
: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
programUniforms: [
{ type: DataType.uint32, data: outputSize },
...createTensorShapeVariables(inputTensor.dims, outputShape),
...createTensorShapeVariables(newInputShape, newOutputShape),
],
};
},

View file

@ -167,6 +167,78 @@
}
]
},
{
"name": "Transpose squeezed 2d - perms:[0, 2, 1, 3]",
"operator": "Transpose",
"attributes": [{ "name": "perm", "data": [0, 2, 1, 3], "type": "ints" }],
"cases": [
{
"name": "T[1, 3 , 4, 1]",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
"dims": [1, 3, 4, 1],
"type": "float32"
}
],
"outputs": [
{
"data": [1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12],
"dims": [1, 4, 3, 1],
"type": "float32"
}
]
}
]
},
{
"name": "Transpose 4D channelsFirst - perms:[0, 3, 1, 2]",
"operator": "Transpose",
"attributes": [{ "name": "perm", "data": [0, 3, 1, 2], "type": "ints" }],
"cases": [
{
"name": "T[1, 2, 3, 4]",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
"dims": [1, 2, 3, 4],
"type": "float32"
}
],
"outputs": [
{
"data": [1, 5, 9, 13, 17, 21, 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23, 4, 8, 12, 16, 20, 24],
"dims": [1, 4, 2, 3],
"type": "float32"
}
]
}
]
},
{
"name": "Transpose 4D channelsLast - perms:[0, 2, 3, 1]",
"operator": "Transpose",
"attributes": [{ "name": "perm", "data": [0, 2, 3, 1], "type": "ints" }],
"cases": [
{
"name": "T[1, 2, 3, 4]",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
"dims": [1, 2, 3, 4],
"type": "float32"
}
],
"outputs": [
{
"data": [1, 13, 2, 14, 3, 15, 4, 16, 5, 17, 6, 18, 7, 19, 8, 20, 9, 21, 10, 22, 11, 23, 12, 24],
"dims": [1, 3, 4, 2],
"type": "float32"
}
]
}
]
},
{
"name": "Transpose 5D - perms:[4, 3, 1, 0, 2]",
"operator": "Transpose",