[JS/WebGPU] Avoid producing presentKey/presentValue outputs if pastKey/pastValue … (#21782)

Avoid producing presentKey/presentValue outputs if pastKey/pastValue
don't exists.

### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
Satya Kumar Jandhyala 2024-08-19 18:02:19 -07:00 committed by GitHub
parent a22cc078b4
commit 1fb2e71ddc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 118 additions and 36 deletions

View file

@ -3,6 +3,7 @@
import { DataType } from '../../../wasm-common';
import { TensorView } from '../../tensor-view';
import { ShapeUtil } from '../../util';
import { ComputeContext, GpuDataType, ProgramInputTensorInfoDependency, ProgramUniform } from '../types';
import {
@ -257,7 +258,7 @@ const validateAttentionInputs = (inputs: readonly TensorView[], attributes: Atte
};
};
const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: TensorView, n: number, d: number) => {
const createInPlaceSoftmaxProgramInfo = (input: TensorView, n: number, d: number) => {
const components = getMaxComponents(d);
let WG = 64;
const dComp = d / components;
@ -358,7 +359,7 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor
};
const createAttentionProbsProgramInfo = (
context: ComputeContext,
outputCount: number,
q: TensorView,
key: TensorView,
pastKey: TensorView | undefined,
@ -369,7 +370,7 @@ const createAttentionProbsProgramInfo = (
) => {
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
const probsShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, totalSequenceLength];
const presentKey = parameters.kvNumHeads === undefined && context.outputCount > 1;
const presentKey = parameters.kvNumHeads === undefined && outputCount > 1 && pastKey;
const presentKeyShape = presentKey
? [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize]
: undefined;
@ -394,9 +395,10 @@ const createAttentionProbsProgramInfo = (
{ type: DataType.uint32, data: pastSequenceLength },
{ type: DataType.uint32, data: parameters.kvSequenceLength },
];
// Feed pastKey to the shader-code only if it is non-zero and presentKey is being produced
const feedPastKey = presentKey && pastKey && ShapeUtil.size(pastKey.dims) > 0;
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
if (pastKey) {
if (feedPastKey) {
inputDependencies.push('type');
}
if (attentionBias) {
@ -410,7 +412,7 @@ const createAttentionProbsProgramInfo = (
const qInput = inputVariable('q', q.dataType, q.dims, components);
const kInput = inputVariable('key', key.dataType, key.dims, components);
const inputVars = [qInput, kInput];
if (pastKey) {
if (feedPastKey) {
const pastKeyInput = inputVariable('past_key', pastKey.dataType, pastKey.dims, components);
inputVars.push(pastKeyInput);
}
@ -446,7 +448,7 @@ const createAttentionProbsProgramInfo = (
let n = workgroup_id.x * TILE_SIZE;
let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K;
${(() => {
if (pastKey && presentKey) {
if (feedPastKey && presentKey) {
return `
let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx;
let pastKeyOffset = uniforms.past_sequence_length * uniforms.K * headIdx;`;
@ -464,7 +466,7 @@ const createAttentionProbsProgramInfo = (
if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {
var idx = TILE_SIZE * local_id.y + local_id.x;
${(() => {
if (pastKey && presentKey) {
if (feedPastKey && presentKey) {
return `
if (n + local_id.y < uniforms.past_sequence_length) {
tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
@ -513,7 +515,7 @@ const createAttentionProbsProgramInfo = (
return {
name: 'AttentionProbs',
shaderCache: {
hint: `${components};${attentionBias !== undefined};${pastKey !== undefined};${context.outputCount}`,
hint: `${components};${attentionBias !== undefined};${pastKey !== undefined};${outputCount}`,
inputDependencies,
},
getRunData: () => ({ outputs, dispatchGroup: dispatch, programUniforms }),
@ -522,7 +524,7 @@ const createAttentionProbsProgramInfo = (
};
const createVxAttentionScoreProgramInfo = (
context: ComputeContext,
outputCount: number,
probs: TensorView,
v: TensorView,
pastValue: TensorView | undefined,
@ -532,7 +534,7 @@ const createVxAttentionScoreProgramInfo = (
const totalSequenceLength = pastSequenceLength + params.kvSequenceLength;
const nReps = params.nReps ? params.nReps : 1;
const repeatedVHiddenSize = params.vHiddenSize * nReps;
const presentValue = params.kvNumHeads == null && context.outputCount > 1;
const presentValue = params.kvNumHeads == null && outputCount > 1 && pastValue;
const presentValueShape = presentValue
? [params.batchSize, params.numHeads, totalSequenceLength, params.headSize]
: undefined;
@ -553,7 +555,12 @@ const createVxAttentionScoreProgramInfo = (
{ type: DataType.uint32, data: pastSequenceLength },
{ type: DataType.uint32, data: params.kvSequenceLength },
];
const inputDependencies: ProgramInputTensorInfoDependency[] = pastValue ? ['type', 'type', 'type'] : ['type', 'type'];
// Feed pastValue to the shader-code only if it is non-empty and presentValue is being produced
const feedPastValue = presentValue && pastValue && ShapeUtil.size(pastValue.dims) > 0;
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
if (feedPastValue) {
inputDependencies.push('type');
}
const outputs = [{ dims: outputShape, dataType: probs.dataType, gpuDataType: GpuDataType.default }];
if (presentValue) {
outputs.push({ dims: presentValueShape!, dataType: probs.dataType, gpuDataType: GpuDataType.default });
@ -562,7 +569,7 @@ const createVxAttentionScoreProgramInfo = (
const probsHelper = inputVariable('probs', probs.dataType, probs.dims);
const vHelper = inputVariable('v', v.dataType, v.dims);
const inputVars = [probsHelper, vHelper];
if (pastValue) {
if (feedPastValue) {
inputVars.push(inputVariable('past_value', pastValue.dataType, pastValue.dims));
}
const output = outputVariable('output', probs.dataType, outputShape);
@ -591,7 +598,7 @@ const createVxAttentionScoreProgramInfo = (
let offsetA = headIdx * (uniforms.M * uniforms.K) + m * uniforms.K;
${(() => {
if (pastValue && presentValue) {
if (feedPastValue && presentValue) {
return `
let pastValueOffset = headIdx * uniforms.N * uniforms.past_sequence_length + n;
let vOffset = headIdx * uniforms.N * uniforms.kv_sequence_length + n;
@ -611,7 +618,7 @@ const createVxAttentionScoreProgramInfo = (
if (n < uniforms.N && w + local_id.y < uniforms.K) {
var idx = TILE_SIZE * local_id.y + local_id.x;
${(() => {
if (pastValue && presentValue) {
if (feedPastValue && presentValue) {
return `
if (w + local_id.y < uniforms.past_sequence_length) {
tileK[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N];
@ -647,7 +654,7 @@ const createVxAttentionScoreProgramInfo = (
return {
name: 'AttentionScore',
shaderCache: { hint: `${pastValue !== undefined};${context.outputCount}`, inputDependencies },
shaderCache: { hint: `${pastValue !== undefined};${outputCount}`, inputDependencies },
getRunData: () => ({ outputs, dispatchGroup: dispatch, programUniforms }),
getShaderSource,
};
@ -662,15 +669,21 @@ export const applyAttention = (
_past: TensorView | undefined,
pastKey: TensorView | undefined,
pastValue: TensorView | undefined,
attentionBias: TensorView | undefined,
attentionBiasInput: TensorView | undefined,
parameters: AttentionParameters,
attributes: AttentionAttrs,
) => {
const outputCount = context.outputCount;
// Assumption is that presentKey/presentValue exists only if pastKey/pastValue exists.
const outputCount = Math.min(context.outputCount, 1 + (pastKey ? 1 : 0) + (pastValue ? 1 : 0));
const pastSequenceLength = parameters.kvNumHeads !== undefined || outputCount > 1 ? parameters.pastSequenceLength : 0;
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
const attentionBias =
attentionBiasInput && ShapeUtil.size(attentionBiasInput.dims) > 0 ? attentionBiasInput : undefined;
const inputsK = parameters.kvNumHeads === undefined && outputCount > 1 && pastKey ? [q, k, pastKey] : [q, k];
const inputsK = [q, k];
if (parameters.kvNumHeads === undefined && outputCount > 1 && pastKey && ShapeUtil.size(pastKey.dims) > 0) {
inputsK.push(pastKey);
}
if (attentionBias) {
inputsK.push(attentionBias);
}
@ -678,10 +691,10 @@ export const applyAttention = (
// Run AttentionProbs
const probs = context.compute(
createAttentionProbsProgramInfo(
context,
outputCount,
q,
k,
outputCount > 1 ? pastKey : undefined,
pastKey,
attentionBias,
parameters,
attributes,
@ -693,7 +706,6 @@ export const applyAttention = (
// Run Softmax
context.compute(
createInPlaceSoftmaxProgramInfo(
context,
probs,
parameters.batchSize * parameters.numHeads * parameters.sequenceLength,
totalSequenceLength,
@ -702,19 +714,14 @@ export const applyAttention = (
);
// Run AttrionScore
const inputsV =
parameters.kvNumHeads === undefined && outputCount > 1 && pastValue ? [probs, v, pastValue] : [probs, v];
context.compute(
createVxAttentionScoreProgramInfo(
context,
probs,
v,
outputCount > 1 && pastValue ? pastValue : undefined,
parameters,
pastSequenceLength,
),
{ inputs: inputsV, outputs: parameters.kvNumHeads === undefined && outputCount > 1 ? [0, 2] : [0] },
);
const inputsV = [probs, v];
if (parameters.kvNumHeads === undefined && outputCount > 1 && pastValue && ShapeUtil.size(pastValue.dims) > 0) {
inputsV.push(pastValue);
}
context.compute(createVxAttentionScoreProgramInfo(outputCount, probs, v, pastValue, parameters, pastSequenceLength), {
inputs: inputsV,
outputs: parameters.kvNumHeads === undefined && outputCount > 1 ? [0, 2] : [0],
});
};
const prepare = (context: ComputeContext, parameters: AttentionParameters) => {

View file

@ -18,7 +18,7 @@ import { inputVariable, outputVariable, ShaderHelper, UniformsArrayType } from '
import { createTransposeProgramInfo, TransposeAttributes } from './transpose';
const getInput = (inputs: readonly TensorView[], i: number) =>
inputs.length > i && inputs[i].dims.length > 0 && ShapeUtil.size(inputs[i].dims) > 0 ? inputs[i] : undefined;
inputs.length > i && inputs[i].dims.length > 0 ? inputs[i] : undefined;
const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => {
const query = inputs[0];

View file

@ -1073,5 +1073,80 @@
]
}
]
},
{
"name": "MultiHeadAttention Basic, one head and head-size=1 with empty pastKey, pastValue inputs and optional presentKey, presentValue outputs",
"operator": "MultiHeadAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
"cases": [
{
"name": "T[0]",
"inputs": [
// Q
{
"data": [1],
"dims": [1, 1, 1],
"type": "float32"
},
// K
{
"data": [2],
"dims": [1, 1, 1],
"type": "float32"
},
// V
{
"data": [3],
"dims": [1, 1, 1],
"type": "float32"
},
// Bias
{
"data": null,
"type": "float32"
},
// Mask
{
"data": null,
"type": "int32"
},
// AttentionBias
{
"data": null,
"type": "float32"
},
// PastKey
{
"data": [],
"dims": [1, 1, 0, 1],
"type": "float32"
},
// PastValue
{
"data": [],
"dims": [1, 1, 0, 1],
"type": "float32"
}
],
"outputs": [
{
"data": [3],
"dims": [1, 1, 1],
"type": "float32"
},
{
"data": [2],
"dims": [1, 1, 1, 1],
"type": "float32"
},
{
"data": [3],
"dims": [1, 1, 1, 1],
"type": "float32"
}
]
}
]
}
]