mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
[js/webgpu] Fix the transpose error when dims > 4D (#18027)
### Description <!-- Describe your changes. --> Currently, the uniform support has bugs when dims rank is larger than 4. See https://github.com/microsoft/onnxruntime/issues/17860 item 1. So this PR only enables shapes uniforms when shape rank is <= 4 for transpose. Otherwise, below compilation errors are thrown: ``` 1 error(s) generated while compiling the shader: :3:50 error: uniform storage requires that array elements are aligned to 16 bytes, but array element of type 'u32' has a stride of 4 bytes. Consider using a vector or struct as the element type instead. struct Uniforms { output_size:u32, a_shape:array<u32, 5>, a_strides:array<u32, 5>, output_shape:array<u32, 5>, output_strides:array<u32, 5> }; ^^^^^^^^^^^^^ :3:7 note: see layout of struct: /* align(4) size(84) */ struct Uniforms { /* offset( 0) align(4) size( 4) */ output_size : u32; /* offset( 4) align(4) size(20) */ a_shape : array<u32, 5>; /* offset(24) align(4) size(20) */ a_strides : array<u32, 5>; /* offset(44) align(4) size(20) */ output_shape : array<u32, 5>; /* offset(64) align(4) size(20) */ output_strides : array<u32, 5>; /* */ }; struct Uniforms { output_size:u32, a_shape:array<u32, 5>, a_strides:array<u32, 5>, output_shape:array<u32, 5>, output_strides:array<u32, 5> }; ^^^^^^ :4:42 note: 'Uniforms' used in address space 'uniform' here @group(0) @binding(2) var<uniform> uniforms: Uniforms; ^^^^^^^^ ```
This commit is contained in:
parent
f0d5ea5930
commit
8a12b2cea6
5 changed files with 59 additions and 25 deletions
|
|
@ -803,3 +803,6 @@ export const getBroadcastDims = (inShape: readonly number[], outShape: readonly
|
|||
}
|
||||
return dims;
|
||||
};
|
||||
|
||||
// TODO: remove this limitation once >4D dims are supported by uniform.
|
||||
export const enableShapesUniforms = (rank: number): boolean => rank <= 4;
|
||||
|
|
|
|||
|
|
@ -232,7 +232,7 @@ const convTranspose2d =
|
|||
// STEP.1: transpose weight
|
||||
const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ??
|
||||
context.compute(
|
||||
createTransposeProgramInfo(inputs[1].dataType, inputs[1].dims.length, weightTransposePerm),
|
||||
createTransposeProgramInfo(inputs[1], weightTransposePerm),
|
||||
{inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0];
|
||||
if (attributes.wIsConst && !context.kernelCustomData.wT) {
|
||||
context.kernelCustomData.wT = transposedWeight;
|
||||
|
|
|
|||
|
|
@ -168,7 +168,7 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
|
|||
if (isChannelsLast) {
|
||||
const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ??
|
||||
context.compute(
|
||||
createTransposeProgramInfo(inputs[1].dataType, inputs[1].dims.length, weightTransposeAttribute),
|
||||
createTransposeProgramInfo(inputs[1], weightTransposeAttribute),
|
||||
{inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0];
|
||||
if (attributes.wIsConst && !context.kernelCustomData.wT) {
|
||||
context.kernelCustomData.wT = transposedWeight;
|
||||
|
|
@ -208,7 +208,7 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
|
|||
// STEP.1: transpose weight
|
||||
const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ??
|
||||
context.compute(
|
||||
createTransposeProgramInfo(inputs[1].dataType, inputs[1].dims.length, weightTransposeAttribute),
|
||||
createTransposeProgramInfo(inputs[1], weightTransposeAttribute),
|
||||
{inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0];
|
||||
if (attributes.wIsConst && !context.kernelCustomData.wT) {
|
||||
context.kernelCustomData.wT = transposedWeight;
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import {ShapeUtil} from '../../util';
|
|||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext, ProgramInfo} from '../types';
|
||||
|
||||
import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
|
||||
import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
|
||||
|
||||
export interface TransposeAttributes extends AttributeWithCacheKey {
|
||||
readonly perm: number[];
|
||||
|
|
@ -35,13 +35,18 @@ const permFunctionBody = (perm: number[], rank: number, input: IndicesHelper, ou
|
|||
return reverseFunc.join('\n');
|
||||
};
|
||||
|
||||
export const createTransposeProgramInfo =
|
||||
(inputDataType: number, inputRank: number, permAttr: number[]): ProgramInfo => {
|
||||
const perm = getAdjustedPerm(inputRank, permAttr);
|
||||
const output = outputVariable('output', inputDataType, (permAttr && permAttr.length) || inputRank);
|
||||
const input = inputVariable('a', inputDataType, inputRank);
|
||||
export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: number[]): ProgramInfo => {
|
||||
const inputDataType = inputTensor.dataType;
|
||||
const inputRank = inputTensor.dims.length;
|
||||
const perm = getAdjustedPerm(inputRank, permAttr);
|
||||
const useShapesUniforms = enableShapesUniforms(inputRank);
|
||||
const outputShape = getOutputShape(inputTensor.dims, perm);
|
||||
const outShapeOrRank = useShapesUniforms ? outputShape.length : outputShape;
|
||||
const inShapeOrRank = useShapesUniforms ? inputRank : inputTensor.dims;
|
||||
const output = outputVariable('output', inputDataType, outShapeOrRank);
|
||||
const input = inputVariable('a', inputDataType, inShapeOrRank);
|
||||
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)}
|
||||
|
||||
${permFunctionBody(perm, inputRank, input, output)}
|
||||
|
|
@ -54,30 +59,32 @@ export const createTransposeProgramInfo =
|
|||
|
||||
${output.setByOffset('global_idx', input.getByIndices('aIndices'))}
|
||||
}`;
|
||||
return {
|
||||
name: 'Transpose',
|
||||
shaderCache: {hint: `${permAttr}`, inputDependencies: useShapesUniforms ? ['rank'] : ['dims']},
|
||||
getRunData: (inputs) => {
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
return {
|
||||
name: 'Transpose',
|
||||
shaderCache: {hint: `${permAttr}`, inputDependencies: ['rank']},
|
||||
getRunData: (inputs) => {
|
||||
const outputShape = getOutputShape(inputs[0].dims, perm);
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
return {
|
||||
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
|
||||
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
|
||||
programUniforms: [
|
||||
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
|
||||
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
|
||||
programUniforms: useShapesUniforms ?
|
||||
[
|
||||
{type: 'uint32', data: outputSize},
|
||||
...createTensorShapeVariables(inputs[0].dims),
|
||||
...createTensorShapeVariables(outputShape),
|
||||
] :
|
||||
[
|
||||
{type: 'uint32', data: outputSize},
|
||||
],
|
||||
};
|
||||
},
|
||||
getShaderSource,
|
||||
};
|
||||
};
|
||||
},
|
||||
getShaderSource,
|
||||
};
|
||||
};
|
||||
|
||||
export const transpose = (context: ComputeContext, attributes: TransposeAttributes): void => {
|
||||
validateInputs(context.inputs);
|
||||
context.compute(
|
||||
createTransposeProgramInfo(context.inputs[0].dataType, context.inputs[0].dims.length, attributes.perm));
|
||||
context.compute(createTransposeProgramInfo(context.inputs[0], attributes.perm));
|
||||
};
|
||||
|
||||
export const parseTransposeAttributes = (attributes: Record<string, unknown>): TransposeAttributes =>
|
||||
|
|
|
|||
|
|
@ -166,5 +166,29 @@
|
|||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Transpose 5D - perms:[4, 3, 1, 0, 2]",
|
||||
"operator": "Transpose",
|
||||
"attributes": [{ "name": "perm", "data": [4, 3, 1, 0, 2], "type": "ints" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[3, 1, 2, 1, 4]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
|
||||
"dims": [3, 1, 2, 1, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [1, 5, 9, 13, 17, 21, 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23, 4, 8, 12, 16, 20, 24],
|
||||
"dims": [4, 1, 1, 3, 2],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in a new issue