From dee6a5b3715c5bdf7a6d29c2b9516902ebd0e0b1 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Tue, 9 Jan 2024 23:46:30 +0800 Subject: [PATCH] [js/webgpu] Support uniforms for attention and multihead attention (#18903) --- .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 4 +- js/web/lib/wasm/jsep/webgpu/ops/attention.ts | 328 +++++++++--------- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 4 +- .../jsep/webgpu/ops/multi-head-attentiion.ts | 34 +- 4 files changed, 187 insertions(+), 183 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 06c3c6c196..c182d3c4ea 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -2,7 +2,7 @@ // Licensed under the MIT License. import {argMax, argMin, parseArgMinMaxAttributes} from './ops/argminmax'; -import {attention, parseAttentionAttributes} from './ops/attention'; +import {attention} from './ops/attention'; import {batchNorm} from './ops/batch-norm'; import {biasAdd} from './ops/bias-add'; import {biasSplitGelu} from './ops/bias-split-gelu'; @@ -50,7 +50,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Asinh', [unaryOps.asinh]], ['Atan', [unaryOps.atan]], ['Atanh', [unaryOps.atanh]], - ['Attention', [attention, parseAttentionAttributes]], + ['Attention', [attention]], // TODO: support new attributes for AveragePool-10 ['AveragePool', [pool.averagePool, pool.parseAveragePoolAttributes]], ['BatchNormalization', [batchNorm]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index e1f2a47301..ef8038dff4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -1,11 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {tensorDataTypeEnumToString} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; -import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType} from '../types'; +import {ComputeContext, GpuDataType, ProgramUniform} from '../types'; -import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common'; +import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, tensorTypeToWsglValueType, UniformDataElementType, UniformsArrayType} from './common'; export const enum AttentionQkvFormat { unknown, // enum value not set, or depends on qkv projection implementation details @@ -231,20 +231,8 @@ const validateAttentionInputs = (inputs: readonly TensorView[], attributes: Atte }; }; -export const parseAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs => - createAttributeWithCacheKey({...attributes}); - export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView, n: number, d: number) => { const components = getMaxComponents(d); - const inputHelper = outputVariable('x', input.dataType, input.dims, components); - - let threadMaxValue = 'threadMaxVector'; - if (components === 2) { - threadMaxValue = 'max(threadMaxVector.x, threadMaxVector.y)'; - } else if (components === 4) { - threadMaxValue = 'max(max(threadMaxVector.x, threadMaxVector.y), max(threadMaxVector.z, threadMaxVector.w))'; - } - const dataType = tensorTypeToWsglStorageType(input.dataType); let WG = 64; const dComp = d / components; if (dComp < WG) { @@ -253,25 +241,41 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView WG = Math.ceil(dComp / 8); } const elementsPerWG = Math.ceil(d / components / WG); + const tensorDataType = tensorDataTypeEnumToString(input.dataType) as ProgramUniform['type']; + const programUniforms: ProgramUniform[] = + [{type: tensorDataType, data: 1 / d}, {type: 'uint32', data: dComp}, {type: 'uint32', data: elementsPerWG}]; + const dataType = tensorTypeToWsglStorageType(input.dataType, components); - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const dInv: ${dataType} = 1 / ${d}; - const dComp = ${d / components}; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const inputHelper = outputVariable('x', input.dataType, input.dims, components); + let threadMaxValue = 'thread_max_vector'; + if (components === 2) { + threadMaxValue = 'max(thread_max_vector.x, thread_max_vector.y)'; + } else if (components === 4) { + threadMaxValue = + 'max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))'; + } + const elemValueType = tensorTypeToWsglValueType(input.dataType); + const uniforms: UniformsArrayType = [ + {name: 'd_inv', type: elemValueType as UniformDataElementType}, {name: 'd_comp', type: 'u32'}, + {name: 'elements_per_wg', type: 'u32'} + ]; + + return ` var wgMax: array; var wgSum: array; + ${shaderHelper.registerUniforms(uniforms).declareVariables(inputHelper)} + ${shaderHelper.mainStart([ + WG, 1, 1 + ])} + let localOffset = local_idx * uniforms.elements_per_wg; + let offset: u32 = workgroup_id.x * uniforms.d_comp + localOffset; - ${shaderHelper.declareVariables(inputHelper)} - @compute @workgroup_size(${WG}, 1, 1) - fn main(@builtin(workgroup_id) workgroup_id : vec3, - @builtin(local_invocation_index) local_index : u32) { - let localOffset = local_index * ${elementsPerWG}; - let offset: u32 = workgroup_id.x * dComp + localOffset; - - var threadMaxVector = ${fillVector('f32', components, '-3.402823e+38f')}; - for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { - threadMaxVector = max(${castToF32(dataType, components, 'x[offset + i]')}, threadMaxVector); + var thread_max_vector = ${fillVector('f32', components, '-3.402823e+38f')}; + for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) { + thread_max_vector = max(${castToF32(elemValueType, components, 'x[offset + i]')}, thread_max_vector); } - wgMax[local_index] = ${threadMaxValue}; + wgMax[local_idx] = ${threadMaxValue}; workgroupBarrier(); var maxValue = -3.402823e+38f; @@ -280,10 +284,10 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView } var sumVector = ${fillVector('f32', components, '0')}; - for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { - sumVector += exp(${castToF32(dataType, components, 'x[offset + i]')} - maxValue); + for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) { + sumVector += exp(${castToF32(elemValueType, components, 'x[offset + i]')} - maxValue); } - wgSum[local_index] = ${sumVector('sumVector', components)}; + wgSum[local_idx] = ${sumVector('sumVector', components)}; workgroupBarrier(); var sum: f32 = 0; @@ -292,26 +296,24 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView } if (sum == 0) { - for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { - x[offset + i] = ${fillVector(dataType, components, 'dInv')}; + for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) { + x[offset + i] = ${fillVector('f32', components, 'uniforms.d_inv')}; } } else { - for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { - let f32input = ${castToF32(dataType, components, 'x[offset + i]')}; + for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) { + let f32input = ${castToF32(elemValueType, components, 'x[offset + i]')}; x[offset + i] = ${inputHelper.type.value}(exp(f32input - maxValue) / sum); } } }`; + }; context.compute( { name: 'AttentionProbsSoftmax', - shaderCache: {hint: `${d}`}, + shaderCache: {hint: `${WG};${dataType};${components}`}, getShaderSource, - getRunData: () => ({ - outputs: [], - dispatchGroup: {x: n}, - }), + getRunData: () => ({outputs: [], dispatchGroup: {x: n}, programUniforms}), }, {inputs: [input], outputs: []}); }; @@ -326,47 +328,43 @@ const computeAttentionProbs = // TODO: handle mask const alpha = attributes.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : attributes.scale; - - const dataType = tensorTypeToWsglStorageType(q.dataType); - const components = getMaxComponents(parameters.headSize); - const qInput = inputVariable('q', q.dataType, q.dims, components); - const kInput = inputVariable('key', key.dataType, key.dims, components); - const output = outputVariable('output', q.dataType, probsShape); - const vectorizedHeadSize = parameters.headSize / components; - const M = parameters.sequenceLength; - const N = parameters.totalSequenceLength; - const K = vectorizedHeadSize; - const TILE_SIZE = 12; - const dispatch = { x: Math.ceil(parameters.totalSequenceLength / TILE_SIZE), y: Math.ceil(parameters.sequenceLength / TILE_SIZE), z: parameters.batchSize * parameters.numHeads }; + const tensorDataType = tensorDataTypeEnumToString(q.dataType) as ProgramUniform['type']; + const programUniforms: ProgramUniform[] = [ + {type: 'uint32', data: parameters.sequenceLength}, {type: 'uint32', data: vectorizedHeadSize}, + {type: 'uint32', data: parameters.totalSequenceLength}, {type: 'uint32', data: parameters.kvSequenceLength}, + {type: tensorDataType, data: alpha} + ]; const inputs = [q, key]; - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const M: u32 = ${M}u; - const N: u32 = ${N}u; - const K: u32 = ${K}u; - const alpha: ${dataType} = ${alpha}; + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const qInput = inputVariable('q', q.dataType, q.dims, components); + const kInput = inputVariable('key', key.dataType, key.dims, components); + const output = outputVariable('output', q.dataType, probsShape); + const dataType = tensorTypeToWsglStorageType(q.dataType); + + const uniforms: UniformsArrayType = [ + {name: 'M', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'}, + {name: 'kv_sequence_length', type: 'u32'}, {name: 'alpha', type: dataType as UniformDataElementType} + ]; + return ` const beta: ${dataType} = 1.0; const TILE_SIZE = ${TILE_SIZE}u; var tileQ: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>; var tileK: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>; - - ${shaderHelper.declareVariables(qInput, kInput, output)} - - @compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1) - fn main(@builtin(workgroup_id) workgroup_id : vec3, - @builtin(local_invocation_id) local_id : vec3, @builtin(local_invocation_index) local_index : u32) { - let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u + - workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; - + ${shaderHelper.registerUniforms(uniforms).declareVariables(qInput, kInput, output)} + ${shaderHelper.mainStart([ + TILE_SIZE, TILE_SIZE, 1 + ])} // x holds the N and y holds the M let headIdx = workgroup_id.z; let m = workgroup_id.y * TILE_SIZE; @@ -374,40 +372,42 @@ const computeAttentionProbs = let lm = m + local_id.y; let ln = n + local_id.x; - let qOffset = ${parameters.sequenceLength * vectorizedHeadSize} * headIdx + m * K; - let kOffset = ${parameters.kvSequenceLength * vectorizedHeadSize} * headIdx + n * K; + let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K; + let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx + n * uniforms.K; var value = ${fillVector(dataType, components)}; - for (var w: u32 = 0u; w < K; w += TILE_SIZE) { - if (m + local_id.y < M && w + local_id.x < K) { - tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * K + w + local_id.x]; + for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) { + if (m + local_id.y < uniforms.M && w + local_id.x < uniforms.K) { + tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x]; } - if (n + local_id.y < N && w + local_id.x < K) { - tileK[TILE_SIZE * local_id.y + local_id.x] = key[kOffset + local_id.y * K + w + local_id.x]; + if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) { + tileK[TILE_SIZE * local_id.y + local_id.x] = key[kOffset + local_id.y * uniforms.K + w + local_id.x]; } workgroupBarrier(); - for (var k: u32 = 0u; k ({ outputs: [{dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default}], dispatchGroup: dispatch, + programUniforms }), getShaderSource, }, @@ -423,78 +423,76 @@ const computeAttentionProbs = const computeVxAttentionScore = (context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters) => { const outputShape = [params.batchSize, params.sequenceLength, params.vHiddenSize]; - - const probsHelper = inputVariable('probs', probs.dataType, probs.dims); - const vHelper = inputVariable('v', v.dataType, v.dims); - const output = outputVariable('output', probs.dataType, outputShape); - - const dataType = tensorTypeToWsglStorageType(probs.dataType); - const TILE_SIZE = 12; const dispatch = { x: Math.ceil(params.vHeadSize / TILE_SIZE), y: Math.ceil(params.sequenceLength / TILE_SIZE), z: params.batchSize * params.numHeads }; + const programUniforms: ProgramUniform[] = [ + {type: 'uint32', data: params.sequenceLength}, {type: 'uint32', data: params.totalSequenceLength}, + {type: 'uint32', data: params.vHeadSize}, {type: 'uint32', data: params.numHeads}, + {type: 'uint32', data: params.vHiddenSize} + ]; - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const M: u32 = ${params.sequenceLength}u; - const N: u32 = ${params.vHeadSize}u; - const K: u32 = ${params.totalSequenceLength}u; - const numHeads: u32 = ${params.numHeads}u; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const probsHelper = inputVariable('probs', probs.dataType, probs.dims); + const vHelper = inputVariable('v', v.dataType, v.dims); + const output = outputVariable('output', probs.dataType, outputShape); + const uniforms: UniformsArrayType = [ + {name: 'M', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'}, + {name: 'num_heads', type: 'u32'}, {name: 'v_hidden_size', type: 'u32'} + ]; + return ` const TILE_SIZE = ${TILE_SIZE}u; - - var tileQ: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>; - var tileK: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>; - - ${shaderHelper.declareVariables(probsHelper, vHelper, output)} - - @compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1) - fn main(@builtin(workgroup_id) workgroup_id : vec3, - @builtin(local_invocation_id) local_id : vec3, @builtin(local_invocation_index) local_index : u32) { - let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u + - workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; - + var tileQ: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>; + var tileK: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>; + ${shaderHelper.registerUniforms(uniforms).declareVariables(probsHelper, vHelper, output)} + ${shaderHelper.mainStart([ + TILE_SIZE, TILE_SIZE, 1 + ])} let headIdx = workgroup_id.z; let m = workgroup_id.y * TILE_SIZE + local_id.y; let n = workgroup_id.x * TILE_SIZE + local_id.x; - let offsetA = headIdx * (M * K) + m * K; - let offsetB = headIdx * (N * K) + n; + let offsetA = headIdx * (uniforms.M * uniforms.K) + m * uniforms.K; + let offsetB = headIdx * (uniforms.N * uniforms.K) + n; - var value = ${dataType}(0); - for (var w: u32 = 0u; w < K; w += TILE_SIZE) { - if (m < M && w + local_id.x < K) { + var value = ${probsHelper.type.storage}(0); + for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) { + if (m < uniforms.M && w + local_id.x < uniforms.K) { tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x]; } - if (n < N && w + local_id.y < K) { - tileK[TILE_SIZE * local_id.y + local_id.x] = v[offsetB + (w + local_id.y) * N]; + if (n < uniforms.N && w + local_id.y < uniforms.K) { + tileK[TILE_SIZE * local_id.y + local_id.x] = v[offsetB + (w + local_id.y) * uniforms.N]; } workgroupBarrier(); - for (var k: u32 = 0u; k ({ outputs: [{dims: outputShape, dataType: probs.dataType, gpuDataType: GpuDataType.default}], dispatchGroup: dispatch, + programUniforms }), getShaderSource, }, @@ -517,71 +515,71 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters) => { parameters.sequenceLength, parameters.headSize, ]; - - const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType); - const M = parameters.sequenceLength; const K = parameters.inputHiddenSize; const N = parameters.headSize; - const TILE_SIZE = 12; const dispatch = { x: Math.ceil(parameters.headSize / TILE_SIZE), y: Math.ceil(parameters.sequenceLength / TILE_SIZE), z: parameters.batchSize * parameters.numHeads }; + const inputs = [context.inputs[0], context.inputs[1], context.inputs[2]]; + const programUniforms: ProgramUniform[] = [ + {type: 'uint32', data: M}, {type: 'uint32', data: K}, {type: 'uint32', data: N}, + {type: 'uint32', data: parameters.numHeads}, {type: 'uint32', data: parameters.headSize}, + {type: 'uint32', data: parameters.hiddenSize}, + {type: 'uint32', data: parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize} + ]; - const getShaderSource = () => ` - const M: u32 = ${M}u; - const K: u32 = ${K}u; - const N: u32 = ${N}u; - const numHeads: u32 = ${parameters.numHeads}; - const ldb = ${parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize}u; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const outputQ = outputVariable('output_q', inputs[0].dataType, outputShape); + const outputK = outputVariable('output_k', inputs[0].dataType, outputShape); + const outputV = outputVariable('output_v', inputs[0].dataType, outputShape); + const input = inputVariable('input', inputs[0].dataType, inputs[0].dims); + const weight = inputVariable('weight', inputs[1].dataType, inputs[1].dims); + const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims); + const dataType = input.type.storage; + + const uniforms: UniformsArrayType = [ + {name: 'M', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'}, {name: 'num_heads', type: 'u32'}, + {name: 'head_size', type: 'u32'}, {name: 'hidden_size', type: 'u32'}, {name: 'ldb', type: 'u32'} + ]; + return ` const TILE_SIZE = ${TILE_SIZE}u; - var tileInput: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; var tileWeightQ: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; var tileWeightK: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; var tileWeightV: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; - - @group(0) @binding(0) var input: array<${dataType}>; - @group(0) @binding(1) var weight: array<${dataType}>; - @group(0) @binding(2) var bias: array<${dataType}>; - @group(0) @binding(3) var outputQ: array<${dataType}>; - @group(0) @binding(4) var outputK: array<${dataType}>; - @group(0) @binding(5) var outputV: array<${dataType}>; - - @compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1) - fn main(@builtin(workgroup_id) workgroup_id : vec3, - @builtin(local_invocation_id) local_id : vec3, @builtin(local_invocation_index) local_index : u32) { - let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u + - workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; - - let batchIndex = workgroup_id.z / ${parameters.numHeads}; - let headNumber = workgroup_id.z % ${parameters.numHeads}; + ${shaderHelper.registerUniforms(uniforms).declareVariables(input, weight, bias, outputQ, outputK, outputV)} + ${shaderHelper.mainStart([ + TILE_SIZE, TILE_SIZE, 1 + ])} + let batchIndex = workgroup_id.z / uniforms.num_heads; + let headNumber = workgroup_id.z % uniforms.num_heads; let m = workgroup_id.y * TILE_SIZE + local_id.y; let n = workgroup_id.x * TILE_SIZE + local_id.x; - let inputOffset = batchIndex * (M * K) + m * K; - let biasOffsetQ = headNumber * ${parameters.headSize}; - let biasOffsetK = ${parameters.hiddenSize} + biasOffsetQ; - let biasOffsetV = ${parameters.hiddenSize} + biasOffsetK; + let inputOffset = batchIndex * (uniforms.M * uniforms.K) + m * uniforms.K; + let biasOffsetQ = headNumber * uniforms.head_size; + let biasOffsetK = uniforms.hidden_size + biasOffsetQ; + let biasOffsetV = uniforms.hidden_size + biasOffsetK; var valueQ = ${dataType}(0); var valueK = ${dataType}(0); var valueV = ${dataType}(0); - for (var w: u32 = 0u; w < K; w += TILE_SIZE) { - if (m < M && w + local_id.x < K) { + for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) { + if (m < uniforms.M && w + local_id.x < uniforms.K) { tileInput[TILE_SIZE * local_id.y + local_id.x] = input[inputOffset + w + local_id.x]; } - if (n < N && w + local_id.y < K) { - let offset = n + (w + local_id.y) * ldb; + if (n < uniforms.N && w + local_id.y < uniforms.K) { + let offset = n + (w + local_id.y) * uniforms.ldb; tileWeightQ[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetQ + offset]; tileWeightK[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetK + offset]; tileWeightV[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetV + offset]; } workgroupBarrier(); - for (var k: u32 = 0u; k { workgroupBarrier(); } - let headOffset = (m * N + n) % ${parameters.headSize}; + let headOffset = (m * uniforms.N + n) % uniforms.head_size; valueQ += bias[headOffset + biasOffsetQ]; valueK += bias[headOffset + biasOffsetK]; valueV += bias[headOffset + biasOffsetV]; - let offset = workgroup_id.z * M * N; - if (m < M && n < N) { - let outputIdx = offset + m * N + n; - outputQ[outputIdx] = valueQ; - outputK[outputIdx] = valueK; - outputV[outputIdx] = valueV; + let offset = workgroup_id.z * uniforms.M * uniforms.N; + if (m < uniforms.M && n < uniforms.N) { + let outputIdx = offset + m * uniforms.N + n; + output_q[outputIdx] = valueQ; + output_k[outputIdx] = valueK; + output_v[outputIdx] = valueV; } }`; - - const inputs = [context.inputs[0], context.inputs[1], context.inputs[2]]; + }; return context.compute( { name: 'AttentionPrepare', - shaderCache: {hint: JSON.stringify(parameters)}, + shaderCache: {inputDependencies: ['type', 'type', 'type']}, getRunData: () => ({ outputs: [ {dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default}, @@ -619,6 +616,7 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters) => { {dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default}, ], dispatchGroup: dispatch, + programUniforms }), getShaderSource, }, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 3ce114c5d3..bc3265be95 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -780,8 +780,10 @@ class ShaderHelperImpl implements ShaderHelper { const is1DimensionDispatch = this.normalizedDispatchGroup[1] === 1 && this.normalizedDispatchGroup[2] === 1; const paramList = is1DimensionDispatch ? `@builtin(global_invocation_id) global_id : vec3, + @builtin(workgroup_id) workgroup_id : vec3, @builtin(local_invocation_id) local_id : vec3` : - `@builtin(local_invocation_index) local_idx : u32, + `@builtin(local_invocation_id) local_id : vec3, + @builtin(local_invocation_index) local_idx : u32, @builtin(workgroup_id) workgroup_id : vec3, @builtin(num_workgroups) num_workgroups : vec3`; const globalIdxDefinition = is1DimensionDispatch ? diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts index b7726a36bc..6d22e3780e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts @@ -4,10 +4,10 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, GpuDataType} from '../types'; +import {ComputeContext, GpuDataType, ProgramUniform} from '../types'; import {applyAttention, AttentionAttrs, AttentionMaskType, AttentionParameters, AttentionQkvFormat} from './attention'; -import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; +import {inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; import {createTransposeProgramInfo, TransposeAttributes} from './transpose'; const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => { @@ -228,7 +228,6 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr }; }; - export const parseMultiHeadAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs => createAttributeWithCacheKey({...attributes}); @@ -239,30 +238,35 @@ const addBiasTranspose = hiddenSize: number, biasOffset: number) => { const outputShape = [batchSize, sequenceLength, hiddenSize]; const outputSize = ShapeUtil.size(outputShape); + const programUniforms: ProgramUniform[] = + [{type: 'uint32', data: outputSize}, {type: 'uint32', data: biasOffset}, {type: 'uint32', data: hiddenSize}]; - const dataType = tensorTypeToWsglStorageType(qkv.dataType); - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const biasOffset = ${biasOffset}u; - const hiddenSize = ${hiddenSize}u; - - @group(0) @binding(0) var qkv: array<${dataType}>; - @group(0) @binding(1) var bias: array<${dataType}>; - @group(0) @binding(2) var qkv_with_bias: array<${dataType}>; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const output = outputVariable('qkv_with_bias', qkv.dataType, outputShape); + const qkvInput = inputVariable('qkv', qkv.dataType, outputShape); + const biasInput = inputVariable('bias', bias.dataType, outputShape); + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, {name: 'bias_offset', type: 'u32'}, {name: 'hidden_size', type: 'u32'} + ]; + return ` + ${shaderHelper.registerUniforms(uniforms).declareVariables(qkvInput, biasInput, output)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} - let biasOffsetIdx = (global_idx % hiddenSize) + biasOffset; + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + let bias_offset_idx = (global_idx % uniforms.hidden_size) + uniforms.bias_offset; - qkv_with_bias[global_idx] = qkv[global_idx] + bias[biasOffsetIdx]; + qkv_with_bias[global_idx] = qkv[global_idx] + bias[bias_offset_idx]; }`; + }; return context.compute( { name: 'MultiHeadAttentionAddBias', - shaderCache: {hint: JSON.stringify({batchSize, sequenceLength, hiddenSize, biasOffset})}, + shaderCache: {inputDependencies: ['type', 'type']}, getRunData: () => ({ outputs: [{dims: outputShape, dataType: qkv.dataType, gpuDataType: GpuDataType.default}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms }), getShaderSource, },