Eliminate explicit Concat operations in Attention (#20556)

### Description
Remove explicitly concatinating pastKey with Key and pastValue with
Value.



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
Satya Kumar Jandhyala 2024-05-24 09:07:57 -07:00 committed by GitHub
parent 535a030b1e
commit bab5037eab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 149 additions and 75 deletions

View file

@ -6,7 +6,6 @@ import {TensorView} from '../../tensor-view';
import {ComputeContext, GpuDataType, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';
import {getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, tensorTypeToWsglValueType, UniformDataElementType, UniformsArrayType} from './common';
import {createConcatProgramInfo} from './concat';
export const enum AttentionQkvFormat {
unknown, // enum value not set, or depends on qkv projection implementation details
@ -336,10 +335,15 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor
};
const createAttentionProbsProgramInfo =
(_context: ComputeContext, q: TensorView, key: TensorView, relativePositionBias: TensorView|undefined,
parameters: AttentionParameters, attributes: AttentionAttrs, pastSequenceLength: number) => {
(context: ComputeContext, q: TensorView, key: TensorView, pastKey: TensorView|undefined,
relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs,
pastSequenceLength: number) => {
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
const probsShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, totalSequenceLength];
const presentKey = parameters.kvNumHeads === undefined && context.outputCount > 1;
const presentKeyShape = presentKey ?
[parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize] :
undefined;
// TODO: handle mask
@ -355,34 +359,51 @@ const createAttentionProbsProgramInfo =
const programUniforms: ProgramUniform[] = [
{type: DataType.uint32, data: parameters.sequenceLength}, {type: DataType.uint32, data: vectorizedHeadSize},
{type: DataType.uint32, data: totalSequenceLength}, {type: DataType.uint32, data: parameters.numHeads},
{type: DataType.float, data: alpha}
{type: DataType.float, data: alpha}, {type: DataType.uint32, data: pastSequenceLength},
{type: DataType.uint32, data: parameters.kvSequenceLength}
];
const inputDependencies: ProgramInputTensorInfoDependency[] =
relativePositionBias ? ['type', 'type', 'type'] : ['type', 'type'];
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
if (pastKey) {
inputDependencies.push('type');
}
if (relativePositionBias) {
inputDependencies.push('type');
}
const outputs = [{dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default}];
if (presentKey) {
outputs.push({dims: presentKeyShape!, dataType: q.dataType, gpuDataType: GpuDataType.default});
}
const getShaderSource = (shaderHelper: ShaderHelper) => {
const qInput = inputVariable('q', q.dataType, q.dims, components);
const kInput = inputVariable('key', key.dataType, key.dims, components);
const inputVars = [qInput, kInput];
if (pastKey) {
const pastKeyInput = inputVariable('past_key', pastKey.dataType, pastKey.dims, components);
inputVars.push(pastKeyInput);
}
if (relativePositionBias) {
inputVars.push(
inputVariable('relative_position_bias', relativePositionBias.dataType, relativePositionBias.dims));
}
const output = outputVariable('output', q.dataType, probsShape);
// const dataType = tensorTypeToWsglStorageType(q.dataType);
const outputVars = [output];
if (presentKey) {
outputVars.push(outputVariable('present_key', q.dataType, presentKeyShape!, components));
}
const f32Type = tensorTypeToWsglValueType(DataType.float, components);
const uniforms: UniformsArrayType = [
{name: 'M', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'},
{name: 'num_heads', type: 'u32'}, {name: 'alpha', type: 'f32' as UniformDataElementType}
{name: 'num_heads', type: 'u32'}, {name: 'alpha', type: 'f32' as UniformDataElementType},
{name: 'past_sequence_length', type: 'u32'}, {name: 'kv_sequence_length', type: 'u32'}
];
return `
const TILE_SIZE = ${TILE_SIZE}u;
var<workgroup> tileQ: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>;
var<workgroup> tileK: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>;
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, output)}
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, ...outputVars)}
${shaderHelper.mainStart([
TILE_SIZE, TILE_SIZE, 1
])}
@ -391,15 +412,41 @@ const createAttentionProbsProgramInfo =
let m = workgroup_id.y * TILE_SIZE;
let n = workgroup_id.x * TILE_SIZE;
let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K;
let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K;
${(() => {
if (pastKey && presentKey) {
return `
let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx;
let pastKeyOffset = uniforms.past_sequence_length * uniforms.K * headIdx;`;
} else {
return `
let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K;`;
}
})()}
${presentKey ? 'let presentKeyOffset = headIdx * uniforms.N * uniforms.K;' : ''}
var value = ${f32Type}(0);
for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {
tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];
}
if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {
tileK[TILE_SIZE * local_id.y + local_id.x] = key[kOffset + local_id.y * uniforms.K + w + local_id.x];
var idx = TILE_SIZE * local_id.y + local_id.x;
${(() => {
if (pastKey && presentKey) {
return `
if (n + local_id.y < uniforms.past_sequence_length) {
tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
} else {
tileK[idx] =
key[kOffset + (n + local_id.y - uniforms.past_sequence_length) * uniforms.K + w + local_id.x];
}`;
} else {
return 'tileK[idx] = key[kOffset + local_id.y * uniforms.K + w + local_id.x];';
}
})()}
${
presentKey ?
'present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];' :
''}
}
workgroupBarrier();
@ -432,23 +479,25 @@ const createAttentionProbsProgramInfo =
};
return {
name: 'AttentionProbs',
shaderCache: {hint: `${components}`, inputDependencies},
getRunData: () => ({
outputs: [{dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default}],
dispatchGroup: dispatch,
programUniforms
}),
shaderCache: {
hint: `${components};${relativePositionBias !== undefined};${pastKey !== undefined};${context.outputCount}`,
inputDependencies
},
getRunData: () => ({outputs, dispatchGroup: dispatch, programUniforms}),
getShaderSource,
};
};
const createVxAttentionScoreProgramInfo =
(_context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters,
pastSequenceLength: number) => {
(context: ComputeContext, probs: TensorView, v: TensorView, pastValue: TensorView|undefined,
params: AttentionParameters, pastSequenceLength: number) => {
const totalSequenceLength = pastSequenceLength + params.kvSequenceLength;
const nReps = params.nReps ? params.nReps : 1;
const repeatedVHiddenSize = params.vHiddenSize * nReps;
const presentValue = params.kvNumHeads == null && context.outputCount > 1;
const presentValueShape =
presentValue ? [params.batchSize, params.numHeads, totalSequenceLength, params.headSize] : undefined;
const outputShape = [params.batchSize, params.sequenceLength, repeatedVHiddenSize];
const TILE_SIZE = 12;
const dispatch = {
@ -460,23 +509,37 @@ const createVxAttentionScoreProgramInfo =
const programUniforms: ProgramUniform[] = [
{type: DataType.uint32, data: params.sequenceLength}, {type: DataType.uint32, data: totalSequenceLength},
{type: DataType.uint32, data: params.vHeadSize}, {type: DataType.uint32, data: params.numHeads},
{type: DataType.uint32, data: repeatedVHiddenSize}
{type: DataType.uint32, data: repeatedVHiddenSize}, {type: DataType.uint32, data: pastSequenceLength},
{type: DataType.uint32, data: params.kvSequenceLength}
];
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
const inputDependencies: ProgramInputTensorInfoDependency[] =
pastValue ? ['type', 'type', 'type'] : ['type', 'type'];
const outputs = [{dims: outputShape, dataType: probs.dataType, gpuDataType: GpuDataType.default}];
if (presentValue) {
outputs.push({dims: presentValueShape!, dataType: probs.dataType, gpuDataType: GpuDataType.default});
}
const getShaderSource = (shaderHelper: ShaderHelper) => {
const probsHelper = inputVariable('probs', probs.dataType, probs.dims);
const vHelper = inputVariable('v', v.dataType, v.dims);
const inputVars = [probsHelper, vHelper];
if (pastValue) {
inputVars.push(inputVariable('past_value', pastValue.dataType, pastValue.dims));
}
const output = outputVariable('output', probs.dataType, outputShape);
const outputVars = [output];
if (presentValue) {
outputVars.push(outputVariable('present_value', probs.dataType, presentValueShape!));
}
const uniforms: UniformsArrayType = [
{name: 'M', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'},
{name: 'num_heads', type: 'u32'}, {name: 'v_hidden_size', type: 'u32'}
{name: 'num_heads', type: 'u32'}, {name: 'v_hidden_size', type: 'u32'},
{name: 'past_sequence_length', type: 'u32'}, {name: 'kv_sequence_length', type: 'u32'}
];
return `
const TILE_SIZE = ${TILE_SIZE}u;
var<workgroup> tileQ: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>;
var<workgroup> tileK: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>;
${shaderHelper.registerUniforms(uniforms).declareVariables(probsHelper, vHelper, output)}
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, ...outputVars)}
${shaderHelper.mainStart([
TILE_SIZE, TILE_SIZE, 1
])}
@ -485,16 +548,43 @@ const createVxAttentionScoreProgramInfo =
let n = global_id.x;
let offsetA = headIdx * (uniforms.M * uniforms.K) + m * uniforms.K;
let offsetB = headIdx * (uniforms.N * uniforms.K) + n;
${(() => {
if (pastValue && presentValue) {
return `
let pastValueOffset = headIdx * uniforms.N * uniforms.past_sequence_length + n;
let vOffset = headIdx * uniforms.N * uniforms.kv_sequence_length + n;
`;
} else {
return `
let offsetB = headIdx * uniforms.N * uniforms.K + n;
`;
}
})()}
${presentValue ? 'let presentValueOffset = headIdx * uniforms.N * uniforms.K + n;' : ''}
var value = ${probsHelper.type.storage}(0);
for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
if (m < uniforms.M && w + local_id.x < uniforms.K) {
tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x];
}
if (n < uniforms.N && w + local_id.y < uniforms.K) {
tileK[TILE_SIZE * local_id.y + local_id.x] = v[offsetB + (w + local_id.y) * uniforms.N];
}
if (m < uniforms.M && w + local_id.x < uniforms.K) {
tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x];
}
if (n < uniforms.N && w + local_id.y < uniforms.K) {
var idx = TILE_SIZE * local_id.y + local_id.x;
${(() => {
if (pastValue && presentValue) {
return `
if (w + local_id.y < uniforms.past_sequence_length) {
tileK[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N];
} else {
tileK[idx] = v[vOffset + (w + local_id.y - uniforms.past_sequence_length) * uniforms.N];
}
`;
} else {
return `
tileK[idx] = v[offsetB + (w + local_id.y) * uniforms.N];
`;
}
})()}
${presentValue ? 'present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];' : ''}
}
workgroupBarrier();
for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {
value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x];
@ -515,12 +605,8 @@ const createVxAttentionScoreProgramInfo =
return {
name: 'AttentionScore',
shaderCache: {inputDependencies},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: probs.dataType, gpuDataType: GpuDataType.default}],
dispatchGroup: dispatch,
programUniforms
}),
shaderCache: {hint: `${pastValue !== undefined};${context.outputCount}`, inputDependencies},
getRunData: () => ({outputs, dispatchGroup: dispatch, programUniforms}),
getShaderSource,
};
};
@ -529,29 +615,12 @@ export const applyAttention =
(context: ComputeContext, q: TensorView, k: TensorView, v: TensorView, _maskIndex: TensorView|undefined,
_past: TensorView|undefined, pastKey: TensorView|undefined, pastValue: TensorView|undefined,
relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs) => {
const outputPresentKey = context.outputCount > 1;
const outputPresentValue = context.outputCount > 2;
const outputCount = context.outputCount;
const pastSequenceLength =
parameters.kvNumHeads != null || (outputPresentKey && outputPresentValue) ? parameters.pastSequenceLength : 0;
parameters.kvNumHeads !== undefined || outputCount > 1 ? parameters.pastSequenceLength : 0;
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
// Concatinate pastKey and K to produce presentKey.
const presentKeyShape = [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize];
const concatKeyInputs = pastKey ? [pastKey, k] : [k];
const key = parameters.kvNumHeads == null && outputPresentKey ?
context.compute(
createConcatProgramInfo(concatKeyInputs, 2, presentKeyShape, k.dataType),
{inputs: concatKeyInputs, outputs: [1]})[0] :
k;
// Concatinate pastValue and V to produce presentValue.
const presentValueShape = [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize];
const concatValueInputs = pastValue ? [pastValue, v] : [v];
const value = parameters.kvNumHeads == null && outputPresentValue ?
context.compute(
createConcatProgramInfo(concatValueInputs, 2, presentValueShape, v.dataType),
{inputs: concatValueInputs, outputs: [2]})[0] :
v;
const inputsK = [q, key];
const inputsK = (parameters.kvNumHeads === undefined && outputCount > 1 && pastKey) ? [q, k, pastKey] : [q, k];
if (relativePositionBias) {
inputsK.push(relativePositionBias);
}
@ -559,8 +628,9 @@ export const applyAttention =
// Run AttentionProbs
const probs = context.compute(
createAttentionProbsProgramInfo(
context, q, key, relativePositionBias, parameters, attributes, pastSequenceLength),
{inputs: inputsK, outputs: [-1]})[0];
context, q, k, outputCount > 1 ? pastKey : undefined, relativePositionBias, parameters, attributes,
pastSequenceLength),
{inputs: inputsK, outputs: (parameters.kvNumHeads === undefined && outputCount > 1) ? [-1, 1] : [-1]})[0];
// Run Softmax
context.compute(
@ -570,10 +640,12 @@ export const applyAttention =
{inputs: [probs], outputs: []});
// Run AttrionScore
const inputsV = [probs, value];
const inputsV =
(parameters.kvNumHeads === undefined && outputCount > 1 && pastValue) ? [probs, v, pastValue] : [probs, v];
context.compute(
createVxAttentionScoreProgramInfo(context, probs, value, parameters, pastSequenceLength),
{inputs: inputsV, outputs: [0]});
createVxAttentionScoreProgramInfo(
context, probs, v, outputCount > 1 && pastValue ? pastValue : undefined, parameters, pastSequenceLength),
{inputs: inputsV, outputs: (parameters.kvNumHeads === undefined && outputCount > 1) ? [0, 2] : [0]});
};
const prepare = (context: ComputeContext, parameters: AttentionParameters) => {

View file

@ -71,7 +71,7 @@ const assignOutputData = (inputs: readonly IndicesHelper[], output: IndicesHelpe
return codeLines.join('\n');
};
export const createConcatProgramInfo =
const createConcatProgramInfo =
(inputs: readonly TensorView[], adjustedAxis: number, outputShape: number[], dataType: DataType): ProgramInfo => {
const outputSize = ShapeUtil.size(outputShape);

View file

@ -322,7 +322,7 @@
]
},
{
"name": "MultiHeadAttention Basic, one head and head-size=1 with pastKey and pastValue",
"name": "MultiHeadAttention Basic, one head and head-size=1 with optional RelativePositionBias, pastKey, pastValue inputs and optional presentKey, presentValue outputs",
"operator": "MultiHeadAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
@ -397,7 +397,7 @@
]
},
{
"name": "MultiHeadAttention Basic, one head and head-size=4 with pastKey and pastValue",
"name": "MultiHeadAttention Basic, one head and head-size=4 with RelativePositionBias, pastKey, pastValue inputs and optional presentKey, presentValue outputs",
"operator": "MultiHeadAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
@ -474,7 +474,7 @@
]
},
{
"name": "MultiHeadAttention Basic, one head and head-size=1 with pastKey and pastValue",
"name": "MultiHeadAttention Basic, one head and head-size=1 with relativePositionBias, pastKey and pastValue",
"operator": "MultiHeadAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
@ -512,7 +512,8 @@
},
// RelativePositionBias
{
"data": null,
"data": [10, 20],
"dims": [1, 1, 1, 2],
"type": "float32"
},
// PastKey
@ -539,7 +540,7 @@
]
},
{
"name": "MultiHeadAttention Basic, one head and head-size=4 with pastKey and pastValue",
"name": "MultiHeadAttention Basic, one head and head-size=4 with relativePositionBias, and pastValue",
"operator": "MultiHeadAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
@ -577,7 +578,8 @@
},
// RelativePositionBias
{
"data": null,
"data": [100, 200],
"dims": [1, 1, 1, 2],
"type": "float32"
},
// PastKey
@ -841,7 +843,7 @@
]
},
{
"name": "MultiHeadAttention Basic, one head and head-size=4 with RelativePositionBias, PastKey, PastValue, PresentKey and PresentValue",
"name": "MultiHeadAttention Basic, one head and head-size=4 with RelativePositionBias, PastKey, PastValue inputs and PresentKey and PresentValue outputs",
"operator": "MultiHeadAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
@ -976,7 +978,7 @@
],
"outputs": [
{
"data": [3.0006706714630127],
"data": [3],
"dims": [1, 1, 1],
"type": "float32"
},
@ -1052,7 +1054,7 @@
],
"outputs": [
{
"data": [9.000362396240234, 10.00036334991455, 11.000362396240234, 12.000362396240234],
"data": [9, 10, 11, 12],
"dims": [1, 1, 4],
"type": "float32"
},