mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[js/webgpu] Optimize transpose as reshape when suitable (#22870)
BUG #22031
This commit is contained in:
parent
c4f3742bb4
commit
e597eaed4a
2 changed files with 102 additions and 17 deletions
|
|
@ -48,17 +48,61 @@ const squeezeShape = (shape: readonly number[], adjustedPerm: number[]): { newSh
|
|||
return { newShape, newPerm };
|
||||
};
|
||||
|
||||
const isTransposeReshape = (perm: number[], shape: readonly number[]) => {
|
||||
// As long as the dims with values > 1 stay in the same order, it's a reshape.
|
||||
// Example: Shape=(1,1,1024,4096) -> perm=(2,0,3,1).
|
||||
let lastPermutedAxis = 0;
|
||||
for (let i = 0; i < perm.length; ++i) {
|
||||
if (shape[perm[i]] === 1) {
|
||||
continue;
|
||||
}
|
||||
if (perm[i] < lastPermutedAxis) {
|
||||
return false;
|
||||
}
|
||||
lastPermutedAxis = perm[i];
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: number[]): ProgramInfo => {
|
||||
const inputDataType = inputTensor.dataType;
|
||||
const inputRank = inputTensor.dims.length;
|
||||
const perm = getAdjustedPerm(inputRank, permAttr);
|
||||
const outputShape = getOutputShape(inputTensor.dims, perm);
|
||||
let newInputShape = inputTensor.dims;
|
||||
let newOutputShape = outputShape;
|
||||
const transposeAsReshape = isTransposeReshape(perm, inputTensor.dims);
|
||||
let getShaderSource;
|
||||
if (transposeAsReshape) {
|
||||
getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
const input = inputVariable('input', inputDataType, newInputShape, 4);
|
||||
const output = outputVariable('output', inputDataType, newOutputShape, 4);
|
||||
return `
|
||||
${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)}
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
|
||||
output[global_idx] = input[global_idx];
|
||||
}`;
|
||||
};
|
||||
|
||||
return {
|
||||
name: 'TransposeCopy',
|
||||
shaderCache: { inputDependencies: ['type'] },
|
||||
getRunData: () => {
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
return {
|
||||
outputs: [{ dims: outputShape, dataType: inputTensor.dataType }],
|
||||
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* components */) },
|
||||
programUniforms: [{ type: DataType.uint32, data: Math.ceil(outputSize / 4) }],
|
||||
};
|
||||
},
|
||||
getShaderSource,
|
||||
};
|
||||
}
|
||||
const { newShape, newPerm } = squeezeShape(inputTensor.dims, perm);
|
||||
const channelsLast = ShapeUtil.areEqual(newPerm, [2, 3, 1]);
|
||||
const channelsFirst = ShapeUtil.areEqual(newPerm, [3, 1, 2]);
|
||||
const useShared = (newShape.length === 2 && newPerm[0] > newPerm[1]) || channelsLast || channelsFirst;
|
||||
let newInputShape = useShared ? newShape : inputTensor.dims;
|
||||
let newOutputShape = outputShape;
|
||||
const useShared = newShape.length === 2 || channelsLast || channelsFirst;
|
||||
if (useShared) {
|
||||
newInputShape = channelsLast
|
||||
? [newShape[0], newShape[1] * newShape[2]]
|
||||
|
|
@ -66,13 +110,11 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu
|
|||
? [newShape[0] * newShape[1], newShape[2]]
|
||||
: newShape;
|
||||
newOutputShape = [newInputShape[1], newInputShape[0]];
|
||||
}
|
||||
const input = inputVariable('a', inputDataType, newInputShape.length);
|
||||
const output = outputVariable('output', inputDataType, newOutputShape.length);
|
||||
const tileSize = 16;
|
||||
let getShaderSource;
|
||||
if (useShared) {
|
||||
getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
const tileSize = 16;
|
||||
getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
const input = inputVariable('a', inputDataType, newInputShape.length);
|
||||
const output = outputVariable('output', inputDataType, newOutputShape.length);
|
||||
return `
|
||||
${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)}
|
||||
var<workgroup> tile : array<array<${output.type.value}, ${tileSize + 1}>, ${tileSize}>;
|
||||
${shaderHelper.mainStart([tileSize, tileSize, 1])}
|
||||
|
|
@ -92,8 +134,29 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu
|
|||
${output.setByIndices(`${output.type.indices}(output_row, output_col)`, 'tile[local_id.x][local_id.y]')}
|
||||
}
|
||||
}`;
|
||||
} else {
|
||||
getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
};
|
||||
return {
|
||||
name: 'TransposeShared',
|
||||
shaderCache: { inputDependencies: ['type'] },
|
||||
getRunData: () => {
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
return {
|
||||
outputs: [{ dims: outputShape, dataType: inputTensor.dataType }],
|
||||
dispatchGroup: { x: Math.ceil(newOutputShape[1] / tileSize), y: Math.ceil(newOutputShape[0] / tileSize) },
|
||||
programUniforms: [
|
||||
{ type: DataType.uint32, data: outputSize },
|
||||
...createTensorShapeVariables(newInputShape, newOutputShape),
|
||||
],
|
||||
};
|
||||
},
|
||||
getShaderSource,
|
||||
};
|
||||
}
|
||||
|
||||
getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
const input = inputVariable('a', inputDataType, newInputShape.length);
|
||||
const output = outputVariable('output', inputDataType, newOutputShape.length);
|
||||
return `
|
||||
${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)}
|
||||
|
||||
${permFunctionBody(perm, inputRank, input, output)}
|
||||
|
|
@ -106,17 +169,15 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu
|
|||
|
||||
${output.setByOffset('global_idx', input.getByIndices('aIndices'))}
|
||||
}`;
|
||||
}
|
||||
};
|
||||
return {
|
||||
name: useShared ? 'TransposeShared' : 'Transpose',
|
||||
name: 'Transpose',
|
||||
shaderCache: { hint: `${permAttr}`, inputDependencies: ['rank'] },
|
||||
getRunData: () => {
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
return {
|
||||
outputs: [{ dims: outputShape, dataType: inputTensor.dataType }],
|
||||
dispatchGroup: useShared
|
||||
? { x: Math.ceil(newOutputShape[1] / tileSize), y: Math.ceil(newOutputShape[0] / tileSize) }
|
||||
: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
|
||||
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
|
||||
programUniforms: [
|
||||
{ type: DataType.uint32, data: outputSize },
|
||||
...createTensorShapeVariables(newInputShape, newOutputShape),
|
||||
|
|
|
|||
|
|
@ -263,6 +263,30 @@
|
|||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Transpose as reshape - perms:[1, 0, 2, 4, 3]",
|
||||
"operator": "Transpose",
|
||||
"attributes": [{ "name": "perm", "data": [1, 0, 2, 4, 3], "type": "ints" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[3, 1, 2, 1, 4]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
|
||||
"dims": [3, 1, 2, 1, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
|
||||
"dims": [1, 3, 2, 4, 1],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Transpose - perms:[1, 0]",
|
||||
"operator": "Transpose",
|
||||
|
|
|
|||
Loading…
Reference in a new issue