mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Eliminate explicit Concat operations in Attention (#20556)
### Description Remove explicitly concatinating pastKey with Key and pastValue with Value. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
535a030b1e
commit
bab5037eab
3 changed files with 149 additions and 75 deletions
|
|
@ -6,7 +6,6 @@ import {TensorView} from '../../tensor-view';
|
|||
import {ComputeContext, GpuDataType, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';
|
||||
|
||||
import {getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, tensorTypeToWsglValueType, UniformDataElementType, UniformsArrayType} from './common';
|
||||
import {createConcatProgramInfo} from './concat';
|
||||
|
||||
export const enum AttentionQkvFormat {
|
||||
unknown, // enum value not set, or depends on qkv projection implementation details
|
||||
|
|
@ -336,10 +335,15 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor
|
|||
};
|
||||
|
||||
const createAttentionProbsProgramInfo =
|
||||
(_context: ComputeContext, q: TensorView, key: TensorView, relativePositionBias: TensorView|undefined,
|
||||
parameters: AttentionParameters, attributes: AttentionAttrs, pastSequenceLength: number) => {
|
||||
(context: ComputeContext, q: TensorView, key: TensorView, pastKey: TensorView|undefined,
|
||||
relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs,
|
||||
pastSequenceLength: number) => {
|
||||
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
|
||||
const probsShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, totalSequenceLength];
|
||||
const presentKey = parameters.kvNumHeads === undefined && context.outputCount > 1;
|
||||
const presentKeyShape = presentKey ?
|
||||
[parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize] :
|
||||
undefined;
|
||||
|
||||
// TODO: handle mask
|
||||
|
||||
|
|
@ -355,34 +359,51 @@ const createAttentionProbsProgramInfo =
|
|||
const programUniforms: ProgramUniform[] = [
|
||||
{type: DataType.uint32, data: parameters.sequenceLength}, {type: DataType.uint32, data: vectorizedHeadSize},
|
||||
{type: DataType.uint32, data: totalSequenceLength}, {type: DataType.uint32, data: parameters.numHeads},
|
||||
{type: DataType.float, data: alpha}
|
||||
{type: DataType.float, data: alpha}, {type: DataType.uint32, data: pastSequenceLength},
|
||||
{type: DataType.uint32, data: parameters.kvSequenceLength}
|
||||
];
|
||||
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] =
|
||||
relativePositionBias ? ['type', 'type', 'type'] : ['type', 'type'];
|
||||
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
|
||||
if (pastKey) {
|
||||
inputDependencies.push('type');
|
||||
}
|
||||
if (relativePositionBias) {
|
||||
inputDependencies.push('type');
|
||||
}
|
||||
const outputs = [{dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default}];
|
||||
if (presentKey) {
|
||||
outputs.push({dims: presentKeyShape!, dataType: q.dataType, gpuDataType: GpuDataType.default});
|
||||
}
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
const qInput = inputVariable('q', q.dataType, q.dims, components);
|
||||
const kInput = inputVariable('key', key.dataType, key.dims, components);
|
||||
const inputVars = [qInput, kInput];
|
||||
if (pastKey) {
|
||||
const pastKeyInput = inputVariable('past_key', pastKey.dataType, pastKey.dims, components);
|
||||
inputVars.push(pastKeyInput);
|
||||
}
|
||||
if (relativePositionBias) {
|
||||
inputVars.push(
|
||||
inputVariable('relative_position_bias', relativePositionBias.dataType, relativePositionBias.dims));
|
||||
}
|
||||
const output = outputVariable('output', q.dataType, probsShape);
|
||||
// const dataType = tensorTypeToWsglStorageType(q.dataType);
|
||||
const outputVars = [output];
|
||||
if (presentKey) {
|
||||
outputVars.push(outputVariable('present_key', q.dataType, presentKeyShape!, components));
|
||||
}
|
||||
const f32Type = tensorTypeToWsglValueType(DataType.float, components);
|
||||
|
||||
const uniforms: UniformsArrayType = [
|
||||
{name: 'M', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'},
|
||||
{name: 'num_heads', type: 'u32'}, {name: 'alpha', type: 'f32' as UniformDataElementType}
|
||||
{name: 'num_heads', type: 'u32'}, {name: 'alpha', type: 'f32' as UniformDataElementType},
|
||||
{name: 'past_sequence_length', type: 'u32'}, {name: 'kv_sequence_length', type: 'u32'}
|
||||
];
|
||||
return `
|
||||
const TILE_SIZE = ${TILE_SIZE}u;
|
||||
|
||||
var<workgroup> tileQ: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>;
|
||||
var<workgroup> tileK: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>;
|
||||
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, output)}
|
||||
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, ...outputVars)}
|
||||
${shaderHelper.mainStart([
|
||||
TILE_SIZE, TILE_SIZE, 1
|
||||
])}
|
||||
|
|
@ -391,15 +412,41 @@ const createAttentionProbsProgramInfo =
|
|||
let m = workgroup_id.y * TILE_SIZE;
|
||||
let n = workgroup_id.x * TILE_SIZE;
|
||||
let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K;
|
||||
let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K;
|
||||
|
||||
${(() => {
|
||||
if (pastKey && presentKey) {
|
||||
return `
|
||||
let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx;
|
||||
let pastKeyOffset = uniforms.past_sequence_length * uniforms.K * headIdx;`;
|
||||
} else {
|
||||
return `
|
||||
let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K;`;
|
||||
}
|
||||
})()}
|
||||
${presentKey ? 'let presentKeyOffset = headIdx * uniforms.N * uniforms.K;' : ''}
|
||||
var value = ${f32Type}(0);
|
||||
for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
|
||||
if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {
|
||||
tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];
|
||||
}
|
||||
if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {
|
||||
tileK[TILE_SIZE * local_id.y + local_id.x] = key[kOffset + local_id.y * uniforms.K + w + local_id.x];
|
||||
var idx = TILE_SIZE * local_id.y + local_id.x;
|
||||
${(() => {
|
||||
if (pastKey && presentKey) {
|
||||
return `
|
||||
if (n + local_id.y < uniforms.past_sequence_length) {
|
||||
tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
|
||||
} else {
|
||||
tileK[idx] =
|
||||
key[kOffset + (n + local_id.y - uniforms.past_sequence_length) * uniforms.K + w + local_id.x];
|
||||
}`;
|
||||
} else {
|
||||
return 'tileK[idx] = key[kOffset + local_id.y * uniforms.K + w + local_id.x];';
|
||||
}
|
||||
})()}
|
||||
${
|
||||
presentKey ?
|
||||
'present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];' :
|
||||
''}
|
||||
}
|
||||
workgroupBarrier();
|
||||
|
||||
|
|
@ -432,23 +479,25 @@ const createAttentionProbsProgramInfo =
|
|||
};
|
||||
return {
|
||||
name: 'AttentionProbs',
|
||||
shaderCache: {hint: `${components}`, inputDependencies},
|
||||
getRunData: () => ({
|
||||
outputs: [{dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default}],
|
||||
dispatchGroup: dispatch,
|
||||
programUniforms
|
||||
}),
|
||||
shaderCache: {
|
||||
hint: `${components};${relativePositionBias !== undefined};${pastKey !== undefined};${context.outputCount}`,
|
||||
inputDependencies
|
||||
},
|
||||
getRunData: () => ({outputs, dispatchGroup: dispatch, programUniforms}),
|
||||
getShaderSource,
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
const createVxAttentionScoreProgramInfo =
|
||||
(_context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters,
|
||||
pastSequenceLength: number) => {
|
||||
(context: ComputeContext, probs: TensorView, v: TensorView, pastValue: TensorView|undefined,
|
||||
params: AttentionParameters, pastSequenceLength: number) => {
|
||||
const totalSequenceLength = pastSequenceLength + params.kvSequenceLength;
|
||||
const nReps = params.nReps ? params.nReps : 1;
|
||||
const repeatedVHiddenSize = params.vHiddenSize * nReps;
|
||||
const presentValue = params.kvNumHeads == null && context.outputCount > 1;
|
||||
const presentValueShape =
|
||||
presentValue ? [params.batchSize, params.numHeads, totalSequenceLength, params.headSize] : undefined;
|
||||
const outputShape = [params.batchSize, params.sequenceLength, repeatedVHiddenSize];
|
||||
const TILE_SIZE = 12;
|
||||
const dispatch = {
|
||||
|
|
@ -460,23 +509,37 @@ const createVxAttentionScoreProgramInfo =
|
|||
const programUniforms: ProgramUniform[] = [
|
||||
{type: DataType.uint32, data: params.sequenceLength}, {type: DataType.uint32, data: totalSequenceLength},
|
||||
{type: DataType.uint32, data: params.vHeadSize}, {type: DataType.uint32, data: params.numHeads},
|
||||
{type: DataType.uint32, data: repeatedVHiddenSize}
|
||||
{type: DataType.uint32, data: repeatedVHiddenSize}, {type: DataType.uint32, data: pastSequenceLength},
|
||||
{type: DataType.uint32, data: params.kvSequenceLength}
|
||||
];
|
||||
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] =
|
||||
pastValue ? ['type', 'type', 'type'] : ['type', 'type'];
|
||||
const outputs = [{dims: outputShape, dataType: probs.dataType, gpuDataType: GpuDataType.default}];
|
||||
if (presentValue) {
|
||||
outputs.push({dims: presentValueShape!, dataType: probs.dataType, gpuDataType: GpuDataType.default});
|
||||
}
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
const probsHelper = inputVariable('probs', probs.dataType, probs.dims);
|
||||
const vHelper = inputVariable('v', v.dataType, v.dims);
|
||||
const inputVars = [probsHelper, vHelper];
|
||||
if (pastValue) {
|
||||
inputVars.push(inputVariable('past_value', pastValue.dataType, pastValue.dims));
|
||||
}
|
||||
const output = outputVariable('output', probs.dataType, outputShape);
|
||||
const outputVars = [output];
|
||||
if (presentValue) {
|
||||
outputVars.push(outputVariable('present_value', probs.dataType, presentValueShape!));
|
||||
}
|
||||
const uniforms: UniformsArrayType = [
|
||||
{name: 'M', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'},
|
||||
{name: 'num_heads', type: 'u32'}, {name: 'v_hidden_size', type: 'u32'}
|
||||
{name: 'num_heads', type: 'u32'}, {name: 'v_hidden_size', type: 'u32'},
|
||||
{name: 'past_sequence_length', type: 'u32'}, {name: 'kv_sequence_length', type: 'u32'}
|
||||
];
|
||||
return `
|
||||
const TILE_SIZE = ${TILE_SIZE}u;
|
||||
var<workgroup> tileQ: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>;
|
||||
var<workgroup> tileK: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>;
|
||||
${shaderHelper.registerUniforms(uniforms).declareVariables(probsHelper, vHelper, output)}
|
||||
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, ...outputVars)}
|
||||
${shaderHelper.mainStart([
|
||||
TILE_SIZE, TILE_SIZE, 1
|
||||
])}
|
||||
|
|
@ -485,16 +548,43 @@ const createVxAttentionScoreProgramInfo =
|
|||
let n = global_id.x;
|
||||
|
||||
let offsetA = headIdx * (uniforms.M * uniforms.K) + m * uniforms.K;
|
||||
let offsetB = headIdx * (uniforms.N * uniforms.K) + n;
|
||||
|
||||
${(() => {
|
||||
if (pastValue && presentValue) {
|
||||
return `
|
||||
let pastValueOffset = headIdx * uniforms.N * uniforms.past_sequence_length + n;
|
||||
let vOffset = headIdx * uniforms.N * uniforms.kv_sequence_length + n;
|
||||
`;
|
||||
} else {
|
||||
return `
|
||||
let offsetB = headIdx * uniforms.N * uniforms.K + n;
|
||||
`;
|
||||
}
|
||||
})()}
|
||||
${presentValue ? 'let presentValueOffset = headIdx * uniforms.N * uniforms.K + n;' : ''}
|
||||
var value = ${probsHelper.type.storage}(0);
|
||||
for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
|
||||
if (m < uniforms.M && w + local_id.x < uniforms.K) {
|
||||
tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x];
|
||||
}
|
||||
if (n < uniforms.N && w + local_id.y < uniforms.K) {
|
||||
tileK[TILE_SIZE * local_id.y + local_id.x] = v[offsetB + (w + local_id.y) * uniforms.N];
|
||||
}
|
||||
if (m < uniforms.M && w + local_id.x < uniforms.K) {
|
||||
tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x];
|
||||
}
|
||||
if (n < uniforms.N && w + local_id.y < uniforms.K) {
|
||||
var idx = TILE_SIZE * local_id.y + local_id.x;
|
||||
${(() => {
|
||||
if (pastValue && presentValue) {
|
||||
return `
|
||||
if (w + local_id.y < uniforms.past_sequence_length) {
|
||||
tileK[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N];
|
||||
} else {
|
||||
tileK[idx] = v[vOffset + (w + local_id.y - uniforms.past_sequence_length) * uniforms.N];
|
||||
}
|
||||
`;
|
||||
} else {
|
||||
return `
|
||||
tileK[idx] = v[offsetB + (w + local_id.y) * uniforms.N];
|
||||
`;
|
||||
}
|
||||
})()}
|
||||
${presentValue ? 'present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];' : ''}
|
||||
}
|
||||
workgroupBarrier();
|
||||
for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {
|
||||
value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x];
|
||||
|
|
@ -515,12 +605,8 @@ const createVxAttentionScoreProgramInfo =
|
|||
|
||||
return {
|
||||
name: 'AttentionScore',
|
||||
shaderCache: {inputDependencies},
|
||||
getRunData: () => ({
|
||||
outputs: [{dims: outputShape, dataType: probs.dataType, gpuDataType: GpuDataType.default}],
|
||||
dispatchGroup: dispatch,
|
||||
programUniforms
|
||||
}),
|
||||
shaderCache: {hint: `${pastValue !== undefined};${context.outputCount}`, inputDependencies},
|
||||
getRunData: () => ({outputs, dispatchGroup: dispatch, programUniforms}),
|
||||
getShaderSource,
|
||||
};
|
||||
};
|
||||
|
|
@ -529,29 +615,12 @@ export const applyAttention =
|
|||
(context: ComputeContext, q: TensorView, k: TensorView, v: TensorView, _maskIndex: TensorView|undefined,
|
||||
_past: TensorView|undefined, pastKey: TensorView|undefined, pastValue: TensorView|undefined,
|
||||
relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs) => {
|
||||
const outputPresentKey = context.outputCount > 1;
|
||||
const outputPresentValue = context.outputCount > 2;
|
||||
const outputCount = context.outputCount;
|
||||
const pastSequenceLength =
|
||||
parameters.kvNumHeads != null || (outputPresentKey && outputPresentValue) ? parameters.pastSequenceLength : 0;
|
||||
parameters.kvNumHeads !== undefined || outputCount > 1 ? parameters.pastSequenceLength : 0;
|
||||
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
|
||||
// Concatinate pastKey and K to produce presentKey.
|
||||
const presentKeyShape = [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize];
|
||||
const concatKeyInputs = pastKey ? [pastKey, k] : [k];
|
||||
const key = parameters.kvNumHeads == null && outputPresentKey ?
|
||||
context.compute(
|
||||
createConcatProgramInfo(concatKeyInputs, 2, presentKeyShape, k.dataType),
|
||||
{inputs: concatKeyInputs, outputs: [1]})[0] :
|
||||
k;
|
||||
|
||||
// Concatinate pastValue and V to produce presentValue.
|
||||
const presentValueShape = [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize];
|
||||
const concatValueInputs = pastValue ? [pastValue, v] : [v];
|
||||
const value = parameters.kvNumHeads == null && outputPresentValue ?
|
||||
context.compute(
|
||||
createConcatProgramInfo(concatValueInputs, 2, presentValueShape, v.dataType),
|
||||
{inputs: concatValueInputs, outputs: [2]})[0] :
|
||||
v;
|
||||
const inputsK = [q, key];
|
||||
const inputsK = (parameters.kvNumHeads === undefined && outputCount > 1 && pastKey) ? [q, k, pastKey] : [q, k];
|
||||
if (relativePositionBias) {
|
||||
inputsK.push(relativePositionBias);
|
||||
}
|
||||
|
|
@ -559,8 +628,9 @@ export const applyAttention =
|
|||
// Run AttentionProbs
|
||||
const probs = context.compute(
|
||||
createAttentionProbsProgramInfo(
|
||||
context, q, key, relativePositionBias, parameters, attributes, pastSequenceLength),
|
||||
{inputs: inputsK, outputs: [-1]})[0];
|
||||
context, q, k, outputCount > 1 ? pastKey : undefined, relativePositionBias, parameters, attributes,
|
||||
pastSequenceLength),
|
||||
{inputs: inputsK, outputs: (parameters.kvNumHeads === undefined && outputCount > 1) ? [-1, 1] : [-1]})[0];
|
||||
|
||||
// Run Softmax
|
||||
context.compute(
|
||||
|
|
@ -570,10 +640,12 @@ export const applyAttention =
|
|||
{inputs: [probs], outputs: []});
|
||||
|
||||
// Run AttrionScore
|
||||
const inputsV = [probs, value];
|
||||
const inputsV =
|
||||
(parameters.kvNumHeads === undefined && outputCount > 1 && pastValue) ? [probs, v, pastValue] : [probs, v];
|
||||
context.compute(
|
||||
createVxAttentionScoreProgramInfo(context, probs, value, parameters, pastSequenceLength),
|
||||
{inputs: inputsV, outputs: [0]});
|
||||
createVxAttentionScoreProgramInfo(
|
||||
context, probs, v, outputCount > 1 && pastValue ? pastValue : undefined, parameters, pastSequenceLength),
|
||||
{inputs: inputsV, outputs: (parameters.kvNumHeads === undefined && outputCount > 1) ? [0, 2] : [0]});
|
||||
};
|
||||
|
||||
const prepare = (context: ComputeContext, parameters: AttentionParameters) => {
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ const assignOutputData = (inputs: readonly IndicesHelper[], output: IndicesHelpe
|
|||
return codeLines.join('\n');
|
||||
};
|
||||
|
||||
export const createConcatProgramInfo =
|
||||
const createConcatProgramInfo =
|
||||
(inputs: readonly TensorView[], adjustedAxis: number, outputShape: number[], dataType: DataType): ProgramInfo => {
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
|
||||
|
|
|
|||
|
|
@ -322,7 +322,7 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"name": "MultiHeadAttention Basic, one head and head-size=1 with pastKey and pastValue",
|
||||
"name": "MultiHeadAttention Basic, one head and head-size=1 with optional RelativePositionBias, pastKey, pastValue inputs and optional presentKey, presentValue outputs",
|
||||
"operator": "MultiHeadAttention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
|
||||
|
|
@ -397,7 +397,7 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"name": "MultiHeadAttention Basic, one head and head-size=4 with pastKey and pastValue",
|
||||
"name": "MultiHeadAttention Basic, one head and head-size=4 with RelativePositionBias, pastKey, pastValue inputs and optional presentKey, presentValue outputs",
|
||||
"operator": "MultiHeadAttention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
|
||||
|
|
@ -474,7 +474,7 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"name": "MultiHeadAttention Basic, one head and head-size=1 with pastKey and pastValue",
|
||||
"name": "MultiHeadAttention Basic, one head and head-size=1 with relativePositionBias, pastKey and pastValue",
|
||||
"operator": "MultiHeadAttention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
|
||||
|
|
@ -512,7 +512,8 @@
|
|||
},
|
||||
// RelativePositionBias
|
||||
{
|
||||
"data": null,
|
||||
"data": [10, 20],
|
||||
"dims": [1, 1, 1, 2],
|
||||
"type": "float32"
|
||||
},
|
||||
// PastKey
|
||||
|
|
@ -539,7 +540,7 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"name": "MultiHeadAttention Basic, one head and head-size=4 with pastKey and pastValue",
|
||||
"name": "MultiHeadAttention Basic, one head and head-size=4 with relativePositionBias, and pastValue",
|
||||
"operator": "MultiHeadAttention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
|
||||
|
|
@ -577,7 +578,8 @@
|
|||
},
|
||||
// RelativePositionBias
|
||||
{
|
||||
"data": null,
|
||||
"data": [100, 200],
|
||||
"dims": [1, 1, 1, 2],
|
||||
"type": "float32"
|
||||
},
|
||||
// PastKey
|
||||
|
|
@ -841,7 +843,7 @@
|
|||
]
|
||||
},
|
||||
{
|
||||
"name": "MultiHeadAttention Basic, one head and head-size=4 with RelativePositionBias, PastKey, PastValue, PresentKey and PresentValue",
|
||||
"name": "MultiHeadAttention Basic, one head and head-size=4 with RelativePositionBias, PastKey, PastValue inputs and PresentKey and PresentValue outputs",
|
||||
"operator": "MultiHeadAttention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
|
||||
|
|
@ -976,7 +978,7 @@
|
|||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [3.0006706714630127],
|
||||
"data": [3],
|
||||
"dims": [1, 1, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
|
|
@ -1052,7 +1054,7 @@
|
|||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [9.000362396240234, 10.00036334991455, 11.000362396240234, 12.000362396240234],
|
||||
"data": [9, 10, 11, 12],
|
||||
"dims": [1, 1, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
|
|
|
|||
Loading…
Reference in a new issue