mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
[js/webgpu] only declare shape and strides in shader when necessary (#18940)
### Description Previously, shape and strides were added unconditionally even they are not used. This PR fixes this issue and only adds shape and strides when they are required. It's useful when some shapes are not used as uniform if the program depends on type instead of rank.
This commit is contained in:
parent
c613cc58a9
commit
44584c3ebe
1 changed files with 13 additions and 4 deletions
|
|
@ -588,30 +588,39 @@ const createIndicesHelper =
|
|||
|
||||
const impl = () => {
|
||||
const impls = [];
|
||||
if (!useUniform) {
|
||||
impls.push(`const ${shape} = ${type.indices}(${shapeOrRank.join(',')});`);
|
||||
impls.push(`const ${strides} = ${type.indices}(${ShapeUtil.computeStrides(shapeOrRank).join(',')});`);
|
||||
}
|
||||
let needShapeStrides = false;
|
||||
if (implementationUsed.offsetToIndices) {
|
||||
impls.push(offsetToIndicesImplementation);
|
||||
needShapeStrides = true;
|
||||
}
|
||||
if (implementationUsed.indicesToOffset) {
|
||||
impls.push(indicesToOffsetImplementation);
|
||||
needShapeStrides = true;
|
||||
}
|
||||
if (implementationUsed.broadcastedIndicesToOffset) {
|
||||
Object.values(broadcastedIndicesToOffsetImplementation).forEach(impl => impls.push(impl));
|
||||
needShapeStrides = true;
|
||||
}
|
||||
if (implementationUsed.set) {
|
||||
impls.push(setImplementation);
|
||||
needShapeStrides = true;
|
||||
}
|
||||
if (implementationUsed.setByIndices) {
|
||||
impls.push(setByIndicesImplementation);
|
||||
needShapeStrides = true;
|
||||
}
|
||||
if (implementationUsed.get) {
|
||||
impls.push(getImplementation);
|
||||
needShapeStrides = true;
|
||||
}
|
||||
if (implementationUsed.getByIndices) {
|
||||
impls.push(getByIndicesImplementation);
|
||||
needShapeStrides = true;
|
||||
}
|
||||
if (!useUniform && needShapeStrides) {
|
||||
impls.unshift(
|
||||
`const ${shape} = ${type.indices}(${shapeOrRank.join(',')});`,
|
||||
`const ${strides} = ${type.indices}(${ShapeUtil.computeStrides(shapeOrRank).join(',')});`);
|
||||
}
|
||||
return impls.join('\n');
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in a new issue