[JS/WebGPU] Fix Split and Where to handle corner cases. (#19613)

### Description
<!-- Describe your changes. -->
1. Fix Where operator to handle Boolean input less than 4 bytes.
2. Fix JSEP test harness to use tensor names consistently.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
satyajandhyala 2024-02-23 00:21:15 -08:00 committed by GitHub
parent 5e432a3ae6
commit ae3d73c981
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 38 additions and 3 deletions

View file

@ -27,7 +27,7 @@ const createWhereOpProgramShader =
const expressionA = `a_data[index_a${x}][component_a${x}]`;
const expressionB = `b_data[index_b${x}][component_b${x}]`;
// eslint-disable-next-line no-bitwise
const expressionC = `bool(c_data[index_c${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`;
const expressionC = `bool(c_data[index_c${x}] & (0xffu << (component_c${x} * 8)))`;
return `
let output_indices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)};
let offset_a${x} = ${a.broadcastedIndicesToOffset(`output_indices${x}`, output)};
@ -38,6 +38,7 @@ const createWhereOpProgramShader =
let index_c${x} = offset_c${x} / 4u;
let component_a${x} = offset_a${x} % 4u;
let component_b${x} = offset_b${x} % 4u;
let component_c${x} = offset_c${x} % 4u;
${resStr}[${x}] = ${typeCast}(${expression(expressionA, expressionB, expressionC)});
`;
};

View file

@ -168,5 +168,39 @@
]
}
]
},
{
"name": "Where with no attributes",
"operator": "Where",
"attributes": [],
"cases": [
{
"name": "T[1 1 2 1] T[1 4] T[1 1 2 4] float32 broadcast 1",
"inputs": [
{
"data": [true, false],
"dims": [1, 1, 2, 1],
"type": "bool"
},
{
"data": [1, 2, 3, 4],
"dims": [1, 4],
"type": "float32"
},
{
"data": [5, 6, 7, 8, 9, 10, 11, 12],
"dims": [1, 1, 2, 4],
"type": "float32"
}
],
"outputs": [
{
"data": [1, 2, 3, 4, 9, 10, 11, 12],
"dims": [1, 1, 2, 4],
"type": "float32"
}
]
}
]
}
]

View file

@ -627,8 +627,8 @@ export async function runModelTestSet(
try {
const feeds: Record<string, ort.Tensor> = {};
const outputsMetaInfo: Record<string, ort.Tensor> = {};
testCase.inputs!.forEach((tensor, i) => feeds[context.session.inputNames[i]] = tensor);
testCase.outputs!.forEach((tensor, i) => outputsMetaInfo[context.session.outputNames[i]] = tensor);
testCase.inputs!.forEach((tensor) => feeds[tensor.name] = tensor);
testCase.outputs!.forEach((tensor) => outputsMetaInfo[tensor.name] = tensor);
const [start, end, outputs] =
await sessionRun({session: context.session, feeds, outputsMetaInfo, ioBinding: context.ioBinding});
if (context.perfData.count === 0) {