diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index 435267a1b9..30a406cd21 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -243,7 +243,7 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor } const elementsPerThread = Math.ceil(d / components / WG); const programUniforms: ProgramUniform[] = [ - {type: input.dataType, data: 1 / d}, {type: DataType.uint32, data: dComp}, + {type: DataType.float, data: 1 / d}, {type: DataType.uint32, data: dComp}, {type: DataType.uint32, data: elementsPerThread} ]; const dataType = tensorTypeToWsglStorageType(input.dataType, components); @@ -252,10 +252,8 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor const getShaderSource = (shaderHelper: ShaderHelper) => { const inputHelper = outputVariable('x', input.dataType, input.dims, components); const elemValueType = tensorTypeToWsglValueType(input.dataType); - const uniforms: UniformsArrayType = [ - {name: 'd_inv', type: elemValueType as UniformDataElementType}, {name: 'd_comp', type: 'u32'}, - {name: 'elements_per_thread', type: 'u32'} - ]; + const uniforms: UniformsArrayType = + [{name: 'd_inv', type: 'f32'}, {name: 'd_comp', type: 'u32'}, {name: 'elements_per_thread', type: 'u32'}]; return ` var thread_max: array; @@ -265,7 +263,7 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor WG, 1, 1 ])} let local_offset = local_idx * uniforms.elements_per_thread; - let offset = workgroup_id.x * uniforms.d_comp + local_offset; + let offset = (global_idx / ${WG}) * uniforms.d_comp + local_offset; var thread_max_vector = ${f32Type}(-3.402823e+38f); for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) { @@ -315,7 +313,7 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor if (sum == 0) { for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) { - x[offset + i] = ${inputHelper.type.value}(uniforms.d_inv); + x[offset + i] = ${inputHelper.type.value}(${elemValueType}(uniforms.d_inv)); } } else { for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {