mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-12 00:59:23 +00:00
[JS/WebGPU] Fix Split and Where to handle corner cases. (#19613)
### Description <!-- Describe your changes. --> 1. Fix Where operator to handle Boolean input less than 4 bytes. 2. Fix JSEP test harness to use tensor names consistently. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
5e432a3ae6
commit
ae3d73c981
3 changed files with 38 additions and 3 deletions
|
|
@ -27,7 +27,7 @@ const createWhereOpProgramShader =
|
|||
const expressionA = `a_data[index_a${x}][component_a${x}]`;
|
||||
const expressionB = `b_data[index_b${x}][component_b${x}]`;
|
||||
// eslint-disable-next-line no-bitwise
|
||||
const expressionC = `bool(c_data[index_c${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`;
|
||||
const expressionC = `bool(c_data[index_c${x}] & (0xffu << (component_c${x} * 8)))`;
|
||||
return `
|
||||
let output_indices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)};
|
||||
let offset_a${x} = ${a.broadcastedIndicesToOffset(`output_indices${x}`, output)};
|
||||
|
|
@ -38,6 +38,7 @@ const createWhereOpProgramShader =
|
|||
let index_c${x} = offset_c${x} / 4u;
|
||||
let component_a${x} = offset_a${x} % 4u;
|
||||
let component_b${x} = offset_b${x} % 4u;
|
||||
let component_c${x} = offset_c${x} % 4u;
|
||||
${resStr}[${x}] = ${typeCast}(${expression(expressionA, expressionB, expressionC)});
|
||||
`;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -168,5 +168,39 @@
|
|||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Where with no attributes",
|
||||
"operator": "Where",
|
||||
"attributes": [],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[1 1 2 1] T[1 4] T[1 1 2 4] float32 broadcast 1",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [true, false],
|
||||
"dims": [1, 1, 2, 1],
|
||||
"type": "bool"
|
||||
},
|
||||
{
|
||||
"data": [1, 2, 3, 4],
|
||||
"dims": [1, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [5, 6, 7, 8, 9, 10, 11, 12],
|
||||
"dims": [1, 1, 2, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4, 9, 10, 11, 12],
|
||||
"dims": [1, 1, 2, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
|
|
|||
|
|
@ -627,8 +627,8 @@ export async function runModelTestSet(
|
|||
try {
|
||||
const feeds: Record<string, ort.Tensor> = {};
|
||||
const outputsMetaInfo: Record<string, ort.Tensor> = {};
|
||||
testCase.inputs!.forEach((tensor, i) => feeds[context.session.inputNames[i]] = tensor);
|
||||
testCase.outputs!.forEach((tensor, i) => outputsMetaInfo[context.session.outputNames[i]] = tensor);
|
||||
testCase.inputs!.forEach((tensor) => feeds[tensor.name] = tensor);
|
||||
testCase.outputs!.forEach((tensor) => outputsMetaInfo[tensor.name] = tensor);
|
||||
const [start, end, outputs] =
|
||||
await sessionRun({session: context.session, feeds, outputsMetaInfo, ioBinding: context.ioBinding});
|
||||
if (context.perfData.count === 0) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue