[js/webgpu] fix a bug in transpose shader (#22997)

### Description

Fix a bug in transpose shader, when input/output rank is 1.

### Motivation and Context

Fixes #22994
This commit is contained in:
Yulong Wang 2024-12-03 20:21:08 -08:00 committed by GitHub
parent e84b8e7bd5
commit 06526af346
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -29,7 +29,9 @@ const permFunctionBody = (perm: number[], rank: number, input: IndicesHelper, ou
let reverseFunc = `fn perm(i: ${output.type.indices}) -> ${input.type.indices} {
var a: ${input.type.indices};`;
for (let i = 0; i < rank; ++i) {
reverseFunc += input.indicesSet('a', perm[i], `i[${i}]`);
// input indices and output indices should always be larger or equal to 2,
// so indexer is always valid to be used on `a` and `i`.
reverseFunc += `a[${perm[i]}]=i[${i}];`;
}
return (reverseFunc += 'return a;}');
};
@ -71,7 +73,7 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu
const outputShape = getOutputShape(inputTensor.dims, perm);
let newInputShape = inputTensor.dims;
let newOutputShape = outputShape;
const transposeAsReshape = isTransposeReshape(perm, inputTensor.dims);
const transposeAsReshape = inputRank < 2 || isTransposeReshape(perm, inputTensor.dims);
let getShaderSource;
if (transposeAsReshape) {
getShaderSource = (shaderHelper: ShaderHelper) => {