mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
[JS/WebGPU] Avoid producing presentKey/presentValue outputs if pastKey/pastValue … (#21782)
Avoid producing presentKey/presentValue outputs if pastKey/pastValue don't exists. ### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
a22cc078b4
commit
1fb2e71ddc
3 changed files with 118 additions and 36 deletions
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
import { DataType } from '../../../wasm-common';
|
||||
import { TensorView } from '../../tensor-view';
|
||||
import { ShapeUtil } from '../../util';
|
||||
import { ComputeContext, GpuDataType, ProgramInputTensorInfoDependency, ProgramUniform } from '../types';
|
||||
|
||||
import {
|
||||
|
|
@ -257,7 +258,7 @@ const validateAttentionInputs = (inputs: readonly TensorView[], attributes: Atte
|
|||
};
|
||||
};
|
||||
|
||||
const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: TensorView, n: number, d: number) => {
|
||||
const createInPlaceSoftmaxProgramInfo = (input: TensorView, n: number, d: number) => {
|
||||
const components = getMaxComponents(d);
|
||||
let WG = 64;
|
||||
const dComp = d / components;
|
||||
|
|
@ -358,7 +359,7 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor
|
|||
};
|
||||
|
||||
const createAttentionProbsProgramInfo = (
|
||||
context: ComputeContext,
|
||||
outputCount: number,
|
||||
q: TensorView,
|
||||
key: TensorView,
|
||||
pastKey: TensorView | undefined,
|
||||
|
|
@ -369,7 +370,7 @@ const createAttentionProbsProgramInfo = (
|
|||
) => {
|
||||
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
|
||||
const probsShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, totalSequenceLength];
|
||||
const presentKey = parameters.kvNumHeads === undefined && context.outputCount > 1;
|
||||
const presentKey = parameters.kvNumHeads === undefined && outputCount > 1 && pastKey;
|
||||
const presentKeyShape = presentKey
|
||||
? [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize]
|
||||
: undefined;
|
||||
|
|
@ -394,9 +395,10 @@ const createAttentionProbsProgramInfo = (
|
|||
{ type: DataType.uint32, data: pastSequenceLength },
|
||||
{ type: DataType.uint32, data: parameters.kvSequenceLength },
|
||||
];
|
||||
|
||||
// Feed pastKey to the shader-code only if it is non-zero and presentKey is being produced
|
||||
const feedPastKey = presentKey && pastKey && ShapeUtil.size(pastKey.dims) > 0;
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
|
||||
if (pastKey) {
|
||||
if (feedPastKey) {
|
||||
inputDependencies.push('type');
|
||||
}
|
||||
if (attentionBias) {
|
||||
|
|
@ -410,7 +412,7 @@ const createAttentionProbsProgramInfo = (
|
|||
const qInput = inputVariable('q', q.dataType, q.dims, components);
|
||||
const kInput = inputVariable('key', key.dataType, key.dims, components);
|
||||
const inputVars = [qInput, kInput];
|
||||
if (pastKey) {
|
||||
if (feedPastKey) {
|
||||
const pastKeyInput = inputVariable('past_key', pastKey.dataType, pastKey.dims, components);
|
||||
inputVars.push(pastKeyInput);
|
||||
}
|
||||
|
|
@ -446,7 +448,7 @@ const createAttentionProbsProgramInfo = (
|
|||
let n = workgroup_id.x * TILE_SIZE;
|
||||
let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K;
|
||||
${(() => {
|
||||
if (pastKey && presentKey) {
|
||||
if (feedPastKey && presentKey) {
|
||||
return `
|
||||
let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx;
|
||||
let pastKeyOffset = uniforms.past_sequence_length * uniforms.K * headIdx;`;
|
||||
|
|
@ -464,7 +466,7 @@ const createAttentionProbsProgramInfo = (
|
|||
if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {
|
||||
var idx = TILE_SIZE * local_id.y + local_id.x;
|
||||
${(() => {
|
||||
if (pastKey && presentKey) {
|
||||
if (feedPastKey && presentKey) {
|
||||
return `
|
||||
if (n + local_id.y < uniforms.past_sequence_length) {
|
||||
tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
|
||||
|
|
@ -513,7 +515,7 @@ const createAttentionProbsProgramInfo = (
|
|||
return {
|
||||
name: 'AttentionProbs',
|
||||
shaderCache: {
|
||||
hint: `${components};${attentionBias !== undefined};${pastKey !== undefined};${context.outputCount}`,
|
||||
hint: `${components};${attentionBias !== undefined};${pastKey !== undefined};${outputCount}`,
|
||||
inputDependencies,
|
||||
},
|
||||
getRunData: () => ({ outputs, dispatchGroup: dispatch, programUniforms }),
|
||||
|
|
@ -522,7 +524,7 @@ const createAttentionProbsProgramInfo = (
|
|||
};
|
||||
|
||||
const createVxAttentionScoreProgramInfo = (
|
||||
context: ComputeContext,
|
||||
outputCount: number,
|
||||
probs: TensorView,
|
||||
v: TensorView,
|
||||
pastValue: TensorView | undefined,
|
||||
|
|
@ -532,7 +534,7 @@ const createVxAttentionScoreProgramInfo = (
|
|||
const totalSequenceLength = pastSequenceLength + params.kvSequenceLength;
|
||||
const nReps = params.nReps ? params.nReps : 1;
|
||||
const repeatedVHiddenSize = params.vHiddenSize * nReps;
|
||||
const presentValue = params.kvNumHeads == null && context.outputCount > 1;
|
||||
const presentValue = params.kvNumHeads == null && outputCount > 1 && pastValue;
|
||||
const presentValueShape = presentValue
|
||||
? [params.batchSize, params.numHeads, totalSequenceLength, params.headSize]
|
||||
: undefined;
|
||||
|
|
@ -553,7 +555,12 @@ const createVxAttentionScoreProgramInfo = (
|
|||
{ type: DataType.uint32, data: pastSequenceLength },
|
||||
{ type: DataType.uint32, data: params.kvSequenceLength },
|
||||
];
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] = pastValue ? ['type', 'type', 'type'] : ['type', 'type'];
|
||||
// Feed pastValue to the shader-code only if it is non-empty and presentValue is being produced
|
||||
const feedPastValue = presentValue && pastValue && ShapeUtil.size(pastValue.dims) > 0;
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
|
||||
if (feedPastValue) {
|
||||
inputDependencies.push('type');
|
||||
}
|
||||
const outputs = [{ dims: outputShape, dataType: probs.dataType, gpuDataType: GpuDataType.default }];
|
||||
if (presentValue) {
|
||||
outputs.push({ dims: presentValueShape!, dataType: probs.dataType, gpuDataType: GpuDataType.default });
|
||||
|
|
@ -562,7 +569,7 @@ const createVxAttentionScoreProgramInfo = (
|
|||
const probsHelper = inputVariable('probs', probs.dataType, probs.dims);
|
||||
const vHelper = inputVariable('v', v.dataType, v.dims);
|
||||
const inputVars = [probsHelper, vHelper];
|
||||
if (pastValue) {
|
||||
if (feedPastValue) {
|
||||
inputVars.push(inputVariable('past_value', pastValue.dataType, pastValue.dims));
|
||||
}
|
||||
const output = outputVariable('output', probs.dataType, outputShape);
|
||||
|
|
@ -591,7 +598,7 @@ const createVxAttentionScoreProgramInfo = (
|
|||
|
||||
let offsetA = headIdx * (uniforms.M * uniforms.K) + m * uniforms.K;
|
||||
${(() => {
|
||||
if (pastValue && presentValue) {
|
||||
if (feedPastValue && presentValue) {
|
||||
return `
|
||||
let pastValueOffset = headIdx * uniforms.N * uniforms.past_sequence_length + n;
|
||||
let vOffset = headIdx * uniforms.N * uniforms.kv_sequence_length + n;
|
||||
|
|
@ -611,7 +618,7 @@ const createVxAttentionScoreProgramInfo = (
|
|||
if (n < uniforms.N && w + local_id.y < uniforms.K) {
|
||||
var idx = TILE_SIZE * local_id.y + local_id.x;
|
||||
${(() => {
|
||||
if (pastValue && presentValue) {
|
||||
if (feedPastValue && presentValue) {
|
||||
return `
|
||||
if (w + local_id.y < uniforms.past_sequence_length) {
|
||||
tileK[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N];
|
||||
|
|
@ -647,7 +654,7 @@ const createVxAttentionScoreProgramInfo = (
|
|||
|
||||
return {
|
||||
name: 'AttentionScore',
|
||||
shaderCache: { hint: `${pastValue !== undefined};${context.outputCount}`, inputDependencies },
|
||||
shaderCache: { hint: `${pastValue !== undefined};${outputCount}`, inputDependencies },
|
||||
getRunData: () => ({ outputs, dispatchGroup: dispatch, programUniforms }),
|
||||
getShaderSource,
|
||||
};
|
||||
|
|
@ -662,15 +669,21 @@ export const applyAttention = (
|
|||
_past: TensorView | undefined,
|
||||
pastKey: TensorView | undefined,
|
||||
pastValue: TensorView | undefined,
|
||||
attentionBias: TensorView | undefined,
|
||||
attentionBiasInput: TensorView | undefined,
|
||||
parameters: AttentionParameters,
|
||||
attributes: AttentionAttrs,
|
||||
) => {
|
||||
const outputCount = context.outputCount;
|
||||
// Assumption is that presentKey/presentValue exists only if pastKey/pastValue exists.
|
||||
const outputCount = Math.min(context.outputCount, 1 + (pastKey ? 1 : 0) + (pastValue ? 1 : 0));
|
||||
const pastSequenceLength = parameters.kvNumHeads !== undefined || outputCount > 1 ? parameters.pastSequenceLength : 0;
|
||||
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
|
||||
const attentionBias =
|
||||
attentionBiasInput && ShapeUtil.size(attentionBiasInput.dims) > 0 ? attentionBiasInput : undefined;
|
||||
|
||||
const inputsK = parameters.kvNumHeads === undefined && outputCount > 1 && pastKey ? [q, k, pastKey] : [q, k];
|
||||
const inputsK = [q, k];
|
||||
if (parameters.kvNumHeads === undefined && outputCount > 1 && pastKey && ShapeUtil.size(pastKey.dims) > 0) {
|
||||
inputsK.push(pastKey);
|
||||
}
|
||||
if (attentionBias) {
|
||||
inputsK.push(attentionBias);
|
||||
}
|
||||
|
|
@ -678,10 +691,10 @@ export const applyAttention = (
|
|||
// Run AttentionProbs
|
||||
const probs = context.compute(
|
||||
createAttentionProbsProgramInfo(
|
||||
context,
|
||||
outputCount,
|
||||
q,
|
||||
k,
|
||||
outputCount > 1 ? pastKey : undefined,
|
||||
pastKey,
|
||||
attentionBias,
|
||||
parameters,
|
||||
attributes,
|
||||
|
|
@ -693,7 +706,6 @@ export const applyAttention = (
|
|||
// Run Softmax
|
||||
context.compute(
|
||||
createInPlaceSoftmaxProgramInfo(
|
||||
context,
|
||||
probs,
|
||||
parameters.batchSize * parameters.numHeads * parameters.sequenceLength,
|
||||
totalSequenceLength,
|
||||
|
|
@ -702,19 +714,14 @@ export const applyAttention = (
|
|||
);
|
||||
|
||||
// Run AttrionScore
|
||||
const inputsV =
|
||||
parameters.kvNumHeads === undefined && outputCount > 1 && pastValue ? [probs, v, pastValue] : [probs, v];
|
||||
context.compute(
|
||||
createVxAttentionScoreProgramInfo(
|
||||
context,
|
||||
probs,
|
||||
v,
|
||||
outputCount > 1 && pastValue ? pastValue : undefined,
|
||||
parameters,
|
||||
pastSequenceLength,
|
||||
),
|
||||
{ inputs: inputsV, outputs: parameters.kvNumHeads === undefined && outputCount > 1 ? [0, 2] : [0] },
|
||||
);
|
||||
const inputsV = [probs, v];
|
||||
if (parameters.kvNumHeads === undefined && outputCount > 1 && pastValue && ShapeUtil.size(pastValue.dims) > 0) {
|
||||
inputsV.push(pastValue);
|
||||
}
|
||||
context.compute(createVxAttentionScoreProgramInfo(outputCount, probs, v, pastValue, parameters, pastSequenceLength), {
|
||||
inputs: inputsV,
|
||||
outputs: parameters.kvNumHeads === undefined && outputCount > 1 ? [0, 2] : [0],
|
||||
});
|
||||
};
|
||||
|
||||
const prepare = (context: ComputeContext, parameters: AttentionParameters) => {
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import { inputVariable, outputVariable, ShaderHelper, UniformsArrayType } from '
|
|||
import { createTransposeProgramInfo, TransposeAttributes } from './transpose';
|
||||
|
||||
const getInput = (inputs: readonly TensorView[], i: number) =>
|
||||
inputs.length > i && inputs[i].dims.length > 0 && ShapeUtil.size(inputs[i].dims) > 0 ? inputs[i] : undefined;
|
||||
inputs.length > i && inputs[i].dims.length > 0 ? inputs[i] : undefined;
|
||||
|
||||
const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => {
|
||||
const query = inputs[0];
|
||||
|
|
|
|||
|
|
@ -1073,5 +1073,80 @@
|
|||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "MultiHeadAttention Basic, one head and head-size=1 with empty pastKey, pastValue inputs and optional presentKey, presentValue outputs",
|
||||
"operator": "MultiHeadAttention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
// Q
|
||||
{
|
||||
"data": [1],
|
||||
"dims": [1, 1, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
// K
|
||||
{
|
||||
"data": [2],
|
||||
"dims": [1, 1, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
// V
|
||||
{
|
||||
"data": [3],
|
||||
"dims": [1, 1, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
// Bias
|
||||
{
|
||||
"data": null,
|
||||
"type": "float32"
|
||||
},
|
||||
// Mask
|
||||
{
|
||||
"data": null,
|
||||
"type": "int32"
|
||||
},
|
||||
// AttentionBias
|
||||
{
|
||||
"data": null,
|
||||
"type": "float32"
|
||||
},
|
||||
// PastKey
|
||||
{
|
||||
"data": [],
|
||||
"dims": [1, 1, 0, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
// PastValue
|
||||
{
|
||||
"data": [],
|
||||
"dims": [1, 1, 0, 1],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [3],
|
||||
"dims": [1, 1, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [2],
|
||||
"dims": [1, 1, 1, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [3],
|
||||
"dims": [1, 1, 1, 1],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in a new issue