[js/webgpu] fix two errors of attention operator (#21687)

Fix two issues:
(1) scale shall be fp32 instead of f16
(2) Softmax program does not handle the normalized dispatch group values, so if the sequence length is over 65535, the result is not correct for this program.
This commit is contained in:
xhcao 2024-08-14 00:42:34 +08:00 committed by GitHub
parent 6db3d63add
commit 9c6ee89fa7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -243,7 +243,7 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor
}
const elementsPerThread = Math.ceil(d / components / WG);
const programUniforms: ProgramUniform[] = [
{type: input.dataType, data: 1 / d}, {type: DataType.uint32, data: dComp},
{type: DataType.float, data: 1 / d}, {type: DataType.uint32, data: dComp},
{type: DataType.uint32, data: elementsPerThread}
];
const dataType = tensorTypeToWsglStorageType(input.dataType, components);
@ -252,10 +252,8 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor
const getShaderSource = (shaderHelper: ShaderHelper) => {
const inputHelper = outputVariable('x', input.dataType, input.dims, components);
const elemValueType = tensorTypeToWsglValueType(input.dataType);
const uniforms: UniformsArrayType = [
{name: 'd_inv', type: elemValueType as UniformDataElementType}, {name: 'd_comp', type: 'u32'},
{name: 'elements_per_thread', type: 'u32'}
];
const uniforms: UniformsArrayType =
[{name: 'd_inv', type: 'f32'}, {name: 'd_comp', type: 'u32'}, {name: 'elements_per_thread', type: 'u32'}];
return `
var<workgroup> thread_max: array<f32, ${WG}>;
@ -265,7 +263,7 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor
WG, 1, 1
])}
let local_offset = local_idx * uniforms.elements_per_thread;
let offset = workgroup_id.x * uniforms.d_comp + local_offset;
let offset = (global_idx / ${WG}) * uniforms.d_comp + local_offset;
var thread_max_vector = ${f32Type}(-3.402823e+38f);
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
@ -315,7 +313,7 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor
if (sum == 0) {
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
x[offset + i] = ${inputHelper.type.value}(uniforms.d_inv);
x[offset + i] = ${inputHelper.type.value}(${elemValueType}(uniforms.d_inv));
}
} else {
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {