mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[js/webgpu] fix a bug in transpose shader (#22997)
### Description Fix a bug in transpose shader, when input/output rank is 1. ### Motivation and Context Fixes #22994
This commit is contained in:
parent
e84b8e7bd5
commit
06526af346
1 changed files with 4 additions and 2 deletions
|
|
@ -29,7 +29,9 @@ const permFunctionBody = (perm: number[], rank: number, input: IndicesHelper, ou
|
|||
let reverseFunc = `fn perm(i: ${output.type.indices}) -> ${input.type.indices} {
|
||||
var a: ${input.type.indices};`;
|
||||
for (let i = 0; i < rank; ++i) {
|
||||
reverseFunc += input.indicesSet('a', perm[i], `i[${i}]`);
|
||||
// input indices and output indices should always be larger or equal to 2,
|
||||
// so indexer is always valid to be used on `a` and `i`.
|
||||
reverseFunc += `a[${perm[i]}]=i[${i}];`;
|
||||
}
|
||||
return (reverseFunc += 'return a;}');
|
||||
};
|
||||
|
|
@ -71,7 +73,7 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu
|
|||
const outputShape = getOutputShape(inputTensor.dims, perm);
|
||||
let newInputShape = inputTensor.dims;
|
||||
let newOutputShape = outputShape;
|
||||
const transposeAsReshape = isTransposeReshape(perm, inputTensor.dims);
|
||||
const transposeAsReshape = inputRank < 2 || isTransposeReshape(perm, inputTensor.dims);
|
||||
let getShaderSource;
|
||||
if (transposeAsReshape) {
|
||||
getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
|
|
|
|||
Loading…
Reference in a new issue