From 9c6ee89fa7c89e4bf39b60f0ba636d1b9988735c Mon Sep 17 00:00:00 2001 From: xhcao Date: Wed, 14 Aug 2024 00:42:34 +0800 Subject: [PATCH] [js/webgpu] fix two errors of attention operator (#21687) Fix two issues: (1) scale shall be fp32 instead of f16 (2) Softmax program does not handle the normalized dispatch group values, so if the sequence length is over 65535, the result is not correct for this program. --- js/web/lib/wasm/jsep/webgpu/ops/attention.ts | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index 435267a1b9..30a406cd21 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -243,7 +243,7 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor } const elementsPerThread = Math.ceil(d / components / WG); const programUniforms: ProgramUniform[] = [ - {type: input.dataType, data: 1 / d}, {type: DataType.uint32, data: dComp}, + {type: DataType.float, data: 1 / d}, {type: DataType.uint32, data: dComp}, {type: DataType.uint32, data: elementsPerThread} ]; const dataType = tensorTypeToWsglStorageType(input.dataType, components); @@ -252,10 +252,8 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor const getShaderSource = (shaderHelper: ShaderHelper) => { const inputHelper = outputVariable('x', input.dataType, input.dims, components); const elemValueType = tensorTypeToWsglValueType(input.dataType); - const uniforms: UniformsArrayType = [ - {name: 'd_inv', type: elemValueType as UniformDataElementType}, {name: 'd_comp', type: 'u32'}, - {name: 'elements_per_thread', type: 'u32'} - ]; + const uniforms: UniformsArrayType = + [{name: 'd_inv', type: 'f32'}, {name: 'd_comp', type: 'u32'}, {name: 'elements_per_thread', type: 'u32'}]; return ` var thread_max: array; @@ -265,7 +263,7 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor WG, 1, 1 ])} let local_offset = local_idx * uniforms.elements_per_thread; - let offset = workgroup_id.x * uniforms.d_comp + local_offset; + let offset = (global_idx / ${WG}) * uniforms.d_comp + local_offset; var thread_max_vector = ${f32Type}(-3.402823e+38f); for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) { @@ -315,7 +313,7 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor if (sum == 0) { for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) { - x[offset + i] = ${inputHelper.type.value}(uniforms.d_inv); + x[offset + i] = ${inputHelper.type.value}(${elemValueType}(uniforms.d_inv)); } } else { for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {