mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-08 00:23:03 +00:00
[js/webgpu] fixes for fp16 attention (#20440)
This commit is contained in:
parent
80213a9e66
commit
33d5ea39b3
2 changed files with 4 additions and 4 deletions
|
|
@ -264,7 +264,7 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor
|
|||
let local_offset = local_idx * uniforms.elements_per_thread;
|
||||
let offset = workgroup_id.x * uniforms.d_comp + local_offset;
|
||||
|
||||
var thread_max_vector = ${inputHelper.type.value}(-3.402823e+38f);
|
||||
var thread_max_vector = ${f32Type}(-3.402823e+38f);
|
||||
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
|
||||
thread_max_vector = max(${f32Type}(x[offset + i]), thread_max_vector);
|
||||
}
|
||||
|
|
@ -282,12 +282,12 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor
|
|||
})()};
|
||||
workgroupBarrier();
|
||||
|
||||
var max_value: f32 = -3.402823e+38f;
|
||||
var max_value = -3.402823e+38f;
|
||||
for (var i = 0u; i < ${WG}; i++) {
|
||||
max_value = max(thread_max[i], max_value);
|
||||
}
|
||||
|
||||
var sum_vector = ${inputHelper.type.value}(${0});
|
||||
var sum_vector = ${f32Type}(${0});
|
||||
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
|
||||
sum_vector += exp(${f32Type}(x[offset + i]) - max_value);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -313,7 +313,7 @@ export const castToF32 = (dataType: string, components: number, value: string) =
|
|||
return `f32(${value})`;
|
||||
}
|
||||
|
||||
return `vec${components}f32(${value})`;
|
||||
return `vec${components}<f32>(${value})`;
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
|
|||
Loading…
Reference in a new issue