mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[js/webgpu] validate transpose perm if specified (#23197)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
0b87bccca8
commit
a3833a5e79
1 changed files with 7 additions and 3 deletions
|
|
@ -13,14 +13,18 @@ export interface TransposeAttributes extends AttributeWithCacheKey {
|
|||
readonly perm: number[];
|
||||
}
|
||||
|
||||
const validateInputs = (inputs: readonly TensorView[]): void => {
|
||||
const validateInputs = (inputs: readonly TensorView[], perm: readonly number[]): void => {
|
||||
if (!inputs || inputs.length !== 1) {
|
||||
throw new Error('Transpose requires 1 input.');
|
||||
}
|
||||
|
||||
if (perm.length !== 0 && perm.length !== inputs[0].dims.length) {
|
||||
throw new Error(`perm size ${perm.length} does not match input rank ${inputs[0].dims.length}`);
|
||||
}
|
||||
};
|
||||
|
||||
const getAdjustedPerm = (inputRank: number, perm: number[]): number[] =>
|
||||
perm && perm.length !== inputRank ? [...new Array(inputRank).keys()].reverse() : perm;
|
||||
perm.length !== 0 ? perm : [...new Array(inputRank).keys()].reverse();
|
||||
|
||||
const getOutputShape = (inputShape: readonly number[], perm: number[]): readonly number[] =>
|
||||
ShapeUtil.sortBasedOnPerm(inputShape, getAdjustedPerm(inputShape.length, perm));
|
||||
|
|
@ -191,7 +195,7 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu
|
|||
};
|
||||
|
||||
export const transpose = (context: ComputeContext, attributes: TransposeAttributes): void => {
|
||||
validateInputs(context.inputs);
|
||||
validateInputs(context.inputs, attributes.perm);
|
||||
context.compute(createTransposeProgramInfo(context.inputs[0], attributes.perm));
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue