From 33d5ea39b378d82ce0a6038e61844a6819281d9c Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Wed, 24 Apr 2024 08:01:28 -0700 Subject: [PATCH] [js/webgpu] fixes for fp16 attention (#20440) --- js/web/lib/wasm/jsep/webgpu/ops/attention.ts | 6 +++--- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index e8dc702d6b..57e96640c3 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -264,7 +264,7 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor let local_offset = local_idx * uniforms.elements_per_thread; let offset = workgroup_id.x * uniforms.d_comp + local_offset; - var thread_max_vector = ${inputHelper.type.value}(-3.402823e+38f); + var thread_max_vector = ${f32Type}(-3.402823e+38f); for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) { thread_max_vector = max(${f32Type}(x[offset + i]), thread_max_vector); } @@ -282,12 +282,12 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor })()}; workgroupBarrier(); - var max_value: f32 = -3.402823e+38f; + var max_value = -3.402823e+38f; for (var i = 0u; i < ${WG}; i++) { max_value = max(thread_max[i], max_value); } - var sum_vector = ${inputHelper.type.value}(${0}); + var sum_vector = ${f32Type}(${0}); for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) { sum_vector += exp(${f32Type}(x[offset + i]) - max_value); } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 5e27e79087..ec2831a3cc 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -313,7 +313,7 @@ export const castToF32 = (dataType: string, components: number, value: string) = return `f32(${value})`; } - return `vec${components}f32(${value})`; + return `vec${components}(${value})`; }; /**