mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-30 03:37:44 +00:00
[js/webgpu] fix two errors of attention operator (#21687)
Fix two issues: (1) scale shall be fp32 instead of f16 (2) Softmax program does not handle the normalized dispatch group values, so if the sequence length is over 65535, the result is not correct for this program.
This commit is contained in:
parent
6db3d63add
commit
9c6ee89fa7
1 changed files with 5 additions and 7 deletions
|
|
@ -243,7 +243,7 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor
|
|||
}
|
||||
const elementsPerThread = Math.ceil(d / components / WG);
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: input.dataType, data: 1 / d}, {type: DataType.uint32, data: dComp},
|
||||
{type: DataType.float, data: 1 / d}, {type: DataType.uint32, data: dComp},
|
||||
{type: DataType.uint32, data: elementsPerThread}
|
||||
];
|
||||
const dataType = tensorTypeToWsglStorageType(input.dataType, components);
|
||||
|
|
@ -252,10 +252,8 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor
|
|||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
const inputHelper = outputVariable('x', input.dataType, input.dims, components);
|
||||
const elemValueType = tensorTypeToWsglValueType(input.dataType);
|
||||
const uniforms: UniformsArrayType = [
|
||||
{name: 'd_inv', type: elemValueType as UniformDataElementType}, {name: 'd_comp', type: 'u32'},
|
||||
{name: 'elements_per_thread', type: 'u32'}
|
||||
];
|
||||
const uniforms: UniformsArrayType =
|
||||
[{name: 'd_inv', type: 'f32'}, {name: 'd_comp', type: 'u32'}, {name: 'elements_per_thread', type: 'u32'}];
|
||||
|
||||
return `
|
||||
var<workgroup> thread_max: array<f32, ${WG}>;
|
||||
|
|
@ -265,7 +263,7 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor
|
|||
WG, 1, 1
|
||||
])}
|
||||
let local_offset = local_idx * uniforms.elements_per_thread;
|
||||
let offset = workgroup_id.x * uniforms.d_comp + local_offset;
|
||||
let offset = (global_idx / ${WG}) * uniforms.d_comp + local_offset;
|
||||
|
||||
var thread_max_vector = ${f32Type}(-3.402823e+38f);
|
||||
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
|
||||
|
|
@ -315,7 +313,7 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor
|
|||
|
||||
if (sum == 0) {
|
||||
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
|
||||
x[offset + i] = ${inputHelper.type.value}(uniforms.d_inv);
|
||||
x[offset + i] = ${inputHelper.type.value}(${elemValueType}(uniforms.d_inv));
|
||||
}
|
||||
} else {
|
||||
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue