mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
[js/webgpu] Fix shader errors in indicesGet/Set when rank > 4 (#18661)
### Description
Currently, for non-uniform variables, we still use `array<u32, N>` type
instead of array<vec4<u32>, N1>`. So we can't always treat all variables
with rank > 4 as uniforms to index.
This PR fixes below errors:
```
error(s) generated while compiling the shader:
:5:44 error: index 4 out of bounds [0..1]
return uniforms.input_strides[4] * (outputIndices[4] % uniforms.input_shape[4])+uniforms.input_strides[3] * (outputIndices[3] % uniforms.input_shape[3])+uniforms.input_strides[2] * (outputIndices[2] % uniforms.input_shape[2])+uniforms.input_strides[1] * (outputIndices[1] % uniforms.input_shape[1])+uniforms.input_strides[0] * (outputIndices[0] % uniforms.input_shape[0]);
^
FAILED #OpTest# - expand.jsonc [webgpu]Expand - Expand 5D - float32 Expand 5 - float32
FAILED #OpTest# - expand.jsonc [webgpu]Expand - Expand 5D - float32 Expand 5 - shape < input.size()
This commit is contained in:
parent
eaaf27015e
commit
92ee664f64
2 changed files with 22 additions and 18 deletions
|
|
@ -326,16 +326,20 @@ export const sumVector = (name: string, components: number) => {
|
|||
};
|
||||
|
||||
/**
|
||||
* A helper function that returns uniform element at index.
|
||||
* @param name - the name of uniform element.
|
||||
* @param index - the index of uniform element.
|
||||
* @param length - the length of uniform element.
|
||||
* A helper function that returns variable element at index.
|
||||
* @param name - the name of variable.
|
||||
* @param index - the index of variable element.
|
||||
* @param length - the length of variable.
|
||||
*/
|
||||
export const getUniformElementAt = (name: string, index: number|string, length: number): string => {
|
||||
if (typeof (index) === 'string') {
|
||||
return length > 4 ? `${name}[(${index}) / 4][(${index}) % 4]` : length > 1 ? `${name}[${index}]` : name;
|
||||
export const getElementAt = (name: string, index: number|string, length: number): string => {
|
||||
if (name.startsWith('uniforms.') && length > 4) {
|
||||
if (typeof (index) === 'string') {
|
||||
return `${name}[(${index}) / 4][(${index}) % 4]`;
|
||||
} else {
|
||||
return `${name}[${Math.floor(index / 4)}][${index % 4}]`;
|
||||
}
|
||||
} else {
|
||||
return length > 4 ? `${name}[${Math.floor(index / 4)}][${index % 4}]` : length > 1 ? `${name}[${index}]` : name;
|
||||
return length > 1 ? `${name}[${index}]` : name;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -380,8 +384,8 @@ const createIndicesHelper =
|
|||
let o2iSnippet = '';
|
||||
for (let i = 0; i < rank - 1; i++) {
|
||||
o2iSnippet += `
|
||||
let dim${i} = current / ${getUniformElementAt(strides, i, rank)};
|
||||
let rest${i} = current % ${getUniformElementAt(strides, i, rank)};
|
||||
let dim${i} = current / ${getElementAt(strides, i, rank)};
|
||||
let rest${i} = current % ${getElementAt(strides, i, rank)};
|
||||
indices[${i}] = dim${i};
|
||||
current = rest${i};
|
||||
`;
|
||||
|
|
@ -404,7 +408,7 @@ const createIndicesHelper =
|
|||
const offsets: string[] = [];
|
||||
if (rank >= 2) {
|
||||
for (let i = rank - 1; i >= 0; i--) {
|
||||
offsets.push(`${getUniformElementAt(strides, i, rank)} * (indices[${i}])`);
|
||||
offsets.push(`${getElementAt(strides, i, rank)} * (indices[${i}])`);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -425,7 +429,7 @@ const createIndicesHelper =
|
|||
if (rank < 2) {
|
||||
return `${varIndices}`;
|
||||
} else {
|
||||
return `${varIndices}[${idx}]`;
|
||||
return `${getElementAt(varIndices, idx, rank)}`;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -433,7 +437,7 @@ const createIndicesHelper =
|
|||
if (rank < 2) {
|
||||
return `${varIndices}=${value};`;
|
||||
} else {
|
||||
return `${varIndices}[${idx}]=${value};`;
|
||||
return `${getElementAt(varIndices, idx, rank)}=${value};`;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util';
|
|||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext, ProgramInfo, ProgramUniform, TensorInfo} from '../types';
|
||||
|
||||
import {createTensorShapeVariables, getUniformElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common';
|
||||
import {createTensorShapeVariables, getElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common';
|
||||
|
||||
export interface SliceAttributes extends AttributeWithCacheKey {
|
||||
readonly starts: number[];
|
||||
|
|
@ -82,10 +82,10 @@ const calculateInputIndicesImpl =
|
|||
var inputIndices: ${input.type.indices};
|
||||
var carry = 0u;
|
||||
for (var i = ${inputShape.length}; i >= 0; i--) {
|
||||
let input_shape_i = ${getUniformElementAt('uniforms.input_shape', 'i', inputShape.length)};
|
||||
let steps_i = ${getUniformElementAt('uniforms.steps', 'i', inputShape.length)};
|
||||
let signs_i = ${getUniformElementAt('uniforms.signs', 'i', inputShape.length)};
|
||||
let starts_i = ${getUniformElementAt('uniforms.starts', 'i', inputShape.length)};
|
||||
let input_shape_i = ${getElementAt('uniforms.input_shape', 'i', inputShape.length)};
|
||||
let steps_i = ${getElementAt('uniforms.steps', 'i', inputShape.length)};
|
||||
let signs_i = ${getElementAt('uniforms.signs', 'i', inputShape.length)};
|
||||
let starts_i = ${getElementAt('uniforms.starts', 'i', inputShape.length)};
|
||||
var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'};
|
||||
var inputIndex = outputIndex * steps_i + starts_i + carry;
|
||||
carry = inputIndex / input_shape_i;
|
||||
|
|
|
|||
Loading…
Reference in a new issue