[js/webgpu] validate transpose perm if specified (#23197)

### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
xhcao 2025-01-02 07:58:54 +08:00 committed by GitHub
parent 0b87bccca8
commit a3833a5e79
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -13,14 +13,18 @@ export interface TransposeAttributes extends AttributeWithCacheKey {
readonly perm: number[];
}
const validateInputs = (inputs: readonly TensorView[]): void => {
const validateInputs = (inputs: readonly TensorView[], perm: readonly number[]): void => {
if (!inputs || inputs.length !== 1) {
throw new Error('Transpose requires 1 input.');
}
if (perm.length !== 0 && perm.length !== inputs[0].dims.length) {
throw new Error(`perm size ${perm.length} does not match input rank ${inputs[0].dims.length}`);
}
};
const getAdjustedPerm = (inputRank: number, perm: number[]): number[] =>
perm && perm.length !== inputRank ? [...new Array(inputRank).keys()].reverse() : perm;
perm.length !== 0 ? perm : [...new Array(inputRank).keys()].reverse();
const getOutputShape = (inputShape: readonly number[], perm: number[]): readonly number[] =>
ShapeUtil.sortBasedOnPerm(inputShape, getAdjustedPerm(inputShape.length, perm));
@ -191,7 +195,7 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu
};
export const transpose = (context: ComputeContext, attributes: TransposeAttributes): void => {
validateInputs(context.inputs);
validateInputs(context.inputs, attributes.perm);
context.compute(createTransposeProgramInfo(context.inputs[0], attributes.perm));
};