[JS/WebGPU] MultiheadAttention bugfix (#20447)

### Description
Fixed pastkey, key and pastvalue, value concatenation condition and
fixed index error. Added new test cases.



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
Satya Kumar Jandhyala 2024-04-24 08:43:14 -07:00 committed by GitHub
parent 33d5ea39b3
commit ae78cdb5d7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 679 additions and 25 deletions

View file

@ -333,9 +333,9 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor
const createAttentionProbsProgramInfo =
(_context: ComputeContext, q: TensorView, key: TensorView, relativePositionBias: TensorView|undefined,
parameters: AttentionParameters, attributes: AttentionAttrs) => {
const probsShape =
[parameters.batchSize, parameters.numHeads, parameters.sequenceLength, parameters.totalSequenceLength];
parameters: AttentionParameters, attributes: AttentionAttrs, pastSequenceLength: number) => {
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
const probsShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, totalSequenceLength];
// TODO: handle mask
@ -344,14 +344,13 @@ const createAttentionProbsProgramInfo =
const vectorizedHeadSize = parameters.headSize / components;
const TILE_SIZE = 12;
const dispatch = {
x: Math.ceil(parameters.totalSequenceLength / TILE_SIZE),
x: Math.ceil(totalSequenceLength / TILE_SIZE),
y: Math.ceil(parameters.sequenceLength / TILE_SIZE),
z: parameters.batchSize * parameters.numHeads
};
const programUniforms: ProgramUniform[] = [
{type: DataType.uint32, data: parameters.sequenceLength}, {type: DataType.uint32, data: vectorizedHeadSize},
{type: DataType.uint32, data: parameters.totalSequenceLength},
{type: DataType.uint32, data: parameters.numHeads}, {type: DataType.uint32, data: parameters.kvSequenceLength},
{type: DataType.uint32, data: totalSequenceLength}, {type: DataType.uint32, data: parameters.numHeads},
{type: q.dataType, data: alpha}
];
@ -376,8 +375,7 @@ const createAttentionProbsProgramInfo =
const uniforms: UniformsArrayType = [
{name: 'M', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'},
{name: 'num_heads', type: 'u32'}, {name: 'kv_sequence_length', type: 'u32'},
{name: 'alpha', type: dataType as UniformDataElementType}
{name: 'num_heads', type: 'u32'}, {name: 'alpha', type: dataType as UniformDataElementType}
];
return `
const beta: ${dataType} = 1.0;
@ -394,7 +392,7 @@ const createAttentionProbsProgramInfo =
let m = workgroup_id.y * TILE_SIZE;
let n = workgroup_id.x * TILE_SIZE;
let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K;
let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx + n * uniforms.K;
let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K;
var value = ${qInput.type.value}(0);
for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
@ -456,7 +454,9 @@ const createAttentionProbsProgramInfo =
const createVxAttentionScoreProgramInfo =
(_context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters) => {
(_context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters,
pastSequenceLength: number) => {
const totalSequenceLength = pastSequenceLength + params.kvSequenceLength;
const outputShape = [params.batchSize, params.sequenceLength, params.vHiddenSize];
const TILE_SIZE = 12;
const dispatch = {
@ -465,7 +465,7 @@ const createVxAttentionScoreProgramInfo =
z: params.batchSize * params.numHeads
};
const programUniforms: ProgramUniform[] = [
{type: DataType.uint32, data: params.sequenceLength}, {type: DataType.uint32, data: params.totalSequenceLength},
{type: DataType.uint32, data: params.sequenceLength}, {type: DataType.uint32, data: totalSequenceLength},
{type: DataType.uint32, data: params.vHeadSize}, {type: DataType.uint32, data: params.numHeads},
{type: DataType.uint32, data: params.vHiddenSize}
];
@ -537,24 +537,25 @@ export const applyAttention =
(context: ComputeContext, q: TensorView, k: TensorView, v: TensorView, _maskIndex: TensorView|undefined,
_past: TensorView|undefined, pastKey: TensorView|undefined, pastValue: TensorView|undefined,
relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs) => {
const outputPresentKey = context.outputCount > 1;
const outputPresentValue = context.outputCount > 2;
const pastSequenceLength = (outputPresentKey && outputPresentValue) ? parameters.pastSequenceLength : 0;
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
// Concatinate pastKey and K to produce presentKey.
const presentKeyShape =
[parameters.batchSize, parameters.numHeads, parameters.totalSequenceLength, parameters.headSize];
const presentKeyShape = [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize];
const concatKeyInputs = pastKey ? [pastKey, k] : [k];
const key = (context.outputCount > 1 || pastKey) ?
context.compute(
createConcatProgramInfo(concatKeyInputs, 2, presentKeyShape, k.dataType),
{inputs: concatKeyInputs, outputs: [context.outputCount > 1 ? 1 : -1]})[0] :
k;
const key = outputPresentKey ? context.compute(
createConcatProgramInfo(concatKeyInputs, 2, presentKeyShape, k.dataType),
{inputs: concatKeyInputs, outputs: [1]})[0] :
k;
// Concatinate pastValue and V to produce presentValue.
const presentValueShape =
[parameters.batchSize, parameters.numHeads, parameters.totalSequenceLength, parameters.headSize];
const presentValueShape = [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize];
const concatValueInputs = pastValue ? [pastValue, v] : [v];
const value = (context.outputCount > 2 || pastValue) ?
const value = outputPresentValue ?
context.compute(
createConcatProgramInfo(concatValueInputs, 2, presentValueShape, v.dataType),
{inputs: concatValueInputs, outputs: [context.outputCount > 2 ? 2 : -1]})[0] :
{inputs: concatValueInputs, outputs: [2]})[0] :
v;
const inputsK = [q, key];
if (relativePositionBias) {
@ -563,20 +564,22 @@ export const applyAttention =
// Run AttentionProbs
const probs = context.compute(
createAttentionProbsProgramInfo(context, q, key, relativePositionBias, parameters, attributes),
createAttentionProbsProgramInfo(
context, q, key, relativePositionBias, parameters, attributes, pastSequenceLength),
{inputs: inputsK, outputs: [-1]})[0];
// Run Softmax
context.compute(
createInPlaceSoftmaxProgramInfo(
context, probs, parameters.batchSize * parameters.numHeads * parameters.sequenceLength,
parameters.totalSequenceLength),
totalSequenceLength),
{inputs: [probs], outputs: []});
// Run AttrionScore
const inputsV = [probs, value];
context.compute(
createVxAttentionScoreProgramInfo(context, probs, value, parameters), {inputs: inputsV, outputs: [0]});
createVxAttentionScoreProgramInfo(context, probs, value, parameters, pastSequenceLength),
{inputs: inputsV, outputs: [0]});
};
const prepare = (context: ComputeContext, parameters: AttentionParameters) => {

View file

@ -190,5 +190,656 @@
]
}
]
},
{
"name": "MultiHeadAttention Basic, one head and head-size=1 with pastKey and pastValue",
"operator": "MultiHeadAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
"cases": [
{
"name": "T[0]",
"inputs": [
// Q
{
"data": [1],
"dims": [1, 1, 1],
"type": "float32"
},
// K
{
"data": [2],
"dims": [1, 1, 1],
"type": "float32"
},
// V
{
"data": [3],
"dims": [1, 1, 1],
"type": "float32"
},
// Bias
{
"data": null,
"type": "float32"
},
// Mask
{
"data": null,
"type": "int32"
},
// RelativePositionBias
{
"data": null,
"type": "float32"
},
// PastKey
{
"data": [4],
"dims": [1, 1, 1, 1],
"type": "float32"
},
// PastValue
{
"data": [5],
"dims": [1, 1, 1, 1],
"type": "float32"
}
],
"outputs": [
{
"data": [3],
"dims": [1, 1, 1],
"type": "float32"
}
]
}
]
},
{
"name": "MultiHeadAttention Basic, one head and head-size=4 with pastKey and pastValue",
"operator": "MultiHeadAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
"cases": [
{
"name": "T[0]",
"inputs": [
// Q
{
"data": [1, 2, 3, 4],
"dims": [1, 1, 4],
"type": "float32"
},
// K
{
"data": [5, 6, 7, 8],
"dims": [1, 1, 4],
"type": "float32"
},
// V
{
"data": [9, 10, 11, 12],
"dims": [1, 1, 4],
"type": "float32"
},
// Bias
{
"data": null,
"type": "float32"
},
// Mask
{
"data": null,
"type": "int32"
},
// RelativePositionBias
{
"data": null,
"type": "float32"
},
// PastKey
{
"data": [13, 14, 15, 16],
"dims": [1, 1, 1, 4],
"type": "float32"
},
// PastValue
{
"data": [17, 18, 19, 20],
"dims": [1, 1, 1, 4],
"type": "float32"
}
],
"outputs": [
{
"data": [9, 10, 11, 12],
"dims": [1, 1, 4],
"type": "float32"
}
]
}
]
},
{
"name": "MultiHeadAttention Basic, one head and head-size=1 with pastKey and pastValue",
"operator": "MultiHeadAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
"cases": [
{
"name": "T[0]",
"inputs": [
// Q
{
"data": [1],
"dims": [1, 1, 1],
"type": "float32"
},
// K
{
"data": [2],
"dims": [1, 1, 1],
"type": "float32"
},
// V
{
"data": [3],
"dims": [1, 1, 1],
"type": "float32"
},
// Bias
{
"data": null,
"type": "float32"
},
// Mask
{
"data": null,
"type": "int32"
},
// RelativePositionBias
{
"data": null,
"type": "float32"
},
// PastKey
{
"data": [4],
"dims": [1, 1, 1, 1],
"type": "float32"
},
// PastValue
{
"data": [5],
"dims": [1, 1, 1, 1],
"type": "float32"
}
],
"outputs": [
{
"data": [4.761593818664551],
"dims": [1, 1, 1],
"type": "float32"
},
{
"data": [4, 2],
"dims": [1, 1, 2, 1],
"type": "float32"
},
{
"data": [5, 3],
"dims": [1, 1, 2, 1],
"type": "float32"
}
]
}
]
},
{
"name": "MultiHeadAttention Basic, one head and head-size=4 with pastKey and pastValue",
"operator": "MultiHeadAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
"cases": [
{
"name": "T[0]",
"inputs": [
// Q
{
"data": [1, 2, 3, 4],
"dims": [1, 1, 4],
"type": "float32"
},
// K
{
"data": [5, 6, 7, 8],
"dims": [1, 1, 4],
"type": "float32"
},
// V
{
"data": [9, 10, 11, 12],
"dims": [1, 1, 4],
"type": "float32"
},
// Bias
{
"data": null,
"type": "float32"
},
// Mask
{
"data": null,
"type": "int32"
},
// RelativePositionBias
{
"data": null,
"type": "float32"
},
// Past Key
{
"data": [13, 14, 15, 16],
"dims": [1, 1, 1, 4],
"type": "float32"
},
// Past Value
{
"data": [17, 18, 19, 20],
"dims": [1, 1, 1, 4],
"type": "float32"
}
],
"outputs": [
{
"data": [17, 18, 19, 20],
"dims": [1, 1, 4],
"type": "float32"
},
// Present key
{
"data": [13, 14, 15, 16, 5, 6, 7, 8],
"dims": [1, 1, 2, 4],
"type": "float32"
},
// Present value
{
"data": [17, 18, 19, 20, 9, 10, 11, 12],
"dims": [1, 1, 2, 4],
"type": "float32"
}
]
}
]
},
{
"name": "MultiHeadAttention Basic, one head and head-size=1 with pastKey and pastValue",
"operator": "MultiHeadAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
"cases": [
{
"name": "T[0]",
"inputs": [
// Q
{
"data": [1],
"dims": [1, 1, 1],
"type": "float32"
},
// K
{
"data": [2],
"dims": [1, 1, 1],
"type": "float32"
},
// V
{
"data": [3],
"dims": [1, 1, 1],
"type": "float32"
},
// Bias
{
"data": null,
"type": "float32"
},
// Mask
{
"data": null,
"type": "int32"
},
// RelativePositionBias
{
"data": null,
"type": "float32"
},
// PastKey
{
"data": [4],
"dims": [1, 1, 1, 1],
"type": "float32"
},
// PastValue
{
"data": [5],
"dims": [1, 1, 1, 1],
"type": "float32"
}
],
"outputs": [
{
"data": [3],
"dims": [1, 1, 1],
"type": "float32"
}
]
}
]
},
{
"name": "MultiHeadAttention Basic, one head and head-size=4 with pastKey and pastValue",
"operator": "MultiHeadAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
"cases": [
{
"name": "T[0]",
"inputs": [
// Q
{
"data": [1, 2, 3, 4],
"dims": [1, 1, 4],
"type": "float32"
},
// K
{
"data": [5, 6, 7, 8],
"dims": [1, 1, 4],
"type": "float32"
},
// V
{
"data": [9, 10, 11, 12],
"dims": [1, 1, 4],
"type": "float32"
},
// Bias
{
"data": null,
"type": "float32"
},
// Mask
{
"data": null,
"type": "int32"
},
// RelativePositionBias
{
"data": null,
"type": "float32"
},
// PastKey
{
"data": [13, 14, 15, 16],
"dims": [1, 1, 1, 4],
"type": "float32"
},
// PastValue
{
"data": [17, 18, 19, 20],
"dims": [1, 1, 1, 4],
"type": "float32"
}
],
"outputs": [
{
"data": [9, 10, 11, 12],
"dims": [1, 1, 4],
"type": "float32"
}
]
}
]
},
{
"name": "MultiHeadAttention Basic, 4 heads and head-size=1 with pastKey and pastValue",
"operator": "MultiHeadAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [{ "name": "num_heads", "data": 4, "type": "int" }],
"cases": [
{
"name": "T[0]",
"inputs": [
// Q
{
"data": [1, 2, 3, 4],
"dims": [1, 1, 4],
"type": "float32"
},
// K
{
"data": [5, 6, 7, 8],
"dims": [1, 1, 4],
"type": "float32"
},
// V
{
"data": [9, 10, 11, 12],
"dims": [1, 1, 4],
"type": "float32"
},
// Bias
{
"data": null,
"type": "float32"
},
// Mask
{
"data": null,
"type": "int32"
},
// RelativePositionBias
{
"data": null,
"type": "float32"
},
// PastKey
{
"data": [13, 14, 15, 16],
"dims": [1, 4, 1, 1],
"type": "float32"
},
// PastValue
{
"data": [17, 18, 19, 20],
"dims": [1, 4, 1, 1],
"type": "float32"
}
],
"outputs": [
{
"data": [16.997316360473633, 18, 19, 20],
"dims": [1, 1, 4],
"type": "float32"
},
{
"data": [13, 5, 14, 6, 15, 7, 16, 8],
"dims": [1, 4, 2, 1],
"type": "float32"
},
{
"data": [17, 9, 18, 10, 19, 11, 20, 12],
"dims": [1, 4, 2, 1],
"type": "float32"
}
]
}
]
},
{
"name": "MultiHeadAttention Basic, 4 heads and head-size=4 with pastKey and pastValue",
"operator": "MultiHeadAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [{ "name": "num_heads", "data": 4, "type": "int" }],
"cases": [
{
"name": "T[0]",
"inputs": [
// Q
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
"dims": [1, 1, 16],
"type": "float32"
},
// K
{
"data": [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1],
"dims": [1, 1, 16],
"type": "float32"
},
// V
{
"data": [2, 4, 8, 16, 1, 3, 9, 27, 1, 2, 4, 8, 16, 32, 64, 128],
"dims": [1, 1, 16],
"type": "float32"
},
// Bias
{
"data": null,
"type": "float32"
},
// Mask
{
"data": null,
"type": "int32"
},
// RelativePositionBias
{
"data": null,
"type": "float32"
},
// Past Key
{
"data": [13, 14, 15, 16, 5, 6, 7, 8, 1, 2, 3, 4, 9, 10, 11, 12],
"dims": [1, 4, 1, 4],
"type": "float32"
},
// Past Value
{
"data": [17, 18, 19, 20, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8],
"dims": [1, 4, 1, 4],
"type": "float32"
}
],
"outputs": [
{
"data": [
16.899608612060547, 17.906301498413086, 18.926380157470703, 19.973230361938477, 1, 3, 9, 27, 1, 2, 4, 8,
5, 6, 7, 8
],
"dims": [1, 1, 16],
"type": "float32"
},
// Present key
{
"data": [
13, 14, 15, 16, 16, 15, 14, 13, 5, 6, 7, 8, 12, 11, 10, 9, 1, 2, 3, 4, 8, 7, 6, 5, 9, 10, 11, 12, 4, 3, 2,
1
],
"dims": [1, 4, 2, 4],
"type": "float32"
},
// Present value
{
"data": [
17, 18, 19, 20, 2, 4, 8, 16, 9, 10, 11, 12, 1, 3, 9, 27, 1, 2, 3, 4, 1, 2, 4, 8, 5, 6, 7, 8, 16, 32, 64,
128
],
"dims": [1, 4, 2, 4],
"type": "float32"
}
]
}
]
},
{
"name": "MultiHeadAttention Basic, one head and head-size=4 with RelativePositionBias, PastKey and PastValue",
"operator": "MultiHeadAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
"cases": [
{
"name": "T[0]",
"inputs": [
// Q
{
"data": [1, 2, 3, 4],
"dims": [1, 1, 4],
"type": "float32"
},
// K
{
"data": [5, 6, 7, 8],
"dims": [1, 1, 4],
"type": "float32"
},
// V
{
"data": [9, 10, 11, 12],
"dims": [1, 1, 4],
"type": "float32"
},
// Bias
{
"data": null,
"type": "float32"
},
// Mask
{
"data": null,
"type": "int32"
},
// RelativePositionBias
{
"data": [10, 20],
"dims": [1, 1, 1, 2],
"type": "float32"
},
// Past Key
{
"data": [13, 14, 15, 16],
"dims": [1, 1, 1, 4],
"type": "float32"
},
// Past Value
{
"data": [17, 18, 19, 20],
"dims": [1, 1, 1, 4],
"type": "float32"
}
],
"outputs": [
{
"data": [17, 18, 19, 20],
"dims": [1, 1, 4],
"type": "float32"
},
// Present key
{
"data": [13, 14, 15, 16, 5, 6, 7, 8],
"dims": [1, 1, 2, 4],
"type": "float32"
},
// Present value
{
"data": [17, 18, 19, 20, 9, 10, 11, 12],
"dims": [1, 1, 2, 4],
"type": "float32"
}
]
}
]
}
]