diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index fda2ff64b0..435267a1b9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -6,7 +6,6 @@ import {TensorView} from '../../tensor-view'; import {ComputeContext, GpuDataType, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; import {getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, tensorTypeToWsglValueType, UniformDataElementType, UniformsArrayType} from './common'; -import {createConcatProgramInfo} from './concat'; export const enum AttentionQkvFormat { unknown, // enum value not set, or depends on qkv projection implementation details @@ -336,10 +335,15 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor }; const createAttentionProbsProgramInfo = - (_context: ComputeContext, q: TensorView, key: TensorView, relativePositionBias: TensorView|undefined, - parameters: AttentionParameters, attributes: AttentionAttrs, pastSequenceLength: number) => { + (context: ComputeContext, q: TensorView, key: TensorView, pastKey: TensorView|undefined, + relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs, + pastSequenceLength: number) => { const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength; const probsShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, totalSequenceLength]; + const presentKey = parameters.kvNumHeads === undefined && context.outputCount > 1; + const presentKeyShape = presentKey ? + [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize] : + undefined; // TODO: handle mask @@ -355,34 +359,51 @@ const createAttentionProbsProgramInfo = const programUniforms: ProgramUniform[] = [ {type: DataType.uint32, data: parameters.sequenceLength}, {type: DataType.uint32, data: vectorizedHeadSize}, {type: DataType.uint32, data: totalSequenceLength}, {type: DataType.uint32, data: parameters.numHeads}, - {type: DataType.float, data: alpha} + {type: DataType.float, data: alpha}, {type: DataType.uint32, data: pastSequenceLength}, + {type: DataType.uint32, data: parameters.kvSequenceLength} ]; - const inputDependencies: ProgramInputTensorInfoDependency[] = - relativePositionBias ? ['type', 'type', 'type'] : ['type', 'type']; - + const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; + if (pastKey) { + inputDependencies.push('type'); + } + if (relativePositionBias) { + inputDependencies.push('type'); + } + const outputs = [{dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default}]; + if (presentKey) { + outputs.push({dims: presentKeyShape!, dataType: q.dataType, gpuDataType: GpuDataType.default}); + } const getShaderSource = (shaderHelper: ShaderHelper) => { const qInput = inputVariable('q', q.dataType, q.dims, components); const kInput = inputVariable('key', key.dataType, key.dims, components); const inputVars = [qInput, kInput]; + if (pastKey) { + const pastKeyInput = inputVariable('past_key', pastKey.dataType, pastKey.dims, components); + inputVars.push(pastKeyInput); + } if (relativePositionBias) { inputVars.push( inputVariable('relative_position_bias', relativePositionBias.dataType, relativePositionBias.dims)); } const output = outputVariable('output', q.dataType, probsShape); - // const dataType = tensorTypeToWsglStorageType(q.dataType); + const outputVars = [output]; + if (presentKey) { + outputVars.push(outputVariable('present_key', q.dataType, presentKeyShape!, components)); + } const f32Type = tensorTypeToWsglValueType(DataType.float, components); const uniforms: UniformsArrayType = [ {name: 'M', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'}, - {name: 'num_heads', type: 'u32'}, {name: 'alpha', type: 'f32' as UniformDataElementType} + {name: 'num_heads', type: 'u32'}, {name: 'alpha', type: 'f32' as UniformDataElementType}, + {name: 'past_sequence_length', type: 'u32'}, {name: 'kv_sequence_length', type: 'u32'} ]; return ` const TILE_SIZE = ${TILE_SIZE}u; var tileQ: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>; var tileK: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>; - ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, output)} + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, ...outputVars)} ${shaderHelper.mainStart([ TILE_SIZE, TILE_SIZE, 1 ])} @@ -391,15 +412,41 @@ const createAttentionProbsProgramInfo = let m = workgroup_id.y * TILE_SIZE; let n = workgroup_id.x * TILE_SIZE; let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K; - let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K; - + ${(() => { + if (pastKey && presentKey) { + return ` + let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx; + let pastKeyOffset = uniforms.past_sequence_length * uniforms.K * headIdx;`; + } else { + return ` + let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K;`; + } + })()} + ${presentKey ? 'let presentKeyOffset = headIdx * uniforms.N * uniforms.K;' : ''} var value = ${f32Type}(0); for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) { if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) { tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x]; } if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) { - tileK[TILE_SIZE * local_id.y + local_id.x] = key[kOffset + local_id.y * uniforms.K + w + local_id.x]; + var idx = TILE_SIZE * local_id.y + local_id.x; + ${(() => { + if (pastKey && presentKey) { + return ` + if (n + local_id.y < uniforms.past_sequence_length) { + tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x]; + } else { + tileK[idx] = + key[kOffset + (n + local_id.y - uniforms.past_sequence_length) * uniforms.K + w + local_id.x]; + }`; + } else { + return 'tileK[idx] = key[kOffset + local_id.y * uniforms.K + w + local_id.x];'; + } + })()} + ${ + presentKey ? + 'present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];' : + ''} } workgroupBarrier(); @@ -432,23 +479,25 @@ const createAttentionProbsProgramInfo = }; return { name: 'AttentionProbs', - shaderCache: {hint: `${components}`, inputDependencies}, - getRunData: () => ({ - outputs: [{dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default}], - dispatchGroup: dispatch, - programUniforms - }), + shaderCache: { + hint: `${components};${relativePositionBias !== undefined};${pastKey !== undefined};${context.outputCount}`, + inputDependencies + }, + getRunData: () => ({outputs, dispatchGroup: dispatch, programUniforms}), getShaderSource, }; }; const createVxAttentionScoreProgramInfo = - (_context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters, - pastSequenceLength: number) => { + (context: ComputeContext, probs: TensorView, v: TensorView, pastValue: TensorView|undefined, + params: AttentionParameters, pastSequenceLength: number) => { const totalSequenceLength = pastSequenceLength + params.kvSequenceLength; const nReps = params.nReps ? params.nReps : 1; const repeatedVHiddenSize = params.vHiddenSize * nReps; + const presentValue = params.kvNumHeads == null && context.outputCount > 1; + const presentValueShape = + presentValue ? [params.batchSize, params.numHeads, totalSequenceLength, params.headSize] : undefined; const outputShape = [params.batchSize, params.sequenceLength, repeatedVHiddenSize]; const TILE_SIZE = 12; const dispatch = { @@ -460,23 +509,37 @@ const createVxAttentionScoreProgramInfo = const programUniforms: ProgramUniform[] = [ {type: DataType.uint32, data: params.sequenceLength}, {type: DataType.uint32, data: totalSequenceLength}, {type: DataType.uint32, data: params.vHeadSize}, {type: DataType.uint32, data: params.numHeads}, - {type: DataType.uint32, data: repeatedVHiddenSize} + {type: DataType.uint32, data: repeatedVHiddenSize}, {type: DataType.uint32, data: pastSequenceLength}, + {type: DataType.uint32, data: params.kvSequenceLength} ]; - - const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; + const inputDependencies: ProgramInputTensorInfoDependency[] = + pastValue ? ['type', 'type', 'type'] : ['type', 'type']; + const outputs = [{dims: outputShape, dataType: probs.dataType, gpuDataType: GpuDataType.default}]; + if (presentValue) { + outputs.push({dims: presentValueShape!, dataType: probs.dataType, gpuDataType: GpuDataType.default}); + } const getShaderSource = (shaderHelper: ShaderHelper) => { const probsHelper = inputVariable('probs', probs.dataType, probs.dims); const vHelper = inputVariable('v', v.dataType, v.dims); + const inputVars = [probsHelper, vHelper]; + if (pastValue) { + inputVars.push(inputVariable('past_value', pastValue.dataType, pastValue.dims)); + } const output = outputVariable('output', probs.dataType, outputShape); + const outputVars = [output]; + if (presentValue) { + outputVars.push(outputVariable('present_value', probs.dataType, presentValueShape!)); + } const uniforms: UniformsArrayType = [ {name: 'M', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'}, - {name: 'num_heads', type: 'u32'}, {name: 'v_hidden_size', type: 'u32'} + {name: 'num_heads', type: 'u32'}, {name: 'v_hidden_size', type: 'u32'}, + {name: 'past_sequence_length', type: 'u32'}, {name: 'kv_sequence_length', type: 'u32'} ]; return ` const TILE_SIZE = ${TILE_SIZE}u; var tileQ: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>; var tileK: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>; - ${shaderHelper.registerUniforms(uniforms).declareVariables(probsHelper, vHelper, output)} + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, ...outputVars)} ${shaderHelper.mainStart([ TILE_SIZE, TILE_SIZE, 1 ])} @@ -485,16 +548,43 @@ const createVxAttentionScoreProgramInfo = let n = global_id.x; let offsetA = headIdx * (uniforms.M * uniforms.K) + m * uniforms.K; - let offsetB = headIdx * (uniforms.N * uniforms.K) + n; - + ${(() => { + if (pastValue && presentValue) { + return ` + let pastValueOffset = headIdx * uniforms.N * uniforms.past_sequence_length + n; + let vOffset = headIdx * uniforms.N * uniforms.kv_sequence_length + n; + `; + } else { + return ` + let offsetB = headIdx * uniforms.N * uniforms.K + n; + `; + } + })()} + ${presentValue ? 'let presentValueOffset = headIdx * uniforms.N * uniforms.K + n;' : ''} var value = ${probsHelper.type.storage}(0); for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) { - if (m < uniforms.M && w + local_id.x < uniforms.K) { - tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x]; - } - if (n < uniforms.N && w + local_id.y < uniforms.K) { - tileK[TILE_SIZE * local_id.y + local_id.x] = v[offsetB + (w + local_id.y) * uniforms.N]; - } + if (m < uniforms.M && w + local_id.x < uniforms.K) { + tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x]; + } + if (n < uniforms.N && w + local_id.y < uniforms.K) { + var idx = TILE_SIZE * local_id.y + local_id.x; + ${(() => { + if (pastValue && presentValue) { + return ` + if (w + local_id.y < uniforms.past_sequence_length) { + tileK[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N]; + } else { + tileK[idx] = v[vOffset + (w + local_id.y - uniforms.past_sequence_length) * uniforms.N]; + } + `; + } else { + return ` + tileK[idx] = v[offsetB + (w + local_id.y) * uniforms.N]; + `; + } + })()} + ${presentValue ? 'present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];' : ''} + } workgroupBarrier(); for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) { value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x]; @@ -515,12 +605,8 @@ const createVxAttentionScoreProgramInfo = return { name: 'AttentionScore', - shaderCache: {inputDependencies}, - getRunData: () => ({ - outputs: [{dims: outputShape, dataType: probs.dataType, gpuDataType: GpuDataType.default}], - dispatchGroup: dispatch, - programUniforms - }), + shaderCache: {hint: `${pastValue !== undefined};${context.outputCount}`, inputDependencies}, + getRunData: () => ({outputs, dispatchGroup: dispatch, programUniforms}), getShaderSource, }; }; @@ -529,29 +615,12 @@ export const applyAttention = (context: ComputeContext, q: TensorView, k: TensorView, v: TensorView, _maskIndex: TensorView|undefined, _past: TensorView|undefined, pastKey: TensorView|undefined, pastValue: TensorView|undefined, relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs) => { - const outputPresentKey = context.outputCount > 1; - const outputPresentValue = context.outputCount > 2; + const outputCount = context.outputCount; const pastSequenceLength = - parameters.kvNumHeads != null || (outputPresentKey && outputPresentValue) ? parameters.pastSequenceLength : 0; + parameters.kvNumHeads !== undefined || outputCount > 1 ? parameters.pastSequenceLength : 0; const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength; - // Concatinate pastKey and K to produce presentKey. - const presentKeyShape = [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize]; - const concatKeyInputs = pastKey ? [pastKey, k] : [k]; - const key = parameters.kvNumHeads == null && outputPresentKey ? - context.compute( - createConcatProgramInfo(concatKeyInputs, 2, presentKeyShape, k.dataType), - {inputs: concatKeyInputs, outputs: [1]})[0] : - k; - // Concatinate pastValue and V to produce presentValue. - const presentValueShape = [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize]; - const concatValueInputs = pastValue ? [pastValue, v] : [v]; - const value = parameters.kvNumHeads == null && outputPresentValue ? - context.compute( - createConcatProgramInfo(concatValueInputs, 2, presentValueShape, v.dataType), - {inputs: concatValueInputs, outputs: [2]})[0] : - v; - const inputsK = [q, key]; + const inputsK = (parameters.kvNumHeads === undefined && outputCount > 1 && pastKey) ? [q, k, pastKey] : [q, k]; if (relativePositionBias) { inputsK.push(relativePositionBias); } @@ -559,8 +628,9 @@ export const applyAttention = // Run AttentionProbs const probs = context.compute( createAttentionProbsProgramInfo( - context, q, key, relativePositionBias, parameters, attributes, pastSequenceLength), - {inputs: inputsK, outputs: [-1]})[0]; + context, q, k, outputCount > 1 ? pastKey : undefined, relativePositionBias, parameters, attributes, + pastSequenceLength), + {inputs: inputsK, outputs: (parameters.kvNumHeads === undefined && outputCount > 1) ? [-1, 1] : [-1]})[0]; // Run Softmax context.compute( @@ -570,10 +640,12 @@ export const applyAttention = {inputs: [probs], outputs: []}); // Run AttrionScore - const inputsV = [probs, value]; + const inputsV = + (parameters.kvNumHeads === undefined && outputCount > 1 && pastValue) ? [probs, v, pastValue] : [probs, v]; context.compute( - createVxAttentionScoreProgramInfo(context, probs, value, parameters, pastSequenceLength), - {inputs: inputsV, outputs: [0]}); + createVxAttentionScoreProgramInfo( + context, probs, v, outputCount > 1 && pastValue ? pastValue : undefined, parameters, pastSequenceLength), + {inputs: inputsV, outputs: (parameters.kvNumHeads === undefined && outputCount > 1) ? [0, 2] : [0]}); }; const prepare = (context: ComputeContext, parameters: AttentionParameters) => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index fedffa27f2..010ee589c4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -71,7 +71,7 @@ const assignOutputData = (inputs: readonly IndicesHelper[], output: IndicesHelpe return codeLines.join('\n'); }; -export const createConcatProgramInfo = +const createConcatProgramInfo = (inputs: readonly TensorView[], adjustedAxis: number, outputShape: number[], dataType: DataType): ProgramInfo => { const outputSize = ShapeUtil.size(outputShape); diff --git a/js/web/test/data/ops/multihead-attention.jsonc b/js/web/test/data/ops/multihead-attention.jsonc index 2c5dd30df9..6ce6a5e0a8 100644 --- a/js/web/test/data/ops/multihead-attention.jsonc +++ b/js/web/test/data/ops/multihead-attention.jsonc @@ -322,7 +322,7 @@ ] }, { - "name": "MultiHeadAttention Basic, one head and head-size=1 with pastKey and pastValue", + "name": "MultiHeadAttention Basic, one head and head-size=1 with optional RelativePositionBias, pastKey, pastValue inputs and optional presentKey, presentValue outputs", "operator": "MultiHeadAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], @@ -397,7 +397,7 @@ ] }, { - "name": "MultiHeadAttention Basic, one head and head-size=4 with pastKey and pastValue", + "name": "MultiHeadAttention Basic, one head and head-size=4 with RelativePositionBias, pastKey, pastValue inputs and optional presentKey, presentValue outputs", "operator": "MultiHeadAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], @@ -474,7 +474,7 @@ ] }, { - "name": "MultiHeadAttention Basic, one head and head-size=1 with pastKey and pastValue", + "name": "MultiHeadAttention Basic, one head and head-size=1 with relativePositionBias, pastKey and pastValue", "operator": "MultiHeadAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], @@ -512,7 +512,8 @@ }, // RelativePositionBias { - "data": null, + "data": [10, 20], + "dims": [1, 1, 1, 2], "type": "float32" }, // PastKey @@ -539,7 +540,7 @@ ] }, { - "name": "MultiHeadAttention Basic, one head and head-size=4 with pastKey and pastValue", + "name": "MultiHeadAttention Basic, one head and head-size=4 with relativePositionBias, and pastValue", "operator": "MultiHeadAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], @@ -577,7 +578,8 @@ }, // RelativePositionBias { - "data": null, + "data": [100, 200], + "dims": [1, 1, 1, 2], "type": "float32" }, // PastKey @@ -841,7 +843,7 @@ ] }, { - "name": "MultiHeadAttention Basic, one head and head-size=4 with RelativePositionBias, PastKey, PastValue, PresentKey and PresentValue", + "name": "MultiHeadAttention Basic, one head and head-size=4 with RelativePositionBias, PastKey, PastValue inputs and PresentKey and PresentValue outputs", "operator": "MultiHeadAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], @@ -976,7 +978,7 @@ ], "outputs": [ { - "data": [3.0006706714630127], + "data": [3], "dims": [1, 1, 1], "type": "float32" }, @@ -1052,7 +1054,7 @@ ], "outputs": [ { - "data": [9.000362396240234, 10.00036334991455, 11.000362396240234, 12.000362396240234], + "data": [9, 10, 11, 12], "dims": [1, 1, 4], "type": "float32" },