mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[js/webgpu] Support GroupQueryAttention (#20237)
TODOs: 1. Handle H * params.kvNumHeads greater than work group size limit. 2. Support BNSH kv cache.
This commit is contained in:
parent
90d49ccb9a
commit
8c59cd4fce
11 changed files with 1059 additions and 15 deletions
|
|
@ -54,6 +54,7 @@ Do not modify directly.*
|
|||
| GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
|
||||
| Greater | ai.onnx(7-8,9-12,13+) | |
|
||||
| GreaterOrEqual | ai.onnx(12-15,16+) | |
|
||||
| GroupQueryAttention | com.microsoft(1+) | |
|
||||
| HardSigmoid | ai.onnx(6+) | |
|
||||
| If | ai.onnx(1-10,11-12,13-18,19+) | |
|
||||
| InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | |
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ import {fastGelu} from './ops/fast-gelu';
|
|||
import {gather, parseGatherAttributes} from './ops/gather';
|
||||
import {gatherElements, parseGatherElementsAttributes} from './ops/gather-elements';
|
||||
import {gemm, parseGemmAttributes} from './ops/gemm';
|
||||
import {groupQueryAttention, parseGroupQueryAttentionAttributes} from './ops/group-query-attention';
|
||||
import {instanceNorm} from './ops/instance-norm';
|
||||
import {layerNorm} from './ops/layer-norm';
|
||||
import {matMul} from './ops/matmul';
|
||||
|
|
@ -88,6 +89,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
|
|||
['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]],
|
||||
['Greater', [binaryOps.greater]],
|
||||
['GreaterOrEqual', [binaryOps.greaterOrEqual]],
|
||||
['GroupQueryAttention', [groupQueryAttention, parseGroupQueryAttentionAttributes]],
|
||||
['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]],
|
||||
['InstanceNormalization', [instanceNorm]],
|
||||
['LayerNormalization', [layerNorm]],
|
||||
|
|
|
|||
|
|
@ -46,20 +46,24 @@ export interface AttentionParameters {
|
|||
headSize: number;
|
||||
vHeadSize: number;
|
||||
numHeads: number;
|
||||
isUnidirectional: boolean;
|
||||
kvNumHeads?: number;
|
||||
nReps?: number;
|
||||
isUnidirectional?: boolean;
|
||||
pastPresentShareBuffer: boolean;
|
||||
maskFilterValue: number;
|
||||
maskFilterValue?: number;
|
||||
maskType: AttentionMaskType;
|
||||
scale: number;
|
||||
broadcastResPosBias: boolean;
|
||||
passPastInKv: boolean;
|
||||
qkvFormat: AttentionQkvFormat;
|
||||
isPastkvBSNH?: boolean;
|
||||
}
|
||||
|
||||
export interface AttentionAttrs {
|
||||
numHeads: number;
|
||||
isUnidirectional: number;
|
||||
maskFilterValue: number;
|
||||
kvNumHeads?: number;
|
||||
isUnidirectional?: number;
|
||||
maskFilterValue?: number;
|
||||
scale: number;
|
||||
doRotary: number;
|
||||
qkvHiddenSizes: number[];
|
||||
|
|
@ -443,17 +447,20 @@ const createVxAttentionScoreProgramInfo =
|
|||
(_context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters,
|
||||
pastSequenceLength: number) => {
|
||||
const totalSequenceLength = pastSequenceLength + params.kvSequenceLength;
|
||||
const outputShape = [params.batchSize, params.sequenceLength, params.vHiddenSize];
|
||||
const nReps = params.nReps ? params.nReps : 1;
|
||||
const repeatedVHiddenSize = params.vHiddenSize * nReps;
|
||||
const outputShape = [params.batchSize, params.sequenceLength, repeatedVHiddenSize];
|
||||
const TILE_SIZE = 12;
|
||||
const dispatch = {
|
||||
x: Math.ceil(params.vHeadSize / TILE_SIZE),
|
||||
y: Math.ceil(params.sequenceLength / TILE_SIZE),
|
||||
z: params.batchSize * params.numHeads
|
||||
};
|
||||
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: DataType.uint32, data: params.sequenceLength}, {type: DataType.uint32, data: totalSequenceLength},
|
||||
{type: DataType.uint32, data: params.vHeadSize}, {type: DataType.uint32, data: params.numHeads},
|
||||
{type: DataType.uint32, data: params.vHiddenSize}
|
||||
{type: DataType.uint32, data: repeatedVHiddenSize}
|
||||
];
|
||||
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
|
||||
|
|
@ -524,20 +531,22 @@ export const applyAttention =
|
|||
relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs) => {
|
||||
const outputPresentKey = context.outputCount > 1;
|
||||
const outputPresentValue = context.outputCount > 2;
|
||||
const pastSequenceLength = (outputPresentKey && outputPresentValue) ? parameters.pastSequenceLength : 0;
|
||||
const pastSequenceLength =
|
||||
parameters.kvNumHeads != null || (outputPresentKey && outputPresentValue) ? parameters.pastSequenceLength : 0;
|
||||
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
|
||||
// Concatinate pastKey and K to produce presentKey.
|
||||
const presentKeyShape = [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize];
|
||||
const concatKeyInputs = pastKey ? [pastKey, k] : [k];
|
||||
const key = outputPresentKey ? context.compute(
|
||||
createConcatProgramInfo(concatKeyInputs, 2, presentKeyShape, k.dataType),
|
||||
{inputs: concatKeyInputs, outputs: [1]})[0] :
|
||||
k;
|
||||
const key = parameters.kvNumHeads == null && outputPresentKey ?
|
||||
context.compute(
|
||||
createConcatProgramInfo(concatKeyInputs, 2, presentKeyShape, k.dataType),
|
||||
{inputs: concatKeyInputs, outputs: [1]})[0] :
|
||||
k;
|
||||
|
||||
// Concatinate pastValue and V to produce presentValue.
|
||||
const presentValueShape = [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize];
|
||||
const concatValueInputs = pastValue ? [pastValue, v] : [v];
|
||||
const value = outputPresentValue ?
|
||||
const value = parameters.kvNumHeads == null && outputPresentValue ?
|
||||
context.compute(
|
||||
createConcatProgramInfo(concatValueInputs, 2, presentValueShape, v.dataType),
|
||||
{inputs: concatValueInputs, outputs: [2]})[0] :
|
||||
|
|
|
|||
346
js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts
Normal file
346
js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts
Normal file
|
|
@ -0,0 +1,346 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {DataType} from '../../../wasm-common';
|
||||
import {TensorView} from '../../tensor-view';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';
|
||||
|
||||
import {applyAttention, AttentionAttrs, AttentionMaskType, AttentionParameters, AttentionQkvFormat} from './attention';
|
||||
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common';
|
||||
import {maybeTransposeToBNSHAndAddBias} from './multihead-attentiion';
|
||||
import {createTileProgramInfo} from './tile';
|
||||
import {createTransposeProgramInfo, TransposeAttributes} from './transpose';
|
||||
|
||||
export const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => {
|
||||
const query = inputs[0];
|
||||
const key = inputs[1];
|
||||
const value = inputs[2];
|
||||
const pastKey = inputs[3];
|
||||
const pastValue = inputs[4];
|
||||
|
||||
// Abbreviation and Meanings:
|
||||
// B: batch_size
|
||||
// S: sequence_length (input sequence length of query)
|
||||
// P: past_sequence_length (past sequence length of key or value)
|
||||
// L: kv_sequence_length (input sequence length of key or value)
|
||||
// M: max_sequence_length
|
||||
// T: total_sequence_length = past_sequence_length + kv_sequence_length
|
||||
// N: num_heads
|
||||
// H: head size for Q and K, aka q_head_size or k_head_size or qk_head_size
|
||||
// H_v: v_head_size
|
||||
// D_i: input hidden size
|
||||
// D: hidden size for Q and K (D = N * H), aka q_hidden_size or k_hidden_size or qk_hidden_size
|
||||
// D_v: v_hidden_size = num_heads * v_head_size
|
||||
|
||||
// past_key : (B, N, S*, H)
|
||||
// past_value : (B, N, S*, H)
|
||||
// When no packing for q/k/v:
|
||||
// query (Q) : (B, S, D)
|
||||
// key (K) : (B, L, D) or (B, N, S*, H)
|
||||
// value (V) : (B, L, D_v) or (B, N, S*, H)
|
||||
// When packed kv is used:
|
||||
// query (Q) : (B, S, D)
|
||||
// key (K) : (B, L, N, 2, H)
|
||||
// value (V) : None
|
||||
// When packed qkv is used:
|
||||
// query (Q) : (B, L, N, 3, H) or (B, S, 3*D)
|
||||
// key (K) : None
|
||||
// value (V) : None
|
||||
|
||||
if (query.dims.length !== 3 && query.dims.length !== 5) {
|
||||
throw new Error('Input query is expected to have 3 or 5 dimensions');
|
||||
}
|
||||
|
||||
const dmmhaPacking = false;
|
||||
const batchSize = query.dims[0];
|
||||
const sequenceLength = query.dims[1];
|
||||
const hiddenSize = query.dims.length === 3 ? (dmmhaPacking ? query.dims[2] / 3 : query.dims[2]) :
|
||||
attributes.numHeads * query.dims[4];
|
||||
let kvSequenceLength = sequenceLength;
|
||||
|
||||
let pastSequenceLength = 0;
|
||||
let maxSequenceLength = 0;
|
||||
const headSize = Math.floor(hiddenSize / attributes.numHeads);
|
||||
const hasPastKey = pastKey && pastKey.dims.length !== 0;
|
||||
const hasPastValue = pastValue && pastValue.dims.length !== 0;
|
||||
// TODO : this should be from attributes.
|
||||
const isPastkvBSNH = true;
|
||||
if (hasPastKey && hasPastValue) {
|
||||
if (pastKey.dims.length !== 4) {
|
||||
throw new Error('Input "past_key" is expected to have 4 dimensions');
|
||||
}
|
||||
if (pastValue.dims.length !== 4) {
|
||||
throw new Error('Input "past_value" is expected to have 4 dimensions');
|
||||
}
|
||||
if (isPastkvBSNH) {
|
||||
// For BSNH
|
||||
pastSequenceLength = pastKey.dims[1];
|
||||
maxSequenceLength = pastKey.dims[1];
|
||||
} else {
|
||||
// For BNSH
|
||||
pastSequenceLength = pastKey.dims[2];
|
||||
maxSequenceLength = pastKey.dims[2];
|
||||
}
|
||||
} else if (hasPastKey || hasPastValue) {
|
||||
throw new Error('Input "past_key" and "past_value" shall be both present or both absent');
|
||||
}
|
||||
|
||||
let qkvFormat: AttentionQkvFormat;
|
||||
if (key) {
|
||||
if (query.dims.length !== 3) {
|
||||
throw new Error('Input "query" is expected to have 3 dimensions when key is given');
|
||||
}
|
||||
if (key.dims.length < 3 || key.dims.length > 5) {
|
||||
throw new Error('Input "key" is expected to have 3, 4, or 5 dimensions');
|
||||
}
|
||||
if (query.dims[0] !== key.dims[0]) {
|
||||
throw new Error('Input "query" and "key" shall have same dim 0 (batch size)');
|
||||
}
|
||||
|
||||
if (key.dims.length === 3) {
|
||||
if (query.dims[2] % key.dims[2] !== 0) {
|
||||
throw new Error('Dimension 2 of "query" should be a multiple of "key"');
|
||||
}
|
||||
qkvFormat = AttentionQkvFormat.qkvBSNH;
|
||||
kvSequenceLength = key.dims[1];
|
||||
} else if (key.dims.length === 5) {
|
||||
if (key.dims[2] !== attributes.numHeads || key.dims[3] !== 2 || key.dims[4] !== headSize) {
|
||||
throw new Error('Expect "key" shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv');
|
||||
}
|
||||
if (value) {
|
||||
throw new Error('Expect "value" be none when "key" has packed kv format.');
|
||||
}
|
||||
qkvFormat = AttentionQkvFormat.qKvBSNHxBSN2H;
|
||||
kvSequenceLength = key.dims[1];
|
||||
} else { // key_dims.size() == 4 (cross-attention with past_key)
|
||||
if (key.dims[1] !== attributes.numHeads || key.dims[3] !== headSize) {
|
||||
throw new Error('Expect "key" shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key');
|
||||
}
|
||||
|
||||
qkvFormat = AttentionQkvFormat.unknown;
|
||||
kvSequenceLength = key.dims[2];
|
||||
}
|
||||
} else { // packed QKV
|
||||
if (query.dims.length !== 3 && query.dims.length !== 5) {
|
||||
throw new Error('Input "query" is expected to have 3 or 5 dimensions when key is empty');
|
||||
}
|
||||
if (query.dims.length === 5 && (query.dims[2] !== attributes.numHeads || query.dims[3] !== 3)) {
|
||||
throw new Error('Expect "query" shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv');
|
||||
}
|
||||
|
||||
qkvFormat = AttentionQkvFormat.qkvBSN3H;
|
||||
}
|
||||
|
||||
const maskType: AttentionMaskType = AttentionMaskType.none;
|
||||
let passPastInKv = false;
|
||||
let vHiddenSize = hiddenSize;
|
||||
if (value) {
|
||||
if (value.dims.length !== 3 && value.dims.length !== 4) {
|
||||
throw new Error('Input "value" is expected to have 3 or 4 dimensions');
|
||||
}
|
||||
|
||||
if (query.dims[0] !== value.dims[0]) {
|
||||
throw new Error('Input "query" and "value" shall have same dim 0 (batch_size)');
|
||||
}
|
||||
|
||||
if (value.dims.length === 3) {
|
||||
if (kvSequenceLength !== value.dims[1]) {
|
||||
throw new Error('Input "key" and "value" shall have the same dim 1 (kv_sequence_length)');
|
||||
}
|
||||
vHiddenSize = value.dims[2];
|
||||
} else {
|
||||
if (kvSequenceLength !== value.dims[2]) {
|
||||
throw new Error('Input "past_key" and "past_value" shall have the same dim 2 (kv_sequence_length)');
|
||||
}
|
||||
vHiddenSize = value.dims[1] * value.dims[3];
|
||||
passPastInKv = true;
|
||||
}
|
||||
}
|
||||
const totalSequenceLength = pastSequenceLength + kvSequenceLength;
|
||||
const broadcastResPosBias = false;
|
||||
|
||||
return {
|
||||
batchSize,
|
||||
sequenceLength,
|
||||
pastSequenceLength,
|
||||
kvSequenceLength,
|
||||
totalSequenceLength,
|
||||
maxSequenceLength,
|
||||
inputHiddenSize: 0,
|
||||
hiddenSize,
|
||||
vHiddenSize,
|
||||
headSize,
|
||||
vHeadSize: Math.floor(vHiddenSize / attributes.kvNumHeads!),
|
||||
numHeads: attributes.numHeads,
|
||||
kvNumHeads: attributes.kvNumHeads,
|
||||
nReps: attributes.numHeads / attributes.kvNumHeads!,
|
||||
pastPresentShareBuffer: false,
|
||||
maskType,
|
||||
scale: attributes.scale,
|
||||
broadcastResPosBias,
|
||||
passPastInKv,
|
||||
qkvFormat,
|
||||
isPastkvBSNH,
|
||||
};
|
||||
};
|
||||
|
||||
const createConcatProgramInfo =
|
||||
(a: TensorView, b: TensorView|undefined, dataType: DataType, params: AttentionParameters): ProgramInfo => {
|
||||
const outputShape = [params.batchSize, params.totalSequenceLength, params.kvNumHeads!, params.headSize];
|
||||
const component = 4;
|
||||
const outputSize = ShapeUtil.size(outputShape) / component;
|
||||
const presentSequenceLength = params.totalSequenceLength;
|
||||
const output = outputVariable('present_kv', dataType, outputShape.length, component);
|
||||
const inputA = inputVariable('new_kv', a.dataType, a.dims.length, component);
|
||||
const inputB = b ? inputVariable('past_kv', b.dataType, b.dims.length, component) : undefined;
|
||||
|
||||
const H = Math.ceil(params.headSize / component);
|
||||
const dispatch = {x: presentSequenceLength, y: a.dims[0], z: 1};
|
||||
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] = b ? ['rank', 'rank'] : ['rank'];
|
||||
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: params.pastSequenceLength},
|
||||
{type: DataType.uint32, data: params.kvSequenceLength},
|
||||
{type: DataType.uint32, data: params.totalSequenceLength}
|
||||
];
|
||||
|
||||
const inputs = [inputA];
|
||||
if (inputB) {
|
||||
programUniforms.push(
|
||||
...createTensorShapeVariables(a.dims), ...createTensorShapeVariables(b!.dims),
|
||||
...createTensorShapeVariables(outputShape));
|
||||
inputs.push(inputB);
|
||||
} else {
|
||||
programUniforms.push(...createTensorShapeVariables(a.dims), ...createTensorShapeVariables(outputShape));
|
||||
}
|
||||
const uniforms: UniformsArrayType = [
|
||||
{name: 'output_size', type: 'u32'}, {name: 'past_seqlen', type: 'u32'}, {name: 'new_seqlen', type: 'u32'},
|
||||
{name: 'present_seqlen', type: 'u32'}
|
||||
];
|
||||
|
||||
const pastStr = ` let past_batch_stride = uniforms.past_seqlen * num_heads * H;
|
||||
var past_head_stride = uniforms.past_seqlen * H;
|
||||
if (is_bsnh) {
|
||||
past_head_stride = H;
|
||||
}
|
||||
let in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h;
|
||||
present_kv[out_offset] = past_kv[in_offset];`;
|
||||
const newStr = ` let new_batch_stride = uniforms.new_seqlen * num_heads * H;
|
||||
let new_row_stride = num_heads * H;
|
||||
let new_head_stride = H;
|
||||
let in_offset = b * new_batch_stride + (s - past_seqlen) * new_row_stride + n * new_head_stride + h;
|
||||
present_kv[out_offset] = new_kv[in_offset];`;
|
||||
const concatStr = b ? `if (s < past_seqlen) {
|
||||
${pastStr}
|
||||
} else if (s < past_seqlen + uniforms.new_seqlen) {
|
||||
${newStr}
|
||||
}` :
|
||||
`if (s < past_seqlen + uniforms.new_seqlen) {
|
||||
${newStr}
|
||||
}`;
|
||||
|
||||
// TODO: handle H * params.kvNumHeads greater than maxComputeInvocationsPerWorkgroup limit.
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
|
||||
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputs, output)}
|
||||
${shaderHelper.mainStart([
|
||||
H, params.kvNumHeads!, 1
|
||||
])}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
|
||||
var indices = ${output.offsetToIndices('global_idx')};
|
||||
let h = local_id.x;
|
||||
let n = local_id.y;
|
||||
let s = workgroup_id.x;
|
||||
let b = workgroup_id.y;
|
||||
let num_heads = ${params.kvNumHeads!}u;
|
||||
let H = ${H}u;
|
||||
|
||||
let present_seqlen = uniforms.present_seqlen;
|
||||
let present_batch_stride = present_seqlen * num_heads * H;
|
||||
var row_stride = H;
|
||||
let is_bsnh = ${params.isPastkvBSNH};
|
||||
|
||||
if (is_bsnh) {
|
||||
row_stride = num_heads * H;
|
||||
}
|
||||
var present_head_stride = present_seqlen * H;
|
||||
if (is_bsnh) {
|
||||
present_head_stride = H;
|
||||
}
|
||||
|
||||
let past_seqlen = uniforms.past_seqlen;
|
||||
|
||||
let out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h;
|
||||
${concatStr}
|
||||
}`;
|
||||
|
||||
return {
|
||||
name: 'ConcatPastNew',
|
||||
shaderCache: {hint: `${params.kvNumHeads!}${H}${!!b}`, inputDependencies},
|
||||
getRunData: () => ({
|
||||
outputs: [{dims: outputShape, dataType}],
|
||||
dispatchGroup: dispatch,
|
||||
programUniforms,
|
||||
}),
|
||||
getShaderSource,
|
||||
};
|
||||
};
|
||||
|
||||
export const parseGroupQueryAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs =>
|
||||
createAttributeWithCacheKey({...attributes});
|
||||
|
||||
const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({perm: [0, 2, 1, 3]});
|
||||
|
||||
const maybeExpandAndTransposeToBNSH =
|
||||
(context: ComputeContext, input: TensorView, pastKV: TensorView|undefined, params: AttentionParameters,
|
||||
outputIndex: number) => {
|
||||
let reshapedInput = input;
|
||||
const numHeads = params.kvNumHeads!;
|
||||
const nReps = params.nReps!;
|
||||
if (input.dims.length === 3 && params.kvSequenceLength !== 0) {
|
||||
reshapedInput = input.reshape([params.batchSize, params.kvSequenceLength, numHeads, params.headSize]);
|
||||
}
|
||||
|
||||
if (pastKV) {
|
||||
reshapedInput = context.compute(
|
||||
createConcatProgramInfo(reshapedInput, pastKV, reshapedInput.dataType, params),
|
||||
{inputs: [reshapedInput, pastKV], outputs: [params.isPastkvBSNH ? outputIndex : -1]})[0];
|
||||
} else {
|
||||
reshapedInput = context.compute(
|
||||
createConcatProgramInfo(reshapedInput, undefined, reshapedInput.dataType, params),
|
||||
{inputs: [reshapedInput], outputs: [params.isPastkvBSNH ? outputIndex : -1]})[0];
|
||||
}
|
||||
if (nReps !== 1) {
|
||||
reshapedInput = context.compute(
|
||||
createTileProgramInfo([reshapedInput], [1, 1, 1, nReps]), {inputs: [reshapedInput], outputs: [-1]})[0];
|
||||
reshapedInput =
|
||||
reshapedInput.reshape([params.batchSize, params.totalSequenceLength, numHeads * nReps, params.headSize]);
|
||||
}
|
||||
|
||||
return context.compute(
|
||||
createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm),
|
||||
{inputs: [reshapedInput], outputs: [-1]})[0];
|
||||
};
|
||||
|
||||
export const groupQueryAttention = (context: ComputeContext, attributes: AttentionAttrs): void => {
|
||||
const params = validateInputs(context.inputs, attributes);
|
||||
if (context.inputs[0].dims.length === 5) {
|
||||
throw new Error('Packed QKV is not implemented');
|
||||
}
|
||||
|
||||
if (context.inputs[1]?.dims.length === 5) {
|
||||
throw new Error('Packed KV is not implemented');
|
||||
}
|
||||
|
||||
const Q = maybeTransposeToBNSHAndAddBias(
|
||||
context, params.batchSize, params.numHeads, params.sequenceLength, params.headSize, context.inputs[0], undefined,
|
||||
0);
|
||||
const pastKey = context.inputs[3] && context.inputs[3].dims.length !== 0 ? context.inputs[3] : undefined;
|
||||
const pastValue = context.inputs[4] && context.inputs[4].dims.length !== 0 ? context.inputs[4] : undefined;
|
||||
const K = maybeExpandAndTransposeToBNSH(context, context.inputs[1], pastKey, params, 1);
|
||||
const V = maybeExpandAndTransposeToBNSH(context, context.inputs[2], pastValue, params, 2);
|
||||
applyAttention(context, Q, K, V, undefined, undefined, undefined, undefined, undefined, params, attributes);
|
||||
};
|
||||
|
|
@ -286,7 +286,7 @@ const addBiasTranspose =
|
|||
{inputs: [qkv, bias], outputs: [-1]})[0];
|
||||
};
|
||||
|
||||
const maybeTransposeToBNSHAndAddBias =
|
||||
export const maybeTransposeToBNSHAndAddBias =
|
||||
(context: ComputeContext, batchSize: number, numHeads: number, sequenceLength: number, headSize: number,
|
||||
input: TensorView, bias?: TensorView, biasOffset?: number) => {
|
||||
// const newDims = [];
|
||||
|
|
|
|||
|
|
@ -47,9 +47,9 @@ const getOutputShape = (inputShape: readonly number[], repeats: readonly number[
|
|||
return outputShape;
|
||||
};
|
||||
|
||||
export const createTileProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => {
|
||||
export const createTileProgramInfo = (inputs: readonly TensorView[], shape?: number[]): ProgramInfo => {
|
||||
const inputShape = inputs[0].dims;
|
||||
const repeats: readonly number[] = getRepeats(inputs[1]);
|
||||
const repeats: readonly number[] = shape == null ? getRepeats(inputs[1]) : shape;
|
||||
const outputShape = getOutputShape(inputShape, repeats);
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
|
||||
|
|
|
|||
616
js/web/test/data/ops/group-query-attention.jsonc
Normal file
616
js/web/test/data/ops/group-query-attention.jsonc
Normal file
|
|
@ -0,0 +1,616 @@
|
|||
[
|
||||
{
|
||||
"name": "GroupQueryAttention Basic",
|
||||
"operator": "GroupQueryAttention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [
|
||||
{ "name": "num_heads", "data": 4, "type": "int" },
|
||||
{ "name": "kv_num_heads", "data": 2, "type": "int" }
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [
|
||||
1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4,
|
||||
8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4
|
||||
],
|
||||
"dims": [1, 3, 16],
|
||||
"type": "float32"
|
||||
},
|
||||
// key, BS*
|
||||
{
|
||||
"data": [1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21],
|
||||
"dims": [1, 3, 8],
|
||||
"type": "float32"
|
||||
},
|
||||
// value, BS*
|
||||
{
|
||||
"data": [1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21],
|
||||
"dims": [1, 3, 8],
|
||||
"type": "float32"
|
||||
},
|
||||
// past key, BS*
|
||||
{
|
||||
"data": null,
|
||||
"type": "float32"
|
||||
},
|
||||
// past value, BS*
|
||||
{
|
||||
"data": null,
|
||||
"type": "float32"
|
||||
},
|
||||
// seqlens_k, unimplemented
|
||||
{
|
||||
"data": [1],
|
||||
"dims": [1],
|
||||
"type": "int32"
|
||||
},
|
||||
// total_sequence_length, unimplemented
|
||||
{
|
||||
"data": [1],
|
||||
"dims": [1],
|
||||
"type": "int32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 2, 131, 22, 21, 2, 131, 22, 21, 131, 22, 21, 2, 1, 1, 1, 1, 2, 131, 22, 21, 2,
|
||||
131, 22, 21, 131, 22, 21, 2, 1, 1, 1, 1, 2, 131, 22, 21, 2, 131, 22, 21
|
||||
],
|
||||
"dims": [1, 3, 16],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
// present key, BS*
|
||||
"data": [1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21],
|
||||
"dims": [1, 3, 2, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
// present value, BS*
|
||||
"data": [1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21],
|
||||
"dims": [1, 3, 2, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "GroupQueryAttention Scale",
|
||||
"operator": "GroupQueryAttention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [
|
||||
{ "name": "num_heads", "data": 4, "type": "int" },
|
||||
{ "name": "kv_num_heads", "data": 2, "type": "int" },
|
||||
{ "name": "scale", "data": 2.0, "type": "float" }
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [
|
||||
1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4
|
||||
],
|
||||
"dims": [1, 4, 8],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1, 9, 1, 1, 2, 2, 2, 2],
|
||||
"dims": [1, 2, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1, 1, 1, 1, 2, 2, 2, 2],
|
||||
"dims": [1, 2, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// past key, BS*
|
||||
{
|
||||
"data": null,
|
||||
"type": "float32"
|
||||
},
|
||||
// past value, BS*
|
||||
{
|
||||
"data": null,
|
||||
"type": "float32"
|
||||
},
|
||||
// seqlens_k, unimplemented
|
||||
{
|
||||
"data": [1],
|
||||
"dims": [1],
|
||||
"type": "int32"
|
||||
},
|
||||
// total_sequence_length, unimplemented
|
||||
{
|
||||
"data": [1],
|
||||
"dims": [1],
|
||||
"type": "int32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [
|
||||
1.000006079673767, 1.000006079673767, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 1,
|
||||
1, 1, 1, 1.9820137023925781, 1.9820137023925781, 1.9999991655349731, 1.9999991655349731
|
||||
],
|
||||
"dims": [1, 4, 8],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
// present key, BS*
|
||||
"data": [1, 9, 1, 1, 2, 2, 2, 2],
|
||||
"dims": [1, 2, 2, 2],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
// present value, BS*
|
||||
"data": [1, 1, 1, 1, 2, 2, 2, 2],
|
||||
"dims": [1, 2, 2, 2],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
{
|
||||
"name": "GroupQueryAttention, different sequence length",
|
||||
"operator": "GroupQueryAttention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [
|
||||
{ "name": "num_heads", "data": 4, "type": "int" },
|
||||
{ "name": "kv_num_heads", "data": 2, "type": "int" }
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [
|
||||
1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4
|
||||
],
|
||||
"dims": [1, 4, 8],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1, 9, 1, 1, 2, 2, 2, 2],
|
||||
"dims": [1, 2, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1, 1, 1, 1, 2, 2, 2, 2],
|
||||
"dims": [1, 2, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// past key, BS*
|
||||
{
|
||||
"data": null,
|
||||
"type": "float32"
|
||||
},
|
||||
// past value, BS*
|
||||
{
|
||||
"data": null,
|
||||
"type": "float32"
|
||||
},
|
||||
// seqlens_k, unimplemented
|
||||
{
|
||||
"data": [1],
|
||||
"dims": [1],
|
||||
"type": "int32"
|
||||
},
|
||||
// total_sequence_length, unimplemented
|
||||
{
|
||||
"data": [1],
|
||||
"dims": [1],
|
||||
"type": "int32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [
|
||||
1.014165997505188, 1.014165997505188, 1.0000015497207642, 1.0000015497207642, 1.99828040599823,
|
||||
1.99828040599823, 1.9998981952667236, 1.9998981952667236, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2,
|
||||
1.9995813369750977, 1.9995813369750977, 1.9999752044677734, 1.9999752044677734, 1, 1, 1, 1,
|
||||
1.8044296503067017, 1.8044296503067017, 1.9929646253585815, 1.9929646253585815
|
||||
],
|
||||
"dims": [1, 4, 8],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1, 9, 1, 1, 2, 2, 2, 2],
|
||||
"dims": [1, 2, 2, 2],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1, 1, 1, 1, 2, 2, 2, 2],
|
||||
"dims": [1, 2, 2, 2],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "GroupQueryAttention Basic, q k v same head number",
|
||||
"operator": "GroupQueryAttention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [
|
||||
{ "name": "num_heads", "data": 4, "type": "int" },
|
||||
{ "name": "kv_num_heads", "data": 4, "type": "int" }
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [
|
||||
1, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4,
|
||||
8, 12, 233, 4, 5, 6, 7, 8, 5, 6, 7, 8, 1, 1, 3, 4
|
||||
],
|
||||
"dims": [1, 3, 16],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21, 1, 9, 1, 1, 2, 2, 2,
|
||||
2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21
|
||||
],
|
||||
"dims": [1, 3, 16],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21, 1, 9, 1, 1, 2, 2, 2, 2, 1,
|
||||
12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21
|
||||
],
|
||||
"dims": [1, 3, 16],
|
||||
"type": "float32"
|
||||
},
|
||||
// past key, BS*
|
||||
{
|
||||
"data": null,
|
||||
"type": "float32"
|
||||
},
|
||||
// past value, BS*
|
||||
{
|
||||
"data": null,
|
||||
"type": "float32"
|
||||
},
|
||||
// seqlens_k, unimplemented
|
||||
{
|
||||
"data": [1],
|
||||
"dims": [1],
|
||||
"type": "int32"
|
||||
},
|
||||
// total_sequence_length, unimplemented
|
||||
{
|
||||
"data": [1],
|
||||
"dims": [1],
|
||||
"type": "int32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [
|
||||
1, 12, 21, 131, 2, 131, 22, 21, 1, 1, 1, 1, 2, 131, 22, 21, 131, 22, 21, 2, 2, 131, 22, 21, 1, 1, 1, 1, 2,
|
||||
131, 22, 21, 131, 22, 21, 2, 2, 131, 22, 21, 1, 1, 1, 1, 2, 131, 22, 21
|
||||
],
|
||||
"dims": [1, 3, 16],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
1, 9, 1, 1, 2, 2, 2, 2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21, 1, 9, 1, 1, 2, 2, 2,
|
||||
2, 1, 12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21
|
||||
],
|
||||
"dims": [1, 3, 4, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21, 1, 9, 1, 1, 2, 2, 2, 2, 1,
|
||||
12, 21, 131, 22, 21, 2, 2, 131, 22, 21, 2, 2, 131, 22, 21
|
||||
],
|
||||
"dims": [1, 3, 4, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "GroupQueryAttention, no past kv, used as reference",
|
||||
"operator": "GroupQueryAttention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [
|
||||
{ "name": "num_heads", "data": 4, "type": "int" },
|
||||
{ "name": "kv_num_heads", "data": 2, "type": "int" }
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
|
||||
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
|
||||
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
|
||||
82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106,
|
||||
107, 108, 109, 110, 111, 112
|
||||
],
|
||||
"dims": [1, 7, 16],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
|
||||
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56
|
||||
],
|
||||
"dims": [1, 7, 8],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
|
||||
29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55
|
||||
],
|
||||
"dims": [1, 7, 8],
|
||||
"type": "float32"
|
||||
},
|
||||
// past key, BS*
|
||||
{
|
||||
"data": null,
|
||||
"type": "float32"
|
||||
},
|
||||
// past value, BS*
|
||||
{
|
||||
"data": null,
|
||||
"type": "float32"
|
||||
},
|
||||
// seqlens_k, unimplemented
|
||||
{
|
||||
"data": [1],
|
||||
"dims": [1],
|
||||
"type": "int32"
|
||||
},
|
||||
// total_sequence_length, unimplemented
|
||||
{
|
||||
"data": [1],
|
||||
"dims": [1],
|
||||
"type": "int32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [
|
||||
48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53,
|
||||
54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51,
|
||||
48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53,
|
||||
54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51,
|
||||
52, 53, 54, 55, 52, 53, 54, 55
|
||||
],
|
||||
"dims": [1, 7, 16],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
|
||||
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56
|
||||
],
|
||||
"dims": [1, 7, 2, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
|
||||
29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55
|
||||
],
|
||||
"dims": [1, 7, 2, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "GroupQueryAttention Past&Present KV BSNH, key seqlen = 1",
|
||||
"operator": "GroupQueryAttention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [
|
||||
{ "name": "num_heads", "data": 4, "type": "int" },
|
||||
{ "name": "kv_num_heads", "data": 2, "type": "int" }
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
|
||||
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
|
||||
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
|
||||
82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106,
|
||||
107, 108, 109, 110, 111, 112
|
||||
],
|
||||
"dims": [1, 7, 16],
|
||||
"type": "float32"
|
||||
},
|
||||
// new key, BS*
|
||||
{
|
||||
"data": [
|
||||
9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
|
||||
36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56
|
||||
],
|
||||
"dims": [1, 6, 8],
|
||||
"type": "float32"
|
||||
},
|
||||
// new value, BS*
|
||||
{
|
||||
"data": [
|
||||
8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
|
||||
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55
|
||||
],
|
||||
"dims": [1, 6, 8],
|
||||
"type": "float32"
|
||||
},
|
||||
// past key, BS*
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8],
|
||||
"dims": [1, 1, 2, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// past value, BS*
|
||||
{
|
||||
"data": [0, 1, 2, 3, 4, 5, 6, 7],
|
||||
"dims": [1, 1, 2, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// seqlens_k, unimplemented
|
||||
{
|
||||
"data": [1],
|
||||
"dims": [1],
|
||||
"type": "int32"
|
||||
},
|
||||
// total_sequence_length, unimplemented
|
||||
{
|
||||
"data": [1],
|
||||
"dims": [1],
|
||||
"type": "int32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [
|
||||
48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53,
|
||||
54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51,
|
||||
48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53,
|
||||
54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51,
|
||||
52, 53, 54, 55, 52, 53, 54, 55
|
||||
],
|
||||
"dims": [1, 7, 16],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
|
||||
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56
|
||||
],
|
||||
"dims": [1, 7, 2, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
|
||||
29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55
|
||||
],
|
||||
"dims": [1, 7, 2, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "GroupQueryAttention Past&Present KV BSNH, key seqlen = 2",
|
||||
"operator": "GroupQueryAttention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [
|
||||
{ "name": "num_heads", "data": 4, "type": "int" },
|
||||
{ "name": "kv_num_heads", "data": 2, "type": "int" }
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
|
||||
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
|
||||
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
|
||||
82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106,
|
||||
107, 108, 109, 110, 111, 112
|
||||
],
|
||||
"dims": [1, 7, 16],
|
||||
"type": "float32"
|
||||
},
|
||||
// new key, BS*
|
||||
{
|
||||
"data": [
|
||||
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42,
|
||||
43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56
|
||||
],
|
||||
"dims": [1, 5, 8],
|
||||
"type": "float32"
|
||||
},
|
||||
// new value, BS*
|
||||
{
|
||||
"data": [
|
||||
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
|
||||
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55
|
||||
],
|
||||
"dims": [1, 5, 8],
|
||||
"type": "float32"
|
||||
},
|
||||
// past key, BS*
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
|
||||
"dims": [1, 2, 2, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// past value, BS*
|
||||
{
|
||||
"data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
|
||||
"dims": [1, 2, 2, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
// seqlens_k, unimplemented
|
||||
{
|
||||
"data": [1],
|
||||
"dims": [1],
|
||||
"type": "int32"
|
||||
},
|
||||
// total_sequence_length, unimplemented
|
||||
{
|
||||
"data": [1],
|
||||
"dims": [1],
|
||||
"type": "int32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [
|
||||
48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53,
|
||||
54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51,
|
||||
48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53,
|
||||
54, 55, 48, 49, 50, 51, 48, 49, 50, 51, 52, 53, 54, 55, 52, 53, 54, 55, 48, 49, 50, 51, 48, 49, 50, 51,
|
||||
52, 53, 54, 55, 52, 53, 54, 55
|
||||
],
|
||||
"dims": [1, 7, 16],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
|
||||
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56
|
||||
],
|
||||
"dims": [1, 7, 2, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
|
||||
29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55
|
||||
],
|
||||
"dims": [1, 7, 2, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
|
@ -1361,6 +1361,7 @@
|
|||
"gemm.jsonc",
|
||||
"global-average-pool.jsonc",
|
||||
"greater.jsonc",
|
||||
"group-query-attention.jsonc",
|
||||
"instance-norm.jsonc",
|
||||
"less.jsonc",
|
||||
"log.jsonc",
|
||||
|
|
|
|||
24
onnxruntime/contrib_ops/js/bert/group_query_attention.cc
Normal file
24
onnxruntime/contrib_ops/js/bert/group_query_attention.cc
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "group_query_attention.h"
|
||||
#include "core/providers/js/js_data_types.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace js {
|
||||
|
||||
using onnxruntime::js::JsepSupportedFloatTypes;
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
GroupQueryAttention,
|
||||
kMSDomain,
|
||||
1,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", JsepSupportedFloatTypes()),
|
||||
GroupQueryAttention);
|
||||
|
||||
} // namespace js
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
43
onnxruntime/contrib_ops/js/bert/group_query_attention.h
Normal file
43
onnxruntime/contrib_ops/js/bert/group_query_attention.h
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/js/js_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace js {
|
||||
|
||||
using onnxruntime::js::JsKernel;
|
||||
|
||||
class GroupQueryAttention : public JsKernel {
|
||||
public:
|
||||
explicit GroupQueryAttention(const OpKernelInfo& info)
|
||||
: JsKernel(info) {
|
||||
int64_t num_heads = 0;
|
||||
int64_t kv_num_heads = 0;
|
||||
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
|
||||
ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0);
|
||||
num_heads_ = static_cast<int>(num_heads);
|
||||
kv_num_heads_ = static_cast<int>(kv_num_heads);
|
||||
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);
|
||||
JSEP_INIT_KERNEL_ATTRIBUTE(GroupQueryAttention, ({
|
||||
"numHeads" : $1,
|
||||
"kvNumHeads" : $2,
|
||||
"scale" : $3,
|
||||
}),
|
||||
static_cast<int32_t>(num_heads_),
|
||||
static_cast<int32_t>(kv_num_heads_),
|
||||
static_cast<float>(scale_));
|
||||
}
|
||||
|
||||
protected:
|
||||
int num_heads_; // number of attention heads
|
||||
int kv_num_heads_; // number of k and v heads
|
||||
float scale_; // custom scale will be used if specified. Default value is 1/sqrt(head_size)
|
||||
};
|
||||
|
||||
} // namespace js
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -13,6 +13,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSp
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FastGelu);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, GroupQueryAttention);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MatMulNBits);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, RotaryEmbedding);
|
||||
|
|
@ -34,6 +35,7 @@ Status RegisterJsContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FastGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, GroupQueryAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MatMulNBits)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, RotaryEmbedding)>,
|
||||
|
|
|
|||
Loading…
Reference in a new issue