From ae78cdb5d74dadcc4ab38bb7fe8ce79a04c36de8 Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Wed, 24 Apr 2024 08:43:14 -0700 Subject: [PATCH] [JS/WebGPU] MultiheadAttention bugfix (#20447) ### Description Fixed pastkey, key and pastvalue, value concatenation condition and fixed index error. Added new test cases. ### Motivation and Context --- js/web/lib/wasm/jsep/webgpu/ops/attention.ts | 53 +- .../test/data/ops/multihead-attention.jsonc | 651 ++++++++++++++++++ 2 files changed, 679 insertions(+), 25 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index 57e96640c3..db9bb73e39 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -333,9 +333,9 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor const createAttentionProbsProgramInfo = (_context: ComputeContext, q: TensorView, key: TensorView, relativePositionBias: TensorView|undefined, - parameters: AttentionParameters, attributes: AttentionAttrs) => { - const probsShape = - [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, parameters.totalSequenceLength]; + parameters: AttentionParameters, attributes: AttentionAttrs, pastSequenceLength: number) => { + const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength; + const probsShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, totalSequenceLength]; // TODO: handle mask @@ -344,14 +344,13 @@ const createAttentionProbsProgramInfo = const vectorizedHeadSize = parameters.headSize / components; const TILE_SIZE = 12; const dispatch = { - x: Math.ceil(parameters.totalSequenceLength / TILE_SIZE), + x: Math.ceil(totalSequenceLength / TILE_SIZE), y: Math.ceil(parameters.sequenceLength / TILE_SIZE), z: parameters.batchSize * parameters.numHeads }; const programUniforms: ProgramUniform[] = [ {type: DataType.uint32, data: parameters.sequenceLength}, {type: DataType.uint32, data: vectorizedHeadSize}, - {type: DataType.uint32, data: parameters.totalSequenceLength}, - {type: DataType.uint32, data: parameters.numHeads}, {type: DataType.uint32, data: parameters.kvSequenceLength}, + {type: DataType.uint32, data: totalSequenceLength}, {type: DataType.uint32, data: parameters.numHeads}, {type: q.dataType, data: alpha} ]; @@ -376,8 +375,7 @@ const createAttentionProbsProgramInfo = const uniforms: UniformsArrayType = [ {name: 'M', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'}, - {name: 'num_heads', type: 'u32'}, {name: 'kv_sequence_length', type: 'u32'}, - {name: 'alpha', type: dataType as UniformDataElementType} + {name: 'num_heads', type: 'u32'}, {name: 'alpha', type: dataType as UniformDataElementType} ]; return ` const beta: ${dataType} = 1.0; @@ -394,7 +392,7 @@ const createAttentionProbsProgramInfo = let m = workgroup_id.y * TILE_SIZE; let n = workgroup_id.x * TILE_SIZE; let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K; - let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx + n * uniforms.K; + let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K; var value = ${qInput.type.value}(0); for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) { @@ -456,7 +454,9 @@ const createAttentionProbsProgramInfo = const createVxAttentionScoreProgramInfo = - (_context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters) => { + (_context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters, + pastSequenceLength: number) => { + const totalSequenceLength = pastSequenceLength + params.kvSequenceLength; const outputShape = [params.batchSize, params.sequenceLength, params.vHiddenSize]; const TILE_SIZE = 12; const dispatch = { @@ -465,7 +465,7 @@ const createVxAttentionScoreProgramInfo = z: params.batchSize * params.numHeads }; const programUniforms: ProgramUniform[] = [ - {type: DataType.uint32, data: params.sequenceLength}, {type: DataType.uint32, data: params.totalSequenceLength}, + {type: DataType.uint32, data: params.sequenceLength}, {type: DataType.uint32, data: totalSequenceLength}, {type: DataType.uint32, data: params.vHeadSize}, {type: DataType.uint32, data: params.numHeads}, {type: DataType.uint32, data: params.vHiddenSize} ]; @@ -537,24 +537,25 @@ export const applyAttention = (context: ComputeContext, q: TensorView, k: TensorView, v: TensorView, _maskIndex: TensorView|undefined, _past: TensorView|undefined, pastKey: TensorView|undefined, pastValue: TensorView|undefined, relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs) => { + const outputPresentKey = context.outputCount > 1; + const outputPresentValue = context.outputCount > 2; + const pastSequenceLength = (outputPresentKey && outputPresentValue) ? parameters.pastSequenceLength : 0; + const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength; // Concatinate pastKey and K to produce presentKey. - const presentKeyShape = - [parameters.batchSize, parameters.numHeads, parameters.totalSequenceLength, parameters.headSize]; + const presentKeyShape = [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize]; const concatKeyInputs = pastKey ? [pastKey, k] : [k]; - const key = (context.outputCount > 1 || pastKey) ? - context.compute( - createConcatProgramInfo(concatKeyInputs, 2, presentKeyShape, k.dataType), - {inputs: concatKeyInputs, outputs: [context.outputCount > 1 ? 1 : -1]})[0] : - k; + const key = outputPresentKey ? context.compute( + createConcatProgramInfo(concatKeyInputs, 2, presentKeyShape, k.dataType), + {inputs: concatKeyInputs, outputs: [1]})[0] : + k; // Concatinate pastValue and V to produce presentValue. - const presentValueShape = - [parameters.batchSize, parameters.numHeads, parameters.totalSequenceLength, parameters.headSize]; + const presentValueShape = [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize]; const concatValueInputs = pastValue ? [pastValue, v] : [v]; - const value = (context.outputCount > 2 || pastValue) ? + const value = outputPresentValue ? context.compute( createConcatProgramInfo(concatValueInputs, 2, presentValueShape, v.dataType), - {inputs: concatValueInputs, outputs: [context.outputCount > 2 ? 2 : -1]})[0] : + {inputs: concatValueInputs, outputs: [2]})[0] : v; const inputsK = [q, key]; if (relativePositionBias) { @@ -563,20 +564,22 @@ export const applyAttention = // Run AttentionProbs const probs = context.compute( - createAttentionProbsProgramInfo(context, q, key, relativePositionBias, parameters, attributes), + createAttentionProbsProgramInfo( + context, q, key, relativePositionBias, parameters, attributes, pastSequenceLength), {inputs: inputsK, outputs: [-1]})[0]; // Run Softmax context.compute( createInPlaceSoftmaxProgramInfo( context, probs, parameters.batchSize * parameters.numHeads * parameters.sequenceLength, - parameters.totalSequenceLength), + totalSequenceLength), {inputs: [probs], outputs: []}); // Run AttrionScore const inputsV = [probs, value]; context.compute( - createVxAttentionScoreProgramInfo(context, probs, value, parameters), {inputs: inputsV, outputs: [0]}); + createVxAttentionScoreProgramInfo(context, probs, value, parameters, pastSequenceLength), + {inputs: inputsV, outputs: [0]}); }; const prepare = (context: ComputeContext, parameters: AttentionParameters) => { diff --git a/js/web/test/data/ops/multihead-attention.jsonc b/js/web/test/data/ops/multihead-attention.jsonc index 05687bd482..0bed30747b 100644 --- a/js/web/test/data/ops/multihead-attention.jsonc +++ b/js/web/test/data/ops/multihead-attention.jsonc @@ -190,5 +190,656 @@ ] } ] + }, + { + "name": "MultiHeadAttention Basic, one head and head-size=1 with pastKey and pastValue", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + // Q + { + "data": [1], + "dims": [1, 1, 1], + "type": "float32" + }, + // K + { + "data": [2], + "dims": [1, 1, 1], + "type": "float32" + }, + // V + { + "data": [3], + "dims": [1, 1, 1], + "type": "float32" + }, + // Bias + { + "data": null, + "type": "float32" + }, + // Mask + { + "data": null, + "type": "int32" + }, + // RelativePositionBias + { + "data": null, + "type": "float32" + }, + // PastKey + { + "data": [4], + "dims": [1, 1, 1, 1], + "type": "float32" + }, + // PastValue + { + "data": [5], + "dims": [1, 1, 1, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [3], + "dims": [1, 1, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention Basic, one head and head-size=4 with pastKey and pastValue", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + // Q + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 4], + "type": "float32" + }, + // K + { + "data": [5, 6, 7, 8], + "dims": [1, 1, 4], + "type": "float32" + }, + // V + { + "data": [9, 10, 11, 12], + "dims": [1, 1, 4], + "type": "float32" + }, + // Bias + { + "data": null, + "type": "float32" + }, + // Mask + { + "data": null, + "type": "int32" + }, + // RelativePositionBias + { + "data": null, + "type": "float32" + }, + // PastKey + { + "data": [13, 14, 15, 16], + "dims": [1, 1, 1, 4], + "type": "float32" + }, + // PastValue + { + "data": [17, 18, 19, 20], + "dims": [1, 1, 1, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [9, 10, 11, 12], + "dims": [1, 1, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention Basic, one head and head-size=1 with pastKey and pastValue", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + // Q + { + "data": [1], + "dims": [1, 1, 1], + "type": "float32" + }, + // K + { + "data": [2], + "dims": [1, 1, 1], + "type": "float32" + }, + // V + { + "data": [3], + "dims": [1, 1, 1], + "type": "float32" + }, + // Bias + { + "data": null, + "type": "float32" + }, + // Mask + { + "data": null, + "type": "int32" + }, + // RelativePositionBias + { + "data": null, + "type": "float32" + }, + // PastKey + { + "data": [4], + "dims": [1, 1, 1, 1], + "type": "float32" + }, + // PastValue + { + "data": [5], + "dims": [1, 1, 1, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [4.761593818664551], + "dims": [1, 1, 1], + "type": "float32" + }, + { + "data": [4, 2], + "dims": [1, 1, 2, 1], + "type": "float32" + }, + { + "data": [5, 3], + "dims": [1, 1, 2, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention Basic, one head and head-size=4 with pastKey and pastValue", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + // Q + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 4], + "type": "float32" + }, + // K + { + "data": [5, 6, 7, 8], + "dims": [1, 1, 4], + "type": "float32" + }, + // V + { + "data": [9, 10, 11, 12], + "dims": [1, 1, 4], + "type": "float32" + }, + // Bias + { + "data": null, + "type": "float32" + }, + // Mask + { + "data": null, + "type": "int32" + }, + // RelativePositionBias + { + "data": null, + "type": "float32" + }, + // Past Key + { + "data": [13, 14, 15, 16], + "dims": [1, 1, 1, 4], + "type": "float32" + }, + // Past Value + { + "data": [17, 18, 19, 20], + "dims": [1, 1, 1, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [17, 18, 19, 20], + "dims": [1, 1, 4], + "type": "float32" + }, + // Present key + { + "data": [13, 14, 15, 16, 5, 6, 7, 8], + "dims": [1, 1, 2, 4], + "type": "float32" + }, + // Present value + { + "data": [17, 18, 19, 20, 9, 10, 11, 12], + "dims": [1, 1, 2, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention Basic, one head and head-size=1 with pastKey and pastValue", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + // Q + { + "data": [1], + "dims": [1, 1, 1], + "type": "float32" + }, + // K + { + "data": [2], + "dims": [1, 1, 1], + "type": "float32" + }, + // V + { + "data": [3], + "dims": [1, 1, 1], + "type": "float32" + }, + // Bias + { + "data": null, + "type": "float32" + }, + // Mask + { + "data": null, + "type": "int32" + }, + // RelativePositionBias + { + "data": null, + "type": "float32" + }, + // PastKey + { + "data": [4], + "dims": [1, 1, 1, 1], + "type": "float32" + }, + // PastValue + { + "data": [5], + "dims": [1, 1, 1, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [3], + "dims": [1, 1, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention Basic, one head and head-size=4 with pastKey and pastValue", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + // Q + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 4], + "type": "float32" + }, + // K + { + "data": [5, 6, 7, 8], + "dims": [1, 1, 4], + "type": "float32" + }, + // V + { + "data": [9, 10, 11, 12], + "dims": [1, 1, 4], + "type": "float32" + }, + // Bias + { + "data": null, + "type": "float32" + }, + // Mask + { + "data": null, + "type": "int32" + }, + // RelativePositionBias + { + "data": null, + "type": "float32" + }, + // PastKey + { + "data": [13, 14, 15, 16], + "dims": [1, 1, 1, 4], + "type": "float32" + }, + // PastValue + { + "data": [17, 18, 19, 20], + "dims": [1, 1, 1, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [9, 10, 11, 12], + "dims": [1, 1, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention Basic, 4 heads and head-size=1 with pastKey and pastValue", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 4, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + // Q + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 4], + "type": "float32" + }, + // K + { + "data": [5, 6, 7, 8], + "dims": [1, 1, 4], + "type": "float32" + }, + // V + { + "data": [9, 10, 11, 12], + "dims": [1, 1, 4], + "type": "float32" + }, + // Bias + { + "data": null, + "type": "float32" + }, + // Mask + { + "data": null, + "type": "int32" + }, + // RelativePositionBias + { + "data": null, + "type": "float32" + }, + // PastKey + { + "data": [13, 14, 15, 16], + "dims": [1, 4, 1, 1], + "type": "float32" + }, + // PastValue + { + "data": [17, 18, 19, 20], + "dims": [1, 4, 1, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [16.997316360473633, 18, 19, 20], + "dims": [1, 1, 4], + "type": "float32" + }, + { + "data": [13, 5, 14, 6, 15, 7, 16, 8], + "dims": [1, 4, 2, 1], + "type": "float32" + }, + { + "data": [17, 9, 18, 10, 19, 11, 20, 12], + "dims": [1, 4, 2, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention Basic, 4 heads and head-size=4 with pastKey and pastValue", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 4, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + // Q + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [1, 1, 16], + "type": "float32" + }, + // K + { + "data": [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1], + "dims": [1, 1, 16], + "type": "float32" + }, + // V + { + "data": [2, 4, 8, 16, 1, 3, 9, 27, 1, 2, 4, 8, 16, 32, 64, 128], + "dims": [1, 1, 16], + "type": "float32" + }, + // Bias + { + "data": null, + "type": "float32" + }, + // Mask + { + "data": null, + "type": "int32" + }, + // RelativePositionBias + { + "data": null, + "type": "float32" + }, + // Past Key + { + "data": [13, 14, 15, 16, 5, 6, 7, 8, 1, 2, 3, 4, 9, 10, 11, 12], + "dims": [1, 4, 1, 4], + "type": "float32" + }, + // Past Value + { + "data": [17, 18, 19, 20, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 4, 1, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 16.899608612060547, 17.906301498413086, 18.926380157470703, 19.973230361938477, 1, 3, 9, 27, 1, 2, 4, 8, + 5, 6, 7, 8 + ], + "dims": [1, 1, 16], + "type": "float32" + }, + // Present key + { + "data": [ + 13, 14, 15, 16, 16, 15, 14, 13, 5, 6, 7, 8, 12, 11, 10, 9, 1, 2, 3, 4, 8, 7, 6, 5, 9, 10, 11, 12, 4, 3, 2, + 1 + ], + "dims": [1, 4, 2, 4], + "type": "float32" + }, + // Present value + { + "data": [ + 17, 18, 19, 20, 2, 4, 8, 16, 9, 10, 11, 12, 1, 3, 9, 27, 1, 2, 3, 4, 1, 2, 4, 8, 5, 6, 7, 8, 16, 32, 64, + 128 + ], + "dims": [1, 4, 2, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention Basic, one head and head-size=4 with RelativePositionBias, PastKey and PastValue", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + // Q + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 4], + "type": "float32" + }, + // K + { + "data": [5, 6, 7, 8], + "dims": [1, 1, 4], + "type": "float32" + }, + // V + { + "data": [9, 10, 11, 12], + "dims": [1, 1, 4], + "type": "float32" + }, + // Bias + { + "data": null, + "type": "float32" + }, + // Mask + { + "data": null, + "type": "int32" + }, + // RelativePositionBias + { + "data": [10, 20], + "dims": [1, 1, 1, 2], + "type": "float32" + }, + // Past Key + { + "data": [13, 14, 15, 16], + "dims": [1, 1, 1, 4], + "type": "float32" + }, + // Past Value + { + "data": [17, 18, 19, 20], + "dims": [1, 1, 1, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [17, 18, 19, 20], + "dims": [1, 1, 4], + "type": "float32" + }, + // Present key + { + "data": [13, 14, 15, 16, 5, 6, 7, 8], + "dims": [1, 1, 2, 4], + "type": "float32" + }, + // Present value + { + "data": [17, 18, 19, 20, 9, 10, 11, 12], + "dims": [1, 1, 2, 4], + "type": "float32" + } + ] + } + ] } ]