accumulate in fp32 for Reduce* (#19868)

This commit is contained in:
Guenther Schmuelling 2024-03-18 08:28:43 -07:00 committed by GitHub
parent 28ad6c3955
commit 7e0d424934
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -131,7 +131,7 @@ export const createReduceSharedProgramInfo =
const workgroupSize = 32;
const sharedMemorySnippet = `
var<workgroup> aBestValues : array<${output.type.storage}, ${workgroupSize}>;
var<workgroup> aBestValues : array<f32, ${workgroupSize}>;
`;
const getShaderSource = (shaderHelper: ShaderHelper) => `
@ -145,10 +145,10 @@ export const createReduceSharedProgramInfo =
let outputIndex = global_idx / ${workgroupSize};
let offset = outputIndex * uniforms.reduceSize;
var bestValue = ${output.type.storage}(${reduceInitValues[reduceType]});
var bestValue = f32(${reduceInitValues[reduceType]});
let Length = uniforms.reduceSize;
for (var k = local_idx; k < Length; k = k + ${workgroupSize}) {
let candidate = ${output.type.storage}(${input.getByOffset('offset + k')});
let candidate = f32(${input.getByOffset('offset + k')});
bestValue = ${reduceOps[reduceType]};
}
aBestValues[local_idx] = bestValue;
@ -172,8 +172,8 @@ export const createReduceSharedProgramInfo =
output.setByOffset(
'outputIndex',
`${
reduceType === 'mean' ? `bestValue / ${output.type.storage}(uniforms.reduceSize)` :
`${reduceOutputValues[reduceType]}`}`)};
reduceType === 'mean' ? `${output.type.storage}(bestValue / f32(uniforms.reduceSize))` :
`${output.type.storage}(${reduceOutputValues[reduceType]})`}`)};
}
}`;