[js/webgpu] fixes for fp16 attention (#20440)

This commit is contained in:
Guenther Schmuelling 2024-04-24 08:01:28 -07:00 committed by GitHub
parent 80213a9e66
commit 33d5ea39b3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 4 additions and 4 deletions

View file

@ -264,7 +264,7 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor
let local_offset = local_idx * uniforms.elements_per_thread;
let offset = workgroup_id.x * uniforms.d_comp + local_offset;
var thread_max_vector = ${inputHelper.type.value}(-3.402823e+38f);
var thread_max_vector = ${f32Type}(-3.402823e+38f);
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
thread_max_vector = max(${f32Type}(x[offset + i]), thread_max_vector);
}
@ -282,12 +282,12 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor
})()};
workgroupBarrier();
var max_value: f32 = -3.402823e+38f;
var max_value = -3.402823e+38f;
for (var i = 0u; i < ${WG}; i++) {
max_value = max(thread_max[i], max_value);
}
var sum_vector = ${inputHelper.type.value}(${0});
var sum_vector = ${f32Type}(${0});
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
sum_vector += exp(${f32Type}(x[offset + i]) - max_value);
}

View file

@ -313,7 +313,7 @@ export const castToF32 = (dataType: string, components: number, value: string) =
return `f32(${value})`;
}
return `vec${components}f32(${value})`;
return `vec${components}<f32>(${value})`;
};
/**