mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-07 00:13:17 +00:00
[js/webgpu] Support uniforms for attention and multihead attention (#18903)
This commit is contained in:
parent
ab897a4a40
commit
dee6a5b371
4 changed files with 187 additions and 183 deletions
|
|
@ -2,7 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
import {argMax, argMin, parseArgMinMaxAttributes} from './ops/argminmax';
|
||||
import {attention, parseAttentionAttributes} from './ops/attention';
|
||||
import {attention} from './ops/attention';
|
||||
import {batchNorm} from './ops/batch-norm';
|
||||
import {biasAdd} from './ops/bias-add';
|
||||
import {biasSplitGelu} from './ops/bias-split-gelu';
|
||||
|
|
@ -50,7 +50,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
|
|||
['Asinh', [unaryOps.asinh]],
|
||||
['Atan', [unaryOps.atan]],
|
||||
['Atanh', [unaryOps.atanh]],
|
||||
['Attention', [attention, parseAttentionAttributes]],
|
||||
['Attention', [attention]],
|
||||
// TODO: support new attributes for AveragePool-10
|
||||
['AveragePool', [pool.averagePool, pool.parseAveragePoolAttributes]],
|
||||
['BatchNormalization', [batchNorm]],
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {tensorDataTypeEnumToString} from '../../../wasm-common';
|
||||
import {TensorView} from '../../tensor-view';
|
||||
import {createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext, GpuDataType} from '../types';
|
||||
import {ComputeContext, GpuDataType, ProgramUniform} from '../types';
|
||||
|
||||
import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common';
|
||||
import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, tensorTypeToWsglValueType, UniformDataElementType, UniformsArrayType} from './common';
|
||||
|
||||
export const enum AttentionQkvFormat {
|
||||
unknown, // enum value not set, or depends on qkv projection implementation details
|
||||
|
|
@ -231,20 +231,8 @@ const validateAttentionInputs = (inputs: readonly TensorView[], attributes: Atte
|
|||
};
|
||||
};
|
||||
|
||||
export const parseAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs =>
|
||||
createAttributeWithCacheKey({...attributes});
|
||||
|
||||
export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView, n: number, d: number) => {
|
||||
const components = getMaxComponents(d);
|
||||
const inputHelper = outputVariable('x', input.dataType, input.dims, components);
|
||||
|
||||
let threadMaxValue = 'threadMaxVector';
|
||||
if (components === 2) {
|
||||
threadMaxValue = 'max(threadMaxVector.x, threadMaxVector.y)';
|
||||
} else if (components === 4) {
|
||||
threadMaxValue = 'max(max(threadMaxVector.x, threadMaxVector.y), max(threadMaxVector.z, threadMaxVector.w))';
|
||||
}
|
||||
const dataType = tensorTypeToWsglStorageType(input.dataType);
|
||||
let WG = 64;
|
||||
const dComp = d / components;
|
||||
if (dComp < WG) {
|
||||
|
|
@ -253,25 +241,41 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView
|
|||
WG = Math.ceil(dComp / 8);
|
||||
}
|
||||
const elementsPerWG = Math.ceil(d / components / WG);
|
||||
const tensorDataType = tensorDataTypeEnumToString(input.dataType) as ProgramUniform['type'];
|
||||
const programUniforms: ProgramUniform[] =
|
||||
[{type: tensorDataType, data: 1 / d}, {type: 'uint32', data: dComp}, {type: 'uint32', data: elementsPerWG}];
|
||||
const dataType = tensorTypeToWsglStorageType(input.dataType, components);
|
||||
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
const dInv: ${dataType} = 1 / ${d};
|
||||
const dComp = ${d / components};
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
const inputHelper = outputVariable('x', input.dataType, input.dims, components);
|
||||
let threadMaxValue = 'thread_max_vector';
|
||||
if (components === 2) {
|
||||
threadMaxValue = 'max(thread_max_vector.x, thread_max_vector.y)';
|
||||
} else if (components === 4) {
|
||||
threadMaxValue =
|
||||
'max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))';
|
||||
}
|
||||
const elemValueType = tensorTypeToWsglValueType(input.dataType);
|
||||
const uniforms: UniformsArrayType = [
|
||||
{name: 'd_inv', type: elemValueType as UniformDataElementType}, {name: 'd_comp', type: 'u32'},
|
||||
{name: 'elements_per_wg', type: 'u32'}
|
||||
];
|
||||
|
||||
return `
|
||||
var<workgroup> wgMax: array<f32, ${WG}>;
|
||||
var<workgroup> wgSum: array<f32, ${WG}>;
|
||||
${shaderHelper.registerUniforms(uniforms).declareVariables(inputHelper)}
|
||||
${shaderHelper.mainStart([
|
||||
WG, 1, 1
|
||||
])}
|
||||
let localOffset = local_idx * uniforms.elements_per_wg;
|
||||
let offset: u32 = workgroup_id.x * uniforms.d_comp + localOffset;
|
||||
|
||||
${shaderHelper.declareVariables(inputHelper)}
|
||||
@compute @workgroup_size(${WG}, 1, 1)
|
||||
fn main(@builtin(workgroup_id) workgroup_id : vec3<u32>,
|
||||
@builtin(local_invocation_index) local_index : u32) {
|
||||
let localOffset = local_index * ${elementsPerWG};
|
||||
let offset: u32 = workgroup_id.x * dComp + localOffset;
|
||||
|
||||
var threadMaxVector = ${fillVector('f32', components, '-3.402823e+38f')};
|
||||
for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) {
|
||||
threadMaxVector = max(${castToF32(dataType, components, 'x[offset + i]')}, threadMaxVector);
|
||||
var thread_max_vector = ${fillVector('f32', components, '-3.402823e+38f')};
|
||||
for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) {
|
||||
thread_max_vector = max(${castToF32(elemValueType, components, 'x[offset + i]')}, thread_max_vector);
|
||||
}
|
||||
wgMax[local_index] = ${threadMaxValue};
|
||||
wgMax[local_idx] = ${threadMaxValue};
|
||||
workgroupBarrier();
|
||||
|
||||
var maxValue = -3.402823e+38f;
|
||||
|
|
@ -280,10 +284,10 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView
|
|||
}
|
||||
|
||||
var sumVector = ${fillVector('f32', components, '0')};
|
||||
for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) {
|
||||
sumVector += exp(${castToF32(dataType, components, 'x[offset + i]')} - maxValue);
|
||||
for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) {
|
||||
sumVector += exp(${castToF32(elemValueType, components, 'x[offset + i]')} - maxValue);
|
||||
}
|
||||
wgSum[local_index] = ${sumVector('sumVector', components)};
|
||||
wgSum[local_idx] = ${sumVector('sumVector', components)};
|
||||
workgroupBarrier();
|
||||
|
||||
var sum: f32 = 0;
|
||||
|
|
@ -292,26 +296,24 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView
|
|||
}
|
||||
|
||||
if (sum == 0) {
|
||||
for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) {
|
||||
x[offset + i] = ${fillVector(dataType, components, 'dInv')};
|
||||
for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) {
|
||||
x[offset + i] = ${fillVector('f32', components, 'uniforms.d_inv')};
|
||||
}
|
||||
} else {
|
||||
for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) {
|
||||
let f32input = ${castToF32(dataType, components, 'x[offset + i]')};
|
||||
for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) {
|
||||
let f32input = ${castToF32(elemValueType, components, 'x[offset + i]')};
|
||||
x[offset + i] = ${inputHelper.type.value}(exp(f32input - maxValue) / sum);
|
||||
}
|
||||
}
|
||||
}`;
|
||||
};
|
||||
|
||||
context.compute(
|
||||
{
|
||||
name: 'AttentionProbsSoftmax',
|
||||
shaderCache: {hint: `${d}`},
|
||||
shaderCache: {hint: `${WG};${dataType};${components}`},
|
||||
getShaderSource,
|
||||
getRunData: () => ({
|
||||
outputs: [],
|
||||
dispatchGroup: {x: n},
|
||||
}),
|
||||
getRunData: () => ({outputs: [], dispatchGroup: {x: n}, programUniforms}),
|
||||
},
|
||||
{inputs: [input], outputs: []});
|
||||
};
|
||||
|
|
@ -326,47 +328,43 @@ const computeAttentionProbs =
|
|||
// TODO: handle mask
|
||||
|
||||
const alpha = attributes.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : attributes.scale;
|
||||
|
||||
const dataType = tensorTypeToWsglStorageType(q.dataType);
|
||||
|
||||
const components = getMaxComponents(parameters.headSize);
|
||||
const qInput = inputVariable('q', q.dataType, q.dims, components);
|
||||
const kInput = inputVariable('key', key.dataType, key.dims, components);
|
||||
const output = outputVariable('output', q.dataType, probsShape);
|
||||
|
||||
const vectorizedHeadSize = parameters.headSize / components;
|
||||
const M = parameters.sequenceLength;
|
||||
const N = parameters.totalSequenceLength;
|
||||
const K = vectorizedHeadSize;
|
||||
|
||||
const TILE_SIZE = 12;
|
||||
|
||||
const dispatch = {
|
||||
x: Math.ceil(parameters.totalSequenceLength / TILE_SIZE),
|
||||
y: Math.ceil(parameters.sequenceLength / TILE_SIZE),
|
||||
z: parameters.batchSize * parameters.numHeads
|
||||
};
|
||||
const tensorDataType = tensorDataTypeEnumToString(q.dataType) as ProgramUniform['type'];
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: 'uint32', data: parameters.sequenceLength}, {type: 'uint32', data: vectorizedHeadSize},
|
||||
{type: 'uint32', data: parameters.totalSequenceLength}, {type: 'uint32', data: parameters.kvSequenceLength},
|
||||
{type: tensorDataType, data: alpha}
|
||||
];
|
||||
|
||||
const inputs = [q, key];
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
const M: u32 = ${M}u;
|
||||
const N: u32 = ${N}u;
|
||||
const K: u32 = ${K}u;
|
||||
const alpha: ${dataType} = ${alpha};
|
||||
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
const qInput = inputVariable('q', q.dataType, q.dims, components);
|
||||
const kInput = inputVariable('key', key.dataType, key.dims, components);
|
||||
const output = outputVariable('output', q.dataType, probsShape);
|
||||
const dataType = tensorTypeToWsglStorageType(q.dataType);
|
||||
|
||||
const uniforms: UniformsArrayType = [
|
||||
{name: 'M', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'},
|
||||
{name: 'kv_sequence_length', type: 'u32'}, {name: 'alpha', type: dataType as UniformDataElementType}
|
||||
];
|
||||
return `
|
||||
const beta: ${dataType} = 1.0;
|
||||
const TILE_SIZE = ${TILE_SIZE}u;
|
||||
|
||||
var<workgroup> tileQ: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>;
|
||||
var<workgroup> tileK: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>;
|
||||
|
||||
${shaderHelper.declareVariables(qInput, kInput, output)}
|
||||
|
||||
@compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1)
|
||||
fn main(@builtin(workgroup_id) workgroup_id : vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id : vec3<u32>, @builtin(local_invocation_index) local_index : u32) {
|
||||
let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u +
|
||||
workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index;
|
||||
|
||||
${shaderHelper.registerUniforms(uniforms).declareVariables(qInput, kInput, output)}
|
||||
${shaderHelper.mainStart([
|
||||
TILE_SIZE, TILE_SIZE, 1
|
||||
])}
|
||||
// x holds the N and y holds the M
|
||||
let headIdx = workgroup_id.z;
|
||||
let m = workgroup_id.y * TILE_SIZE;
|
||||
|
|
@ -374,40 +372,42 @@ const computeAttentionProbs =
|
|||
let lm = m + local_id.y;
|
||||
let ln = n + local_id.x;
|
||||
|
||||
let qOffset = ${parameters.sequenceLength * vectorizedHeadSize} * headIdx + m * K;
|
||||
let kOffset = ${parameters.kvSequenceLength * vectorizedHeadSize} * headIdx + n * K;
|
||||
let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K;
|
||||
let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx + n * uniforms.K;
|
||||
|
||||
var value = ${fillVector(dataType, components)};
|
||||
for (var w: u32 = 0u; w < K; w += TILE_SIZE) {
|
||||
if (m + local_id.y < M && w + local_id.x < K) {
|
||||
tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * K + w + local_id.x];
|
||||
for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
|
||||
if (m + local_id.y < uniforms.M && w + local_id.x < uniforms.K) {
|
||||
tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];
|
||||
}
|
||||
if (n + local_id.y < N && w + local_id.x < K) {
|
||||
tileK[TILE_SIZE * local_id.y + local_id.x] = key[kOffset + local_id.y * K + w + local_id.x];
|
||||
if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {
|
||||
tileK[TILE_SIZE * local_id.y + local_id.x] = key[kOffset + local_id.y * uniforms.K + w + local_id.x];
|
||||
}
|
||||
workgroupBarrier();
|
||||
|
||||
for (var k: u32 = 0u; k<TILE_SIZE && w+k < K; k++) {
|
||||
for (var k: u32 = 0u; k<TILE_SIZE && w+k < uniforms.K; k++) {
|
||||
value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k];
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
let headOffset = headIdx * M * N;
|
||||
if (lm < M && ln < N) {
|
||||
let outputIdx = headOffset + lm * N + ln;
|
||||
output[outputIdx] = ${sumVector('value', components)} * alpha;
|
||||
let headOffset = headIdx * uniforms.M * uniforms.N;
|
||||
if (lm < uniforms.M && ln < uniforms.N) {
|
||||
let outputIdx = headOffset + lm * uniforms.N + ln;
|
||||
output[outputIdx] = ${sumVector('value', components)} * uniforms.alpha;
|
||||
}
|
||||
}`;
|
||||
};
|
||||
|
||||
const probs = context.compute(
|
||||
{
|
||||
name: 'AttentionProbs',
|
||||
shaderCache: {hint: JSON.stringify(parameters)},
|
||||
shaderCache: {hint: `${components}`, inputDependencies: ['type', 'type']},
|
||||
getRunData: () => ({
|
||||
outputs: [{dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default}],
|
||||
dispatchGroup: dispatch,
|
||||
programUniforms
|
||||
}),
|
||||
getShaderSource,
|
||||
},
|
||||
|
|
@ -423,78 +423,76 @@ const computeAttentionProbs =
|
|||
const computeVxAttentionScore =
|
||||
(context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters) => {
|
||||
const outputShape = [params.batchSize, params.sequenceLength, params.vHiddenSize];
|
||||
|
||||
const probsHelper = inputVariable('probs', probs.dataType, probs.dims);
|
||||
const vHelper = inputVariable('v', v.dataType, v.dims);
|
||||
const output = outputVariable('output', probs.dataType, outputShape);
|
||||
|
||||
const dataType = tensorTypeToWsglStorageType(probs.dataType);
|
||||
|
||||
const TILE_SIZE = 12;
|
||||
const dispatch = {
|
||||
x: Math.ceil(params.vHeadSize / TILE_SIZE),
|
||||
y: Math.ceil(params.sequenceLength / TILE_SIZE),
|
||||
z: params.batchSize * params.numHeads
|
||||
};
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: 'uint32', data: params.sequenceLength}, {type: 'uint32', data: params.totalSequenceLength},
|
||||
{type: 'uint32', data: params.vHeadSize}, {type: 'uint32', data: params.numHeads},
|
||||
{type: 'uint32', data: params.vHiddenSize}
|
||||
];
|
||||
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
const M: u32 = ${params.sequenceLength}u;
|
||||
const N: u32 = ${params.vHeadSize}u;
|
||||
const K: u32 = ${params.totalSequenceLength}u;
|
||||
const numHeads: u32 = ${params.numHeads}u;
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
const probsHelper = inputVariable('probs', probs.dataType, probs.dims);
|
||||
const vHelper = inputVariable('v', v.dataType, v.dims);
|
||||
const output = outputVariable('output', probs.dataType, outputShape);
|
||||
const uniforms: UniformsArrayType = [
|
||||
{name: 'M', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'},
|
||||
{name: 'num_heads', type: 'u32'}, {name: 'v_hidden_size', type: 'u32'}
|
||||
];
|
||||
return `
|
||||
const TILE_SIZE = ${TILE_SIZE}u;
|
||||
|
||||
var<workgroup> tileQ: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>;
|
||||
var<workgroup> tileK: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>;
|
||||
|
||||
${shaderHelper.declareVariables(probsHelper, vHelper, output)}
|
||||
|
||||
@compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1)
|
||||
fn main(@builtin(workgroup_id) workgroup_id : vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id : vec3<u32>, @builtin(local_invocation_index) local_index : u32) {
|
||||
let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u +
|
||||
workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index;
|
||||
|
||||
var<workgroup> tileQ: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>;
|
||||
var<workgroup> tileK: array<${probsHelper.type.value}, ${TILE_SIZE * TILE_SIZE}>;
|
||||
${shaderHelper.registerUniforms(uniforms).declareVariables(probsHelper, vHelper, output)}
|
||||
${shaderHelper.mainStart([
|
||||
TILE_SIZE, TILE_SIZE, 1
|
||||
])}
|
||||
let headIdx = workgroup_id.z;
|
||||
let m = workgroup_id.y * TILE_SIZE + local_id.y;
|
||||
let n = workgroup_id.x * TILE_SIZE + local_id.x;
|
||||
|
||||
let offsetA = headIdx * (M * K) + m * K;
|
||||
let offsetB = headIdx * (N * K) + n;
|
||||
let offsetA = headIdx * (uniforms.M * uniforms.K) + m * uniforms.K;
|
||||
let offsetB = headIdx * (uniforms.N * uniforms.K) + n;
|
||||
|
||||
var value = ${dataType}(0);
|
||||
for (var w: u32 = 0u; w < K; w += TILE_SIZE) {
|
||||
if (m < M && w + local_id.x < K) {
|
||||
var value = ${probsHelper.type.storage}(0);
|
||||
for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
|
||||
if (m < uniforms.M && w + local_id.x < uniforms.K) {
|
||||
tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x];
|
||||
}
|
||||
if (n < N && w + local_id.y < K) {
|
||||
tileK[TILE_SIZE * local_id.y + local_id.x] = v[offsetB + (w + local_id.y) * N];
|
||||
if (n < uniforms.N && w + local_id.y < uniforms.K) {
|
||||
tileK[TILE_SIZE * local_id.y + local_id.x] = v[offsetB + (w + local_id.y) * uniforms.N];
|
||||
}
|
||||
workgroupBarrier();
|
||||
for (var k: u32 = 0u; k<TILE_SIZE && w+k < K; k++) {
|
||||
for (var k: u32 = 0u; k<TILE_SIZE && w+k < uniforms.K; k++) {
|
||||
value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x];
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
// we need to transpose output from BNSH_v to BSND_v
|
||||
let batchIdx = workgroup_id.z / ${params.numHeads};
|
||||
let currentBatchHeadNumber = workgroup_id.z % ${params.numHeads};
|
||||
let headOffset = (batchIdx * M * ${params.numHeads} + currentBatchHeadNumber) * ${params.vHeadSize};
|
||||
if (m < M && n < N) {
|
||||
let outputIdx = batchIdx * ${params.sequenceLength * params.vHiddenSize} + m * ${params.vHiddenSize}
|
||||
+ currentBatchHeadNumber * ${params.vHeadSize} + n;
|
||||
let batchIdx = workgroup_id.z / uniforms.num_heads;
|
||||
let currentBatchHeadNumber = workgroup_id.z % uniforms.num_heads;
|
||||
let headOffset = (batchIdx * uniforms.M * uniforms.num_heads + currentBatchHeadNumber) * uniforms.N;
|
||||
if (m < uniforms.M && n < uniforms.N) {
|
||||
let outputIdx = batchIdx * uniforms.M *uniforms.v_hidden_size + m * uniforms.v_hidden_size
|
||||
+ currentBatchHeadNumber * uniforms.N + n;
|
||||
output[outputIdx] = value;
|
||||
}
|
||||
}`;
|
||||
};
|
||||
|
||||
return context.compute(
|
||||
{
|
||||
name: 'AttentionScore',
|
||||
shaderCache: {hint: JSON.stringify(params)},
|
||||
shaderCache: {inputDependencies: ['type', 'type']},
|
||||
getRunData: () => ({
|
||||
outputs: [{dims: outputShape, dataType: probs.dataType, gpuDataType: GpuDataType.default}],
|
||||
dispatchGroup: dispatch,
|
||||
programUniforms
|
||||
}),
|
||||
getShaderSource,
|
||||
},
|
||||
|
|
@ -517,71 +515,71 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters) => {
|
|||
parameters.sequenceLength,
|
||||
parameters.headSize,
|
||||
];
|
||||
|
||||
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
|
||||
|
||||
const M = parameters.sequenceLength;
|
||||
const K = parameters.inputHiddenSize;
|
||||
const N = parameters.headSize;
|
||||
|
||||
const TILE_SIZE = 12;
|
||||
const dispatch = {
|
||||
x: Math.ceil(parameters.headSize / TILE_SIZE),
|
||||
y: Math.ceil(parameters.sequenceLength / TILE_SIZE),
|
||||
z: parameters.batchSize * parameters.numHeads
|
||||
};
|
||||
const inputs = [context.inputs[0], context.inputs[1], context.inputs[2]];
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: 'uint32', data: M}, {type: 'uint32', data: K}, {type: 'uint32', data: N},
|
||||
{type: 'uint32', data: parameters.numHeads}, {type: 'uint32', data: parameters.headSize},
|
||||
{type: 'uint32', data: parameters.hiddenSize},
|
||||
{type: 'uint32', data: parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize}
|
||||
];
|
||||
|
||||
const getShaderSource = () => `
|
||||
const M: u32 = ${M}u;
|
||||
const K: u32 = ${K}u;
|
||||
const N: u32 = ${N}u;
|
||||
const numHeads: u32 = ${parameters.numHeads};
|
||||
const ldb = ${parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize}u;
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
const outputQ = outputVariable('output_q', inputs[0].dataType, outputShape);
|
||||
const outputK = outputVariable('output_k', inputs[0].dataType, outputShape);
|
||||
const outputV = outputVariable('output_v', inputs[0].dataType, outputShape);
|
||||
const input = inputVariable('input', inputs[0].dataType, inputs[0].dims);
|
||||
const weight = inputVariable('weight', inputs[1].dataType, inputs[1].dims);
|
||||
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims);
|
||||
const dataType = input.type.storage;
|
||||
|
||||
const uniforms: UniformsArrayType = [
|
||||
{name: 'M', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'}, {name: 'num_heads', type: 'u32'},
|
||||
{name: 'head_size', type: 'u32'}, {name: 'hidden_size', type: 'u32'}, {name: 'ldb', type: 'u32'}
|
||||
];
|
||||
return `
|
||||
const TILE_SIZE = ${TILE_SIZE}u;
|
||||
|
||||
var<workgroup> tileInput: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>;
|
||||
var<workgroup> tileWeightQ: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>;
|
||||
var<workgroup> tileWeightK: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>;
|
||||
var<workgroup> tileWeightV: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>;
|
||||
|
||||
@group(0) @binding(0) var<storage, read> input: array<${dataType}>;
|
||||
@group(0) @binding(1) var<storage, read> weight: array<${dataType}>;
|
||||
@group(0) @binding(2) var<storage, read> bias: array<${dataType}>;
|
||||
@group(0) @binding(3) var<storage, read_write> outputQ: array<${dataType}>;
|
||||
@group(0) @binding(4) var<storage, read_write> outputK: array<${dataType}>;
|
||||
@group(0) @binding(5) var<storage, read_write> outputV: array<${dataType}>;
|
||||
|
||||
@compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1)
|
||||
fn main(@builtin(workgroup_id) workgroup_id : vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id : vec3<u32>, @builtin(local_invocation_index) local_index : u32) {
|
||||
let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u +
|
||||
workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index;
|
||||
|
||||
let batchIndex = workgroup_id.z / ${parameters.numHeads};
|
||||
let headNumber = workgroup_id.z % ${parameters.numHeads};
|
||||
${shaderHelper.registerUniforms(uniforms).declareVariables(input, weight, bias, outputQ, outputK, outputV)}
|
||||
${shaderHelper.mainStart([
|
||||
TILE_SIZE, TILE_SIZE, 1
|
||||
])}
|
||||
let batchIndex = workgroup_id.z / uniforms.num_heads;
|
||||
let headNumber = workgroup_id.z % uniforms.num_heads;
|
||||
let m = workgroup_id.y * TILE_SIZE + local_id.y;
|
||||
let n = workgroup_id.x * TILE_SIZE + local_id.x;
|
||||
|
||||
let inputOffset = batchIndex * (M * K) + m * K;
|
||||
let biasOffsetQ = headNumber * ${parameters.headSize};
|
||||
let biasOffsetK = ${parameters.hiddenSize} + biasOffsetQ;
|
||||
let biasOffsetV = ${parameters.hiddenSize} + biasOffsetK;
|
||||
let inputOffset = batchIndex * (uniforms.M * uniforms.K) + m * uniforms.K;
|
||||
let biasOffsetQ = headNumber * uniforms.head_size;
|
||||
let biasOffsetK = uniforms.hidden_size + biasOffsetQ;
|
||||
let biasOffsetV = uniforms.hidden_size + biasOffsetK;
|
||||
|
||||
var valueQ = ${dataType}(0);
|
||||
var valueK = ${dataType}(0);
|
||||
var valueV = ${dataType}(0);
|
||||
for (var w: u32 = 0u; w < K; w += TILE_SIZE) {
|
||||
if (m < M && w + local_id.x < K) {
|
||||
for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
|
||||
if (m < uniforms.M && w + local_id.x < uniforms.K) {
|
||||
tileInput[TILE_SIZE * local_id.y + local_id.x] = input[inputOffset + w + local_id.x];
|
||||
}
|
||||
if (n < N && w + local_id.y < K) {
|
||||
let offset = n + (w + local_id.y) * ldb;
|
||||
if (n < uniforms.N && w + local_id.y < uniforms.K) {
|
||||
let offset = n + (w + local_id.y) * uniforms.ldb;
|
||||
tileWeightQ[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetQ + offset];
|
||||
tileWeightK[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetK + offset];
|
||||
tileWeightV[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetV + offset];
|
||||
}
|
||||
workgroupBarrier();
|
||||
for (var k: u32 = 0u; k<TILE_SIZE && w+k < K; k++) {
|
||||
for (var k: u32 = 0u; k<TILE_SIZE && w+k < uniforms.K; k++) {
|
||||
let inputTileOffset = TILE_SIZE * local_id.y + k;
|
||||
let weightTileOffset = TILE_SIZE * k + local_id.x;
|
||||
valueQ += tileInput[inputTileOffset] * tileWeightQ[weightTileOffset];
|
||||
|
|
@ -592,26 +590,25 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters) => {
|
|||
workgroupBarrier();
|
||||
}
|
||||
|
||||
let headOffset = (m * N + n) % ${parameters.headSize};
|
||||
let headOffset = (m * uniforms.N + n) % uniforms.head_size;
|
||||
valueQ += bias[headOffset + biasOffsetQ];
|
||||
valueK += bias[headOffset + biasOffsetK];
|
||||
valueV += bias[headOffset + biasOffsetV];
|
||||
|
||||
let offset = workgroup_id.z * M * N;
|
||||
if (m < M && n < N) {
|
||||
let outputIdx = offset + m * N + n;
|
||||
outputQ[outputIdx] = valueQ;
|
||||
outputK[outputIdx] = valueK;
|
||||
outputV[outputIdx] = valueV;
|
||||
let offset = workgroup_id.z * uniforms.M * uniforms.N;
|
||||
if (m < uniforms.M && n < uniforms.N) {
|
||||
let outputIdx = offset + m * uniforms.N + n;
|
||||
output_q[outputIdx] = valueQ;
|
||||
output_k[outputIdx] = valueK;
|
||||
output_v[outputIdx] = valueV;
|
||||
}
|
||||
}`;
|
||||
|
||||
const inputs = [context.inputs[0], context.inputs[1], context.inputs[2]];
|
||||
};
|
||||
|
||||
return context.compute(
|
||||
{
|
||||
name: 'AttentionPrepare',
|
||||
shaderCache: {hint: JSON.stringify(parameters)},
|
||||
shaderCache: {inputDependencies: ['type', 'type', 'type']},
|
||||
getRunData: () => ({
|
||||
outputs: [
|
||||
{dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default},
|
||||
|
|
@ -619,6 +616,7 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters) => {
|
|||
{dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default},
|
||||
],
|
||||
dispatchGroup: dispatch,
|
||||
programUniforms
|
||||
}),
|
||||
getShaderSource,
|
||||
},
|
||||
|
|
|
|||
|
|
@ -780,8 +780,10 @@ class ShaderHelperImpl implements ShaderHelper {
|
|||
|
||||
const is1DimensionDispatch = this.normalizedDispatchGroup[1] === 1 && this.normalizedDispatchGroup[2] === 1;
|
||||
const paramList = is1DimensionDispatch ? `@builtin(global_invocation_id) global_id : vec3<u32>,
|
||||
@builtin(workgroup_id) workgroup_id : vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id : vec3<u32>` :
|
||||
`@builtin(local_invocation_index) local_idx : u32,
|
||||
`@builtin(local_invocation_id) local_id : vec3<u32>,
|
||||
@builtin(local_invocation_index) local_idx : u32,
|
||||
@builtin(workgroup_id) workgroup_id : vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups : vec3<u32>`;
|
||||
const globalIdxDefinition = is1DimensionDispatch ?
|
||||
|
|
|
|||
|
|
@ -4,10 +4,10 @@
|
|||
import {TensorView} from '../../tensor-view';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext, GpuDataType} from '../types';
|
||||
import {ComputeContext, GpuDataType, ProgramUniform} from '../types';
|
||||
|
||||
import {applyAttention, AttentionAttrs, AttentionMaskType, AttentionParameters, AttentionQkvFormat} from './attention';
|
||||
import {ShaderHelper, tensorTypeToWsglStorageType} from './common';
|
||||
import {inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common';
|
||||
import {createTransposeProgramInfo, TransposeAttributes} from './transpose';
|
||||
|
||||
const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => {
|
||||
|
|
@ -228,7 +228,6 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr
|
|||
};
|
||||
};
|
||||
|
||||
|
||||
export const parseMultiHeadAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs =>
|
||||
createAttributeWithCacheKey({...attributes});
|
||||
|
||||
|
|
@ -239,30 +238,35 @@ const addBiasTranspose =
|
|||
hiddenSize: number, biasOffset: number) => {
|
||||
const outputShape = [batchSize, sequenceLength, hiddenSize];
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
const programUniforms: ProgramUniform[] =
|
||||
[{type: 'uint32', data: outputSize}, {type: 'uint32', data: biasOffset}, {type: 'uint32', data: hiddenSize}];
|
||||
|
||||
const dataType = tensorTypeToWsglStorageType(qkv.dataType);
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
const biasOffset = ${biasOffset}u;
|
||||
const hiddenSize = ${hiddenSize}u;
|
||||
|
||||
@group(0) @binding(0) var<storage, read> qkv: array<${dataType}>;
|
||||
@group(0) @binding(1) var<storage, read> bias: array<${dataType}>;
|
||||
@group(0) @binding(2) var<storage, read_write> qkv_with_bias: array<${dataType}>;
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
const output = outputVariable('qkv_with_bias', qkv.dataType, outputShape);
|
||||
const qkvInput = inputVariable('qkv', qkv.dataType, outputShape);
|
||||
const biasInput = inputVariable('bias', bias.dataType, outputShape);
|
||||
|
||||
const uniforms: UniformsArrayType = [
|
||||
{name: 'output_size', type: 'u32'}, {name: 'bias_offset', type: 'u32'}, {name: 'hidden_size', type: 'u32'}
|
||||
];
|
||||
return `
|
||||
${shaderHelper.registerUniforms(uniforms).declareVariables(qkvInput, biasInput, output)}
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
|
||||
let biasOffsetIdx = (global_idx % hiddenSize) + biasOffset;
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
|
||||
let bias_offset_idx = (global_idx % uniforms.hidden_size) + uniforms.bias_offset;
|
||||
|
||||
qkv_with_bias[global_idx] = qkv[global_idx] + bias[biasOffsetIdx];
|
||||
qkv_with_bias[global_idx] = qkv[global_idx] + bias[bias_offset_idx];
|
||||
}`;
|
||||
};
|
||||
|
||||
return context.compute(
|
||||
{
|
||||
name: 'MultiHeadAttentionAddBias',
|
||||
shaderCache: {hint: JSON.stringify({batchSize, sequenceLength, hiddenSize, biasOffset})},
|
||||
shaderCache: {inputDependencies: ['type', 'type']},
|
||||
getRunData: () => ({
|
||||
outputs: [{dims: outputShape, dataType: qkv.dataType, gpuDataType: GpuDataType.default}],
|
||||
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
|
||||
programUniforms
|
||||
}),
|
||||
getShaderSource,
|
||||
},
|
||||
|
|
|
|||
Loading…
Reference in a new issue