mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-28 22:56:32 +00:00
[WIP][JS/WebGPU] Inputs Key and Value could be 4-dims. (#20470)
### Description The Key and Value inputs could be 4-dims ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
2c19db0af1
commit
21b3cbc3af
3 changed files with 239 additions and 11 deletions
|
|
@ -282,12 +282,12 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor
|
|||
})()};
|
||||
workgroupBarrier();
|
||||
|
||||
var max_value = -3.402823e+38f;
|
||||
var max_value = f32(-3.402823e+38f);
|
||||
for (var i = 0u; i < ${WG}; i++) {
|
||||
max_value = max(thread_max[i], max_value);
|
||||
}
|
||||
|
||||
var sum_vector = ${f32Type}(${0});
|
||||
var sum_vector = ${f32Type}(0);
|
||||
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
|
||||
sum_vector += exp(${f32Type}(x[offset + i]) - max_value);
|
||||
}
|
||||
|
|
@ -378,7 +378,6 @@ const createAttentionProbsProgramInfo =
|
|||
{name: 'num_heads', type: 'u32'}, {name: 'alpha', type: dataType as UniformDataElementType}
|
||||
];
|
||||
return `
|
||||
const beta: ${dataType} = 1.0;
|
||||
const TILE_SIZE = ${TILE_SIZE}u;
|
||||
|
||||
var<workgroup> tileQ: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>;
|
||||
|
|
@ -426,16 +425,16 @@ const createAttentionProbsProgramInfo =
|
|||
throw new Error(`Unsupported components: ${components}`);
|
||||
}
|
||||
})()};
|
||||
output[outputIdx] = sum * uniforms.alpha;
|
||||
|
||||
${(() => {
|
||||
if (relativePositionBiasInput) {
|
||||
return `
|
||||
let batch = workgroup_id.z / uniforms.num_heads;
|
||||
let head = workgroup_id.z % uniforms.num_heads;
|
||||
var indices = ${relativePositionBiasInput.type.indices}(batch, head, global_id.y, global_id.x);
|
||||
output[outputIdx] += ${relativePositionBiasInput.getByIndices('indices')};`;
|
||||
output[outputIdx] = sum * uniforms.alpha + ${relativePositionBiasInput.getByIndices('indices')};`;
|
||||
}
|
||||
return '';
|
||||
return 'output[outputIdx] = sum * uniforms.alpha;';
|
||||
})()}
|
||||
}
|
||||
}`;
|
||||
|
|
@ -512,7 +511,6 @@ const createVxAttentionScoreProgramInfo =
|
|||
// we need to transpose output from BNSH_v to BSND_v
|
||||
let batchIdx = workgroup_id.z / uniforms.num_heads;
|
||||
let currentBatchHeadNumber = workgroup_id.z % uniforms.num_heads;
|
||||
let headOffset = (batchIdx * uniforms.M * uniforms.num_heads + currentBatchHeadNumber) * uniforms.N;
|
||||
if (m < uniforms.M && n < uniforms.N) {
|
||||
let outputIdx = batchIdx * uniforms.M *uniforms.v_hidden_size + m * uniforms.v_hidden_size
|
||||
+ currentBatchHeadNumber * uniforms.N + n;
|
||||
|
|
|
|||
|
|
@ -339,7 +339,7 @@ export const multiHeadAttention = (context: ComputeContext, attributes: Attentio
|
|||
|
||||
if (kvBNSH) {
|
||||
return applyAttention(
|
||||
context, Q, key, value, keyPaddingMask, undefined, undefined, undefined, relativePositionBias, params,
|
||||
context, Q, key, value, keyPaddingMask, undefined, pastKey, pastValue, relativePositionBias, params,
|
||||
attributes);
|
||||
}
|
||||
if (!key || !value) {
|
||||
|
|
|
|||
|
|
@ -604,7 +604,7 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"name": "MultiHeadAttention Basic, 4 heads and head-size=1 with pastKey and pastValue",
|
||||
"name": "MultiHeadAttention Basic, 4 heads and head-size=1 with pastKey, pastValue, presentKey and presentValue",
|
||||
"operator": "MultiHeadAttention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 4, "type": "int" }],
|
||||
|
|
@ -765,7 +765,83 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"name": "MultiHeadAttention Basic, one head and head-size=4 with RelativePositionBias, PastKey and PastValue",
|
||||
"name": "MultiHeadAttention Basic, one head and head-size one with RelativePositionBias, pastKey, pastValue, presentKey and presentValue",
|
||||
"operator": "MultiHeadAttention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
// Q
|
||||
{
|
||||
"data": [1.0],
|
||||
"dims": [1, 1, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
// K
|
||||
{
|
||||
"data": [2.0],
|
||||
"dims": [1, 1, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
// V
|
||||
{
|
||||
"data": [3.0],
|
||||
"dims": [1, 1, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
// Bias
|
||||
{
|
||||
"data": null,
|
||||
"type": "float32"
|
||||
},
|
||||
// Mask
|
||||
{
|
||||
"data": null,
|
||||
"type": "int32"
|
||||
},
|
||||
// RelativePositionBias
|
||||
{
|
||||
"data": [10, 20],
|
||||
"dims": [1, 1, 1, 2],
|
||||
"type": "float32"
|
||||
},
|
||||
// PastKey
|
||||
{
|
||||
"data": [4.0],
|
||||
"dims": [1, 1, 1, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
// PastValue
|
||||
{
|
||||
"data": [5.0],
|
||||
"dims": [1, 1, 1, 1],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [3.0006706714630127],
|
||||
"dims": [1, 1, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [4, 2],
|
||||
"dims": [1, 1, 2, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [5, 3],
|
||||
"dims": [1, 1, 2, 1],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "MultiHeadAttention Basic, one head and head-size=4 with RelativePositionBias, PastKey, PastValue, PresentKey and PresentValue",
|
||||
"operator": "MultiHeadAttention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
|
||||
|
|
@ -803,7 +879,7 @@
|
|||
},
|
||||
// RelativePositionBias
|
||||
{
|
||||
"data": [10, 20],
|
||||
"data": [100, 200],
|
||||
"dims": [1, 1, 1, 2],
|
||||
"type": "float32"
|
||||
},
|
||||
|
|
@ -821,8 +897,162 @@
|
|||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [9, 10, 11, 12],
|
||||
"dims": [1, 1, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// Present key
|
||||
{
|
||||
"data": [13, 14, 15, 16, 5, 6, 7, 8],
|
||||
"dims": [1, 1, 2, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// Present value
|
||||
{
|
||||
"data": [17, 18, 19, 20, 9, 10, 11, 12],
|
||||
"dims": [1, 1, 2, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "MultiHeadAttention Basic, one head and head-size one with pastKey and pastValue; kvBNSH (4-dim Key and Value, 3-dim Q)",
|
||||
"operator": "MultiHeadAttention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
// Q
|
||||
{
|
||||
"data": [1.0],
|
||||
"dims": [1, 1, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
// K
|
||||
{
|
||||
"data": [2.0],
|
||||
"dims": [1, 1, 1, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
// V
|
||||
{
|
||||
"data": [3.0],
|
||||
"dims": [1, 1, 1, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
// Bias
|
||||
{
|
||||
"data": null,
|
||||
"type": "float32"
|
||||
},
|
||||
// Mask
|
||||
{
|
||||
"data": null,
|
||||
"type": "int32"
|
||||
},
|
||||
// RelativePositionBias
|
||||
{
|
||||
"data": [10, 20],
|
||||
"dims": [1, 1, 1, 2],
|
||||
"type": "float32"
|
||||
},
|
||||
// PastKey
|
||||
{
|
||||
"data": [4.0],
|
||||
"dims": [1, 1, 1, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
// PastValue
|
||||
{
|
||||
"data": [5.0],
|
||||
"dims": [1, 1, 1, 1],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [3.0006706714630127],
|
||||
"dims": [1, 1, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [4, 2],
|
||||
"dims": [1, 1, 2, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [5, 3],
|
||||
"dims": [1, 1, 2, 1],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "MultiHeadAttention Basic, one head and head-size 4 with pastKey and pastValue; Key and Value 4-dims",
|
||||
"operator": "MultiHeadAttention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
// Q
|
||||
{
|
||||
"data": [1, 2, 3, 4],
|
||||
"dims": [1, 1, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// K
|
||||
{
|
||||
"data": [5, 6, 7, 8],
|
||||
"dims": [1, 1, 1, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// V
|
||||
{
|
||||
"data": [9, 10, 11, 12],
|
||||
"dims": [1, 1, 1, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// Bias
|
||||
{
|
||||
"data": null,
|
||||
"type": "float32"
|
||||
},
|
||||
// Mask
|
||||
{
|
||||
"data": null,
|
||||
"type": "int32"
|
||||
},
|
||||
// RelativePositionBias
|
||||
{
|
||||
"data": [50, 100],
|
||||
"dims": [1, 1, 1, 2],
|
||||
"type": "float32"
|
||||
},
|
||||
// PastKey
|
||||
{
|
||||
"data": [13, 14, 15, 16],
|
||||
"dims": [1, 1, 1, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// PastValue
|
||||
{
|
||||
"data": [17, 18, 19, 20],
|
||||
"dims": [1, 1, 1, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [9.000362396240234, 10.00036334991455, 11.000362396240234, 12.000362396240234],
|
||||
"dims": [1, 1, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
|
|
|
|||
Loading…
Reference in a new issue