mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
[JS/WebGPU] MultiheadAttention bugfix (#20447)
### Description Fixed pastkey, key and pastvalue, value concatenation condition and fixed index error. Added new test cases. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
33d5ea39b3
commit
ae78cdb5d7
2 changed files with 679 additions and 25 deletions
|
|
@ -333,9 +333,9 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor
|
|||
|
||||
const createAttentionProbsProgramInfo =
|
||||
(_context: ComputeContext, q: TensorView, key: TensorView, relativePositionBias: TensorView|undefined,
|
||||
parameters: AttentionParameters, attributes: AttentionAttrs) => {
|
||||
const probsShape =
|
||||
[parameters.batchSize, parameters.numHeads, parameters.sequenceLength, parameters.totalSequenceLength];
|
||||
parameters: AttentionParameters, attributes: AttentionAttrs, pastSequenceLength: number) => {
|
||||
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
|
||||
const probsShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, totalSequenceLength];
|
||||
|
||||
// TODO: handle mask
|
||||
|
||||
|
|
@ -344,14 +344,13 @@ const createAttentionProbsProgramInfo =
|
|||
const vectorizedHeadSize = parameters.headSize / components;
|
||||
const TILE_SIZE = 12;
|
||||
const dispatch = {
|
||||
x: Math.ceil(parameters.totalSequenceLength / TILE_SIZE),
|
||||
x: Math.ceil(totalSequenceLength / TILE_SIZE),
|
||||
y: Math.ceil(parameters.sequenceLength / TILE_SIZE),
|
||||
z: parameters.batchSize * parameters.numHeads
|
||||
};
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: DataType.uint32, data: parameters.sequenceLength}, {type: DataType.uint32, data: vectorizedHeadSize},
|
||||
{type: DataType.uint32, data: parameters.totalSequenceLength},
|
||||
{type: DataType.uint32, data: parameters.numHeads}, {type: DataType.uint32, data: parameters.kvSequenceLength},
|
||||
{type: DataType.uint32, data: totalSequenceLength}, {type: DataType.uint32, data: parameters.numHeads},
|
||||
{type: q.dataType, data: alpha}
|
||||
];
|
||||
|
||||
|
|
@ -376,8 +375,7 @@ const createAttentionProbsProgramInfo =
|
|||
|
||||
const uniforms: UniformsArrayType = [
|
||||
{name: 'M', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'},
|
||||
{name: 'num_heads', type: 'u32'}, {name: 'kv_sequence_length', type: 'u32'},
|
||||
{name: 'alpha', type: dataType as UniformDataElementType}
|
||||
{name: 'num_heads', type: 'u32'}, {name: 'alpha', type: dataType as UniformDataElementType}
|
||||
];
|
||||
return `
|
||||
const beta: ${dataType} = 1.0;
|
||||
|
|
@ -394,7 +392,7 @@ const createAttentionProbsProgramInfo =
|
|||
let m = workgroup_id.y * TILE_SIZE;
|
||||
let n = workgroup_id.x * TILE_SIZE;
|
||||
let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K;
|
||||
let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx + n * uniforms.K;
|
||||
let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K;
|
||||
|
||||
var value = ${qInput.type.value}(0);
|
||||
for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
|
||||
|
|
@ -456,7 +454,9 @@ const createAttentionProbsProgramInfo =
|
|||
|
||||
|
||||
const createVxAttentionScoreProgramInfo =
|
||||
(_context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters) => {
|
||||
(_context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters,
|
||||
pastSequenceLength: number) => {
|
||||
const totalSequenceLength = pastSequenceLength + params.kvSequenceLength;
|
||||
const outputShape = [params.batchSize, params.sequenceLength, params.vHiddenSize];
|
||||
const TILE_SIZE = 12;
|
||||
const dispatch = {
|
||||
|
|
@ -465,7 +465,7 @@ const createVxAttentionScoreProgramInfo =
|
|||
z: params.batchSize * params.numHeads
|
||||
};
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: DataType.uint32, data: params.sequenceLength}, {type: DataType.uint32, data: params.totalSequenceLength},
|
||||
{type: DataType.uint32, data: params.sequenceLength}, {type: DataType.uint32, data: totalSequenceLength},
|
||||
{type: DataType.uint32, data: params.vHeadSize}, {type: DataType.uint32, data: params.numHeads},
|
||||
{type: DataType.uint32, data: params.vHiddenSize}
|
||||
];
|
||||
|
|
@ -537,24 +537,25 @@ export const applyAttention =
|
|||
(context: ComputeContext, q: TensorView, k: TensorView, v: TensorView, _maskIndex: TensorView|undefined,
|
||||
_past: TensorView|undefined, pastKey: TensorView|undefined, pastValue: TensorView|undefined,
|
||||
relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs) => {
|
||||
const outputPresentKey = context.outputCount > 1;
|
||||
const outputPresentValue = context.outputCount > 2;
|
||||
const pastSequenceLength = (outputPresentKey && outputPresentValue) ? parameters.pastSequenceLength : 0;
|
||||
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
|
||||
// Concatinate pastKey and K to produce presentKey.
|
||||
const presentKeyShape =
|
||||
[parameters.batchSize, parameters.numHeads, parameters.totalSequenceLength, parameters.headSize];
|
||||
const presentKeyShape = [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize];
|
||||
const concatKeyInputs = pastKey ? [pastKey, k] : [k];
|
||||
const key = (context.outputCount > 1 || pastKey) ?
|
||||
context.compute(
|
||||
createConcatProgramInfo(concatKeyInputs, 2, presentKeyShape, k.dataType),
|
||||
{inputs: concatKeyInputs, outputs: [context.outputCount > 1 ? 1 : -1]})[0] :
|
||||
k;
|
||||
const key = outputPresentKey ? context.compute(
|
||||
createConcatProgramInfo(concatKeyInputs, 2, presentKeyShape, k.dataType),
|
||||
{inputs: concatKeyInputs, outputs: [1]})[0] :
|
||||
k;
|
||||
|
||||
// Concatinate pastValue and V to produce presentValue.
|
||||
const presentValueShape =
|
||||
[parameters.batchSize, parameters.numHeads, parameters.totalSequenceLength, parameters.headSize];
|
||||
const presentValueShape = [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize];
|
||||
const concatValueInputs = pastValue ? [pastValue, v] : [v];
|
||||
const value = (context.outputCount > 2 || pastValue) ?
|
||||
const value = outputPresentValue ?
|
||||
context.compute(
|
||||
createConcatProgramInfo(concatValueInputs, 2, presentValueShape, v.dataType),
|
||||
{inputs: concatValueInputs, outputs: [context.outputCount > 2 ? 2 : -1]})[0] :
|
||||
{inputs: concatValueInputs, outputs: [2]})[0] :
|
||||
v;
|
||||
const inputsK = [q, key];
|
||||
if (relativePositionBias) {
|
||||
|
|
@ -563,20 +564,22 @@ export const applyAttention =
|
|||
|
||||
// Run AttentionProbs
|
||||
const probs = context.compute(
|
||||
createAttentionProbsProgramInfo(context, q, key, relativePositionBias, parameters, attributes),
|
||||
createAttentionProbsProgramInfo(
|
||||
context, q, key, relativePositionBias, parameters, attributes, pastSequenceLength),
|
||||
{inputs: inputsK, outputs: [-1]})[0];
|
||||
|
||||
// Run Softmax
|
||||
context.compute(
|
||||
createInPlaceSoftmaxProgramInfo(
|
||||
context, probs, parameters.batchSize * parameters.numHeads * parameters.sequenceLength,
|
||||
parameters.totalSequenceLength),
|
||||
totalSequenceLength),
|
||||
{inputs: [probs], outputs: []});
|
||||
|
||||
// Run AttrionScore
|
||||
const inputsV = [probs, value];
|
||||
context.compute(
|
||||
createVxAttentionScoreProgramInfo(context, probs, value, parameters), {inputs: inputsV, outputs: [0]});
|
||||
createVxAttentionScoreProgramInfo(context, probs, value, parameters, pastSequenceLength),
|
||||
{inputs: inputsV, outputs: [0]});
|
||||
};
|
||||
|
||||
const prepare = (context: ComputeContext, parameters: AttentionParameters) => {
|
||||
|
|
|
|||
|
|
@ -190,5 +190,656 @@
|
|||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "MultiHeadAttention Basic, one head and head-size=1 with pastKey and pastValue",
|
||||
"operator": "MultiHeadAttention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
// Q
|
||||
{
|
||||
"data": [1],
|
||||
"dims": [1, 1, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
// K
|
||||
{
|
||||
"data": [2],
|
||||
"dims": [1, 1, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
// V
|
||||
{
|
||||
"data": [3],
|
||||
"dims": [1, 1, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
// Bias
|
||||
{
|
||||
"data": null,
|
||||
"type": "float32"
|
||||
},
|
||||
// Mask
|
||||
{
|
||||
"data": null,
|
||||
"type": "int32"
|
||||
},
|
||||
// RelativePositionBias
|
||||
{
|
||||
"data": null,
|
||||
"type": "float32"
|
||||
},
|
||||
// PastKey
|
||||
{
|
||||
"data": [4],
|
||||
"dims": [1, 1, 1, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
// PastValue
|
||||
{
|
||||
"data": [5],
|
||||
"dims": [1, 1, 1, 1],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [3],
|
||||
"dims": [1, 1, 1],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "MultiHeadAttention Basic, one head and head-size=4 with pastKey and pastValue",
|
||||
"operator": "MultiHeadAttention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
// Q
|
||||
{
|
||||
"data": [1, 2, 3, 4],
|
||||
"dims": [1, 1, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// K
|
||||
{
|
||||
"data": [5, 6, 7, 8],
|
||||
"dims": [1, 1, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// V
|
||||
{
|
||||
"data": [9, 10, 11, 12],
|
||||
"dims": [1, 1, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// Bias
|
||||
{
|
||||
"data": null,
|
||||
"type": "float32"
|
||||
},
|
||||
// Mask
|
||||
{
|
||||
"data": null,
|
||||
"type": "int32"
|
||||
},
|
||||
// RelativePositionBias
|
||||
{
|
||||
"data": null,
|
||||
"type": "float32"
|
||||
},
|
||||
// PastKey
|
||||
{
|
||||
"data": [13, 14, 15, 16],
|
||||
"dims": [1, 1, 1, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// PastValue
|
||||
{
|
||||
"data": [17, 18, 19, 20],
|
||||
"dims": [1, 1, 1, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [9, 10, 11, 12],
|
||||
"dims": [1, 1, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "MultiHeadAttention Basic, one head and head-size=1 with pastKey and pastValue",
|
||||
"operator": "MultiHeadAttention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
// Q
|
||||
{
|
||||
"data": [1],
|
||||
"dims": [1, 1, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
// K
|
||||
{
|
||||
"data": [2],
|
||||
"dims": [1, 1, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
// V
|
||||
{
|
||||
"data": [3],
|
||||
"dims": [1, 1, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
// Bias
|
||||
{
|
||||
"data": null,
|
||||
"type": "float32"
|
||||
},
|
||||
// Mask
|
||||
{
|
||||
"data": null,
|
||||
"type": "int32"
|
||||
},
|
||||
// RelativePositionBias
|
||||
{
|
||||
"data": null,
|
||||
"type": "float32"
|
||||
},
|
||||
// PastKey
|
||||
{
|
||||
"data": [4],
|
||||
"dims": [1, 1, 1, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
// PastValue
|
||||
{
|
||||
"data": [5],
|
||||
"dims": [1, 1, 1, 1],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [4.761593818664551],
|
||||
"dims": [1, 1, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [4, 2],
|
||||
"dims": [1, 1, 2, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [5, 3],
|
||||
"dims": [1, 1, 2, 1],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "MultiHeadAttention Basic, one head and head-size=4 with pastKey and pastValue",
|
||||
"operator": "MultiHeadAttention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
// Q
|
||||
{
|
||||
"data": [1, 2, 3, 4],
|
||||
"dims": [1, 1, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// K
|
||||
{
|
||||
"data": [5, 6, 7, 8],
|
||||
"dims": [1, 1, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// V
|
||||
{
|
||||
"data": [9, 10, 11, 12],
|
||||
"dims": [1, 1, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// Bias
|
||||
{
|
||||
"data": null,
|
||||
"type": "float32"
|
||||
},
|
||||
// Mask
|
||||
{
|
||||
"data": null,
|
||||
"type": "int32"
|
||||
},
|
||||
// RelativePositionBias
|
||||
{
|
||||
"data": null,
|
||||
"type": "float32"
|
||||
},
|
||||
// Past Key
|
||||
{
|
||||
"data": [13, 14, 15, 16],
|
||||
"dims": [1, 1, 1, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// Past Value
|
||||
{
|
||||
"data": [17, 18, 19, 20],
|
||||
"dims": [1, 1, 1, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [17, 18, 19, 20],
|
||||
"dims": [1, 1, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// Present key
|
||||
{
|
||||
"data": [13, 14, 15, 16, 5, 6, 7, 8],
|
||||
"dims": [1, 1, 2, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// Present value
|
||||
{
|
||||
"data": [17, 18, 19, 20, 9, 10, 11, 12],
|
||||
"dims": [1, 1, 2, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "MultiHeadAttention Basic, one head and head-size=1 with pastKey and pastValue",
|
||||
"operator": "MultiHeadAttention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
// Q
|
||||
{
|
||||
"data": [1],
|
||||
"dims": [1, 1, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
// K
|
||||
{
|
||||
"data": [2],
|
||||
"dims": [1, 1, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
// V
|
||||
{
|
||||
"data": [3],
|
||||
"dims": [1, 1, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
// Bias
|
||||
{
|
||||
"data": null,
|
||||
"type": "float32"
|
||||
},
|
||||
// Mask
|
||||
{
|
||||
"data": null,
|
||||
"type": "int32"
|
||||
},
|
||||
// RelativePositionBias
|
||||
{
|
||||
"data": null,
|
||||
"type": "float32"
|
||||
},
|
||||
// PastKey
|
||||
{
|
||||
"data": [4],
|
||||
"dims": [1, 1, 1, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
// PastValue
|
||||
{
|
||||
"data": [5],
|
||||
"dims": [1, 1, 1, 1],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [3],
|
||||
"dims": [1, 1, 1],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "MultiHeadAttention Basic, one head and head-size=4 with pastKey and pastValue",
|
||||
"operator": "MultiHeadAttention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
// Q
|
||||
{
|
||||
"data": [1, 2, 3, 4],
|
||||
"dims": [1, 1, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// K
|
||||
{
|
||||
"data": [5, 6, 7, 8],
|
||||
"dims": [1, 1, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// V
|
||||
{
|
||||
"data": [9, 10, 11, 12],
|
||||
"dims": [1, 1, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// Bias
|
||||
{
|
||||
"data": null,
|
||||
"type": "float32"
|
||||
},
|
||||
// Mask
|
||||
{
|
||||
"data": null,
|
||||
"type": "int32"
|
||||
},
|
||||
// RelativePositionBias
|
||||
{
|
||||
"data": null,
|
||||
"type": "float32"
|
||||
},
|
||||
// PastKey
|
||||
{
|
||||
"data": [13, 14, 15, 16],
|
||||
"dims": [1, 1, 1, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// PastValue
|
||||
{
|
||||
"data": [17, 18, 19, 20],
|
||||
"dims": [1, 1, 1, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [9, 10, 11, 12],
|
||||
"dims": [1, 1, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "MultiHeadAttention Basic, 4 heads and head-size=1 with pastKey and pastValue",
|
||||
"operator": "MultiHeadAttention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 4, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
// Q
|
||||
{
|
||||
"data": [1, 2, 3, 4],
|
||||
"dims": [1, 1, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// K
|
||||
{
|
||||
"data": [5, 6, 7, 8],
|
||||
"dims": [1, 1, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// V
|
||||
{
|
||||
"data": [9, 10, 11, 12],
|
||||
"dims": [1, 1, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// Bias
|
||||
{
|
||||
"data": null,
|
||||
"type": "float32"
|
||||
},
|
||||
// Mask
|
||||
{
|
||||
"data": null,
|
||||
"type": "int32"
|
||||
},
|
||||
// RelativePositionBias
|
||||
{
|
||||
"data": null,
|
||||
"type": "float32"
|
||||
},
|
||||
// PastKey
|
||||
{
|
||||
"data": [13, 14, 15, 16],
|
||||
"dims": [1, 4, 1, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
// PastValue
|
||||
{
|
||||
"data": [17, 18, 19, 20],
|
||||
"dims": [1, 4, 1, 1],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [16.997316360473633, 18, 19, 20],
|
||||
"dims": [1, 1, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [13, 5, 14, 6, 15, 7, 16, 8],
|
||||
"dims": [1, 4, 2, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [17, 9, 18, 10, 19, 11, 20, 12],
|
||||
"dims": [1, 4, 2, 1],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "MultiHeadAttention Basic, 4 heads and head-size=4 with pastKey and pastValue",
|
||||
"operator": "MultiHeadAttention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 4, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
// Q
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
|
||||
"dims": [1, 1, 16],
|
||||
"type": "float32"
|
||||
},
|
||||
// K
|
||||
{
|
||||
"data": [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1],
|
||||
"dims": [1, 1, 16],
|
||||
"type": "float32"
|
||||
},
|
||||
// V
|
||||
{
|
||||
"data": [2, 4, 8, 16, 1, 3, 9, 27, 1, 2, 4, 8, 16, 32, 64, 128],
|
||||
"dims": [1, 1, 16],
|
||||
"type": "float32"
|
||||
},
|
||||
// Bias
|
||||
{
|
||||
"data": null,
|
||||
"type": "float32"
|
||||
},
|
||||
// Mask
|
||||
{
|
||||
"data": null,
|
||||
"type": "int32"
|
||||
},
|
||||
// RelativePositionBias
|
||||
{
|
||||
"data": null,
|
||||
"type": "float32"
|
||||
},
|
||||
// Past Key
|
||||
{
|
||||
"data": [13, 14, 15, 16, 5, 6, 7, 8, 1, 2, 3, 4, 9, 10, 11, 12],
|
||||
"dims": [1, 4, 1, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// Past Value
|
||||
{
|
||||
"data": [17, 18, 19, 20, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8],
|
||||
"dims": [1, 4, 1, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [
|
||||
16.899608612060547, 17.906301498413086, 18.926380157470703, 19.973230361938477, 1, 3, 9, 27, 1, 2, 4, 8,
|
||||
5, 6, 7, 8
|
||||
],
|
||||
"dims": [1, 1, 16],
|
||||
"type": "float32"
|
||||
},
|
||||
// Present key
|
||||
{
|
||||
"data": [
|
||||
13, 14, 15, 16, 16, 15, 14, 13, 5, 6, 7, 8, 12, 11, 10, 9, 1, 2, 3, 4, 8, 7, 6, 5, 9, 10, 11, 12, 4, 3, 2,
|
||||
1
|
||||
],
|
||||
"dims": [1, 4, 2, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// Present value
|
||||
{
|
||||
"data": [
|
||||
17, 18, 19, 20, 2, 4, 8, 16, 9, 10, 11, 12, 1, 3, 9, 27, 1, 2, 3, 4, 1, 2, 4, 8, 5, 6, 7, 8, 16, 32, 64,
|
||||
128
|
||||
],
|
||||
"dims": [1, 4, 2, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "MultiHeadAttention Basic, one head and head-size=4 with RelativePositionBias, PastKey and PastValue",
|
||||
"operator": "MultiHeadAttention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
// Q
|
||||
{
|
||||
"data": [1, 2, 3, 4],
|
||||
"dims": [1, 1, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// K
|
||||
{
|
||||
"data": [5, 6, 7, 8],
|
||||
"dims": [1, 1, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// V
|
||||
{
|
||||
"data": [9, 10, 11, 12],
|
||||
"dims": [1, 1, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// Bias
|
||||
{
|
||||
"data": null,
|
||||
"type": "float32"
|
||||
},
|
||||
// Mask
|
||||
{
|
||||
"data": null,
|
||||
"type": "int32"
|
||||
},
|
||||
// RelativePositionBias
|
||||
{
|
||||
"data": [10, 20],
|
||||
"dims": [1, 1, 1, 2],
|
||||
"type": "float32"
|
||||
},
|
||||
// Past Key
|
||||
{
|
||||
"data": [13, 14, 15, 16],
|
||||
"dims": [1, 1, 1, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// Past Value
|
||||
{
|
||||
"data": [17, 18, 19, 20],
|
||||
"dims": [1, 1, 1, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [17, 18, 19, 20],
|
||||
"dims": [1, 1, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// Present key
|
||||
{
|
||||
"data": [13, 14, 15, 16, 5, 6, 7, 8],
|
||||
"dims": [1, 1, 2, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// Present value
|
||||
{
|
||||
"data": [17, 18, 19, 20, 9, 10, 11, 12],
|
||||
"dims": [1, 1, 2, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in a new issue