mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-09 00:30:53 +00:00
[js/web] JSEP Attention & MultiHeadAttention (#17742)
### Description This is a narrow implementation of Attention/MultiHeadAttention as it does not support: a. inputs 5-7 for MHA b. packed QKV/KV c. past/present d. attention mask But it works well for StableDiffusion and can be extended later. It reduces VRAM usage as it combines many ops into few I've updated demo here https://islamov.ai/stable-diffusion-webgpu/ it takes ~13sec for 1 image with 20 steps on RTX3090Ti and about 25s on M1 Pro VRAM usage is about 8gb if you don't use img2img Going to focus on SDXL now --------- Co-authored-by: Guenther Schmuelling <guschmue@microsoft.com> Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com>
This commit is contained in:
parent
a5537f2f56
commit
fac3e33da5
13 changed files with 1866 additions and 0 deletions
|
|
@ -20,6 +20,7 @@ Do not modify directly.*
|
|||
| Asinh | ai.onnx(9+) | |
|
||||
| Atan | ai.onnx(7+) | |
|
||||
| Atanh | ai.onnx(9+) | |
|
||||
| Attention | com.microsoft(1+) | need implementing mask and past/present |
|
||||
| AveragePool | ai.onnx(7-9,10,11+); com.ms.internal.nhwc(7-9,10,11+) | need perf optimization; need implementing activation |
|
||||
| BiasAdd | com.microsoft(1+) | |
|
||||
| BiasSplitGelu | com.microsoft(1+) | |
|
||||
|
|
@ -61,6 +62,7 @@ Do not modify directly.*
|
|||
| MemcpyFromHost | ai.onnx(1+) | |
|
||||
| MemcpyToHost | ai.onnx(1+) | |
|
||||
| Mul | ai.onnx(7-12,13,14+) | |
|
||||
| MultiHeadAttention | com.microsoft(1+) | need implementing mask and past/present |
|
||||
| Neg | ai.onnx(6-12,13+) | |
|
||||
| Not | ai.onnx(1+) | |
|
||||
| Pad | ai.onnx(2-10,11-12,13-17,18,19+) | |
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
import {argMax, argMin, parseArgMinMaxAttributes} from './ops/argminmax';
|
||||
import {attention, parseAttentionAttributes} from './ops/attention';
|
||||
import {biasAdd} from './ops/bias-add';
|
||||
import {biasSplitGelu} from './ops/bias-split-gelu';
|
||||
import * as binaryOps from './ops/binary-op';
|
||||
|
|
@ -16,6 +17,7 @@ import {gemm, parseGemmAttributes} from './ops/gemm';
|
|||
import {instanceNorm, parseInstanceNormAttributes} from './ops/instance-norm';
|
||||
import {layerNorm, parseLayerNormAttributes} from './ops/layer-norm';
|
||||
import {matMul} from './ops/matmul';
|
||||
import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multi-head-attentiion';
|
||||
import {pad, parsePadAttributes} from './ops/pad';
|
||||
import * as pool from './ops/pool';
|
||||
import {range} from './ops/range';
|
||||
|
|
@ -46,6 +48,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
|
|||
['Asinh', [unaryOps.asinh]],
|
||||
['Atan', [unaryOps.atan]],
|
||||
['Atanh', [unaryOps.atanh]],
|
||||
['Attention', [attention, parseAttentionAttributes]],
|
||||
// TODO: support new attributes for AveragePool-10
|
||||
['AveragePool', [pool.averagePool, pool.parseAveragePoolAttributes]],
|
||||
['BiasAdd', [biasAdd]],
|
||||
|
|
@ -86,6 +89,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
|
|||
// TODO: support new attributes for MaxPool-8 and MaxPool-10
|
||||
['MaxPool', [pool.maxPool, pool.parseMaxPoolAttributes]],
|
||||
['Mul', [binaryOps.mul]],
|
||||
['MultiHeadAttention', [multiHeadAttention, parseMultiHeadAttentionAttributes]],
|
||||
['Neg', [unaryOps.neg]],
|
||||
['Not', [unaryOps.not]],
|
||||
['Pad', [pad, parsePadAttributes]],
|
||||
|
|
|
|||
635
js/web/lib/wasm/jsep/webgpu/ops/attention.ts
Normal file
635
js/web/lib/wasm/jsep/webgpu/ops/attention.ts
Normal file
|
|
@ -0,0 +1,635 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {TensorView} from '../../tensor-view';
|
||||
import {createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext, GpuDataType} from '../types';
|
||||
|
||||
import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common';
|
||||
|
||||
export const enum AttentionQkvFormat {
|
||||
unknown, // enum value not set, or depends on qkv projection implementation details
|
||||
qkvBNSH, // for non-packed qkv, permuted
|
||||
qkvBSNH, // for non-packed qkv, not permuted, used by memory efficient attention or MultiHeadAttention
|
||||
qkvBSN3H, // for TRT fused attention, qkv are packed
|
||||
qkvBNSHqkvBS3NH, // for TRT fused causal attention, data has two formats (qkv is 3BNSH, gemm_buffer is BS3NH)
|
||||
qKvBSNHxBSN2H, // for TRT fused cross attention, kv are packed
|
||||
qkvTNH, // for memory efficient attention, qkv are not packed, and paddings are removed.
|
||||
qkvTN3H, // for TRT fused attention, qkv are packed and paddings are removed
|
||||
}
|
||||
|
||||
export const enum AttentionMaskType {
|
||||
none, // No mask
|
||||
mask1dKeySeqLen, // [batch_size], key sequence length
|
||||
mask1dEndStart, // [2 * batch_size] with end positions and start positions
|
||||
mask1DKeySeqLenStart, // [3 * batch_size + 2] with [key_len[0], ..., key_len[batch_size - 1], query_start[0],
|
||||
// ..., query_start[batch_size - 1], query_end[batch_size - 1], key_start[0], ...,
|
||||
// key_start[batch_size - 1], key_end[batch_size - 1]]
|
||||
mask2dDummy, // dummy mask with shape [1, 1] or [batch_size, 1]. It has same effect as no mask.
|
||||
mask2dKeyPadding, // [batch_size, total_sequence_length]
|
||||
mask3dAttention, // [batch_size, sequence_length, total_sequence_length]
|
||||
mask4dMegatron, // Megatron causal mask with shape [batch_size, 1, max_sequence_length, max_sequence_length]
|
||||
maskUnknown
|
||||
}
|
||||
|
||||
export interface AttentionParameters {
|
||||
batchSize: number;
|
||||
sequenceLength: number;
|
||||
pastSequenceLength: number;
|
||||
kvSequenceLength: number;
|
||||
totalSequenceLength: number;
|
||||
maxSequenceLength: number;
|
||||
inputHiddenSize: number;
|
||||
hiddenSize: number;
|
||||
vHiddenSize: number;
|
||||
headSize: number;
|
||||
vHeadSize: number;
|
||||
numHeads: number;
|
||||
isUnidirectional: boolean;
|
||||
pastPresentShareBuffer: boolean;
|
||||
maskFilterValue: number;
|
||||
maskType: AttentionMaskType;
|
||||
scale: number;
|
||||
broadcastResPosBias: boolean;
|
||||
passPastInKv: boolean;
|
||||
qkvFormat: AttentionQkvFormat;
|
||||
}
|
||||
|
||||
export interface AttentionAttrs {
|
||||
numHeads: number;
|
||||
isUnidirectional: number;
|
||||
maskFilterValue: number;
|
||||
scale: number;
|
||||
doRotary: number;
|
||||
qkvHiddenSizes: number[];
|
||||
pastPresentShareBuffer: boolean;
|
||||
}
|
||||
|
||||
const validateAttentionInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => {
|
||||
// Abbreviation and Meanings:
|
||||
// B: batch_size
|
||||
// S: sequence_length (input sequence length of query)
|
||||
// P: past_sequence_length (past sequence length of key or value)
|
||||
// L: kv_sequence_length (input sequence length of key or value)
|
||||
// M: max_sequence_length
|
||||
// T: total_sequence_length = past_sequence_length + kv_sequence_length
|
||||
// N: num_heads
|
||||
// H: head size for Q and K, aka q_head_size or k_head_size or qk_head_size
|
||||
// H_v: v_head_size
|
||||
// D_i: input hidden size
|
||||
// D: hidden size for Q and K (D = N * H), aka q_hidden_size or k_hidden_size or qk_hidden_size
|
||||
// D_v: v_hidden_size = num_heads * v_head_size
|
||||
|
||||
// When past state is used, Q, K and V should have same hidden size (unless we split it into past_key and past_value).
|
||||
|
||||
// Input shapes:
|
||||
// input (Q/K/V) : (B, S, D_i)
|
||||
// weights (Q/K/V) : (D_i, D + D + D_v)
|
||||
// bias (Q/K/V) : (D + D + D_v)
|
||||
// mask_index : see below
|
||||
// past (K/V) : (2, B, N, P, H) or NULL
|
||||
// relative_position_bias : (B, N, S, T) or NULL
|
||||
|
||||
// For mask_index, the following shapes are supported:
|
||||
// NULL, (B, 1), (1, 1)
|
||||
// (B), (2 * B), (3 * B + 2)
|
||||
// (B, T)
|
||||
// (B, S, T)
|
||||
// (B, 1, M, M)
|
||||
//
|
||||
// When a model is pruned (like some attention heads are removed in Q/K/V), input_hidden_size could be larger
|
||||
// than hidden dimension of Q, K and V.
|
||||
|
||||
const input = inputs[0];
|
||||
const weights = inputs[1];
|
||||
const bias = inputs[2];
|
||||
const maskIndex = inputs[3];
|
||||
const past = inputs[4];
|
||||
const relativePositionBias = inputs[5];
|
||||
|
||||
if (past && relativePositionBias) {
|
||||
throw new Error('Attention cannot have both past and relative_position_bias');
|
||||
}
|
||||
|
||||
if (input.dims.length !== 3) {
|
||||
throw new Error('Input "input" must have 3 dimensions');
|
||||
}
|
||||
|
||||
const batchSize = input.dims[0];
|
||||
const sequenceLength = input.dims[1];
|
||||
const inputHiddenSize = input.dims[2];
|
||||
|
||||
if (bias.dims.length !== 1) {
|
||||
throw new Error('Input "bias" is expected to have 1 dimensions');
|
||||
}
|
||||
|
||||
if (weights.dims.length !== 2) {
|
||||
throw new Error('Input "weights" is expected to have 2 dimensions');
|
||||
}
|
||||
|
||||
if (weights.dims[0] !== inputHiddenSize) {
|
||||
throw new Error('Input 1 dimension 0 should have same length as dimension 2 of input 0');
|
||||
}
|
||||
|
||||
if (bias.dims[0] !== weights.dims[1]) {
|
||||
throw new Error('Input "bias" dimension 0 should have same length as dimension 1 of input "weights"');
|
||||
}
|
||||
|
||||
let qHiddenSize = bias.dims[0] / 3;
|
||||
let kHiddenSize = qHiddenSize;
|
||||
let vHiddenSize = kHiddenSize;
|
||||
if (attributes.qkvHiddenSizes.length > 0) {
|
||||
if (attributes.qkvHiddenSizes.length !== 3) {
|
||||
throw new Error('qkv_hidden_sizes attribute should have 3 elements');
|
||||
}
|
||||
for (const sz of attributes.qkvHiddenSizes) {
|
||||
if (sz % attributes.numHeads !== 0) {
|
||||
throw new Error('qkv_hidden_sizes should be divisible by num_heads');
|
||||
}
|
||||
}
|
||||
|
||||
qHiddenSize = attributes.qkvHiddenSizes[0];
|
||||
kHiddenSize = attributes.qkvHiddenSizes[1];
|
||||
vHiddenSize = attributes.qkvHiddenSizes[2];
|
||||
}
|
||||
|
||||
const kvSequenceLength = sequenceLength;
|
||||
|
||||
if (qHiddenSize !== kHiddenSize) {
|
||||
throw new Error('qkv_hidden_sizes first element should be same as the second');
|
||||
}
|
||||
|
||||
if (bias.dims[0] !== qHiddenSize + kHiddenSize + vHiddenSize) {
|
||||
throw new Error('Input "bias" dimension 0 should have same length as sum of Q/K/V hidden sizes');
|
||||
}
|
||||
|
||||
let pastSequenceLength = 0;
|
||||
if (past) {
|
||||
if (kHiddenSize !== vHiddenSize) {
|
||||
throw new Error('Input "past" expect k_hidden_size == v_hidden_size');
|
||||
}
|
||||
if (past.dims.length !== 5) {
|
||||
throw new Error('Input "past" must have 5 dimensions');
|
||||
}
|
||||
if (past.dims[0] !== 2) {
|
||||
throw new Error('Input "past" first dimension must be 2');
|
||||
}
|
||||
if (past.dims[1] !== batchSize) {
|
||||
throw new Error('Input "past" second dimension must be batch_size');
|
||||
}
|
||||
if (past.dims[2] !== attributes.numHeads) {
|
||||
throw new Error('Input "past" third dimension must be num_heads');
|
||||
}
|
||||
if (past.dims[4] !== kHiddenSize / attributes.numHeads) {
|
||||
throw new Error('Input "past" fifth dimension must be k_hidden_size / num_heads');
|
||||
}
|
||||
|
||||
if (!attributes.pastPresentShareBuffer) {
|
||||
pastSequenceLength = past.dims[3];
|
||||
}
|
||||
// TODO: handle past_seq_len
|
||||
}
|
||||
|
||||
const totalSequenceLength = kvSequenceLength + pastSequenceLength;
|
||||
const maxSequenceLength = -1;
|
||||
|
||||
const maskType = AttentionMaskType.none;
|
||||
if (maskIndex) {
|
||||
// maskType = AttentionMaskType.MASK_UNKNOWN;
|
||||
// TODO: handle mask
|
||||
throw new Error('Mask not supported');
|
||||
}
|
||||
|
||||
if (past) {
|
||||
throw new Error('past is not supported');
|
||||
}
|
||||
if (relativePositionBias) {
|
||||
throw new Error('relativePositionBias is not supported');
|
||||
}
|
||||
|
||||
return {
|
||||
batchSize,
|
||||
sequenceLength,
|
||||
pastSequenceLength,
|
||||
kvSequenceLength,
|
||||
totalSequenceLength,
|
||||
maxSequenceLength,
|
||||
inputHiddenSize,
|
||||
hiddenSize: qHiddenSize,
|
||||
vHiddenSize,
|
||||
headSize: Math.floor(qHiddenSize / attributes.numHeads),
|
||||
vHeadSize: Math.floor(vHiddenSize / attributes.numHeads),
|
||||
numHeads: attributes.numHeads,
|
||||
isUnidirectional: false,
|
||||
pastPresentShareBuffer: false,
|
||||
maskFilterValue: attributes.maskFilterValue,
|
||||
maskType,
|
||||
scale: attributes.scale,
|
||||
broadcastResPosBias: false,
|
||||
passPastInKv: false,
|
||||
qkvFormat: AttentionQkvFormat.qkvBNSH,
|
||||
};
|
||||
};
|
||||
|
||||
export const parseAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs =>
|
||||
createAttributeWithCacheKey({...attributes});
|
||||
|
||||
export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView, n: number, d: number) => {
|
||||
const components = getMaxComponents(d);
|
||||
const inputHelper = outputVariable('x', input.dataType, input.dims, components);
|
||||
|
||||
let threadMaxValue = 'threadMaxVector';
|
||||
if (components === 2) {
|
||||
threadMaxValue = 'max(threadMaxVector.x, threadMaxVector.y)';
|
||||
} else if (components === 4) {
|
||||
threadMaxValue = 'max(max(threadMaxVector.x, threadMaxVector.y), max(threadMaxVector.z, threadMaxVector.w))';
|
||||
}
|
||||
const dataType = tensorTypeToWsglStorageType(input.dataType);
|
||||
let WG = 64;
|
||||
const dComp = d / components;
|
||||
if (dComp < WG) {
|
||||
WG = 1;
|
||||
} else if (dComp / 8 < 64) {
|
||||
WG = Math.ceil(dComp / 8);
|
||||
}
|
||||
const elementsPerWG = Math.ceil(d / components / WG);
|
||||
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
const dInv: ${dataType} = 1 / ${d};
|
||||
const dComp = ${d / components};
|
||||
var<workgroup> wgMax: array<f32, ${WG}>;
|
||||
var<workgroup> wgSum: array<f32, ${WG}>;
|
||||
|
||||
${shaderHelper.declareVariables(inputHelper)}
|
||||
@compute @workgroup_size(${WG}, 1, 1)
|
||||
fn main(@builtin(workgroup_id) workgroup_id : vec3<u32>,
|
||||
@builtin(local_invocation_index) local_index : u32) {
|
||||
let localOffset = local_index * ${elementsPerWG};
|
||||
let offset: u32 = workgroup_id.x * dComp + localOffset;
|
||||
|
||||
var threadMaxVector = ${fillVector('f32', components, '-3.402823e+38f')};
|
||||
for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) {
|
||||
threadMaxVector = max(${castToF32(dataType, components, 'x[offset + i]')}, threadMaxVector);
|
||||
}
|
||||
wgMax[local_index] = ${threadMaxValue};
|
||||
workgroupBarrier();
|
||||
|
||||
var maxValue = -3.402823e+38f;
|
||||
for (var i = 0u; i < ${WG}; i++) {
|
||||
maxValue = max(wgMax[i], maxValue);
|
||||
}
|
||||
|
||||
var sumVector = ${fillVector('f32', components, '0')};
|
||||
for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) {
|
||||
sumVector += exp(${castToF32(dataType, components, 'x[offset + i]')} - maxValue);
|
||||
}
|
||||
wgSum[local_index] = ${sumVector('sumVector', components)};
|
||||
workgroupBarrier();
|
||||
|
||||
var sum: f32 = 0;
|
||||
for (var i = 0u; i < ${WG}; i++) {
|
||||
sum += wgSum[i];
|
||||
}
|
||||
|
||||
if (sum == 0) {
|
||||
for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) {
|
||||
x[offset + i] = ${fillVector(dataType, components, 'dInv')};
|
||||
}
|
||||
} else {
|
||||
for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) {
|
||||
let f32input = ${castToF32(dataType, components, 'x[offset + i]')};
|
||||
x[offset + i] = ${inputHelper.type.value}(exp(f32input - maxValue) / sum);
|
||||
}
|
||||
}
|
||||
}`;
|
||||
|
||||
context.compute(
|
||||
{
|
||||
name: 'AttentionProbsSoftmax',
|
||||
shaderCache: {hint: `${d}`},
|
||||
getShaderSource,
|
||||
getRunData: () => ({
|
||||
outputs: [],
|
||||
dispatchGroup: {x: n},
|
||||
}),
|
||||
},
|
||||
{inputs: [input], outputs: []});
|
||||
};
|
||||
|
||||
const computeAttentionProbs =
|
||||
(context: ComputeContext, q: TensorView, key: TensorView, _bias: TensorView|undefined,
|
||||
parameters: AttentionParameters, attributes: AttentionAttrs) => {
|
||||
const probsShape = [
|
||||
parameters.batchSize, parameters.numHeads, parameters.sequenceLength,
|
||||
parameters.kvSequenceLength + parameters.pastSequenceLength
|
||||
];
|
||||
// TODO: handle mask
|
||||
|
||||
const alpha = attributes.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : attributes.scale;
|
||||
|
||||
const dataType = tensorTypeToWsglStorageType(q.dataType);
|
||||
|
||||
const components = getMaxComponents(parameters.headSize);
|
||||
const qInput = inputVariable('q', q.dataType, q.dims, components);
|
||||
const kInput = inputVariable('key', key.dataType, key.dims, components);
|
||||
const output = outputVariable('output', q.dataType, probsShape);
|
||||
|
||||
const vectorizedHeadSize = parameters.headSize / components;
|
||||
const M = parameters.sequenceLength;
|
||||
const N = parameters.totalSequenceLength;
|
||||
const K = vectorizedHeadSize;
|
||||
|
||||
const TILE_SIZE = 12;
|
||||
|
||||
const dispatch = {
|
||||
x: Math.ceil(parameters.totalSequenceLength / TILE_SIZE),
|
||||
y: Math.ceil(parameters.sequenceLength / TILE_SIZE),
|
||||
z: parameters.batchSize * parameters.numHeads
|
||||
};
|
||||
|
||||
const inputs = [q, key];
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
const M: u32 = ${M}u;
|
||||
const N: u32 = ${N}u;
|
||||
const K: u32 = ${K}u;
|
||||
const alpha: ${dataType} = ${alpha};
|
||||
const beta: ${dataType} = 1.0;
|
||||
const TILE_SIZE = ${TILE_SIZE}u;
|
||||
|
||||
var<workgroup> tileQ: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>;
|
||||
var<workgroup> tileK: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>;
|
||||
|
||||
${shaderHelper.declareVariables(qInput, kInput, output)}
|
||||
|
||||
@compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1)
|
||||
fn main(@builtin(workgroup_id) workgroup_id : vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id : vec3<u32>, @builtin(local_invocation_index) local_index : u32) {
|
||||
let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u +
|
||||
workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index;
|
||||
|
||||
// x holds the N and y holds the M
|
||||
let headIdx = workgroup_id.z;
|
||||
let m = workgroup_id.y * TILE_SIZE;
|
||||
let n = workgroup_id.x * TILE_SIZE;
|
||||
let lm = m + local_id.y;
|
||||
let ln = n + local_id.x;
|
||||
|
||||
let qOffset = ${parameters.sequenceLength * vectorizedHeadSize} * headIdx + m * K;
|
||||
let kOffset = ${parameters.kvSequenceLength * vectorizedHeadSize} * headIdx + n * K;
|
||||
|
||||
var value = ${fillVector(dataType, components)};
|
||||
for (var w: u32 = 0u; w < K; w += TILE_SIZE) {
|
||||
if (m + local_id.y < M && w + local_id.x < K) {
|
||||
tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * K + w + local_id.x];
|
||||
}
|
||||
if (n + local_id.y < N && w + local_id.x < K) {
|
||||
tileK[TILE_SIZE * local_id.y + local_id.x] = key[kOffset + local_id.y * K + w + local_id.x];
|
||||
}
|
||||
workgroupBarrier();
|
||||
|
||||
for (var k: u32 = 0u; k<TILE_SIZE && w+k < K; k++) {
|
||||
value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k];
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
let headOffset = headIdx * M * N;
|
||||
if (lm < M && ln < N) {
|
||||
let outputIdx = headOffset + lm * N + ln;
|
||||
output[outputIdx] = ${sumVector('value', components)} * alpha;
|
||||
}
|
||||
}`;
|
||||
|
||||
const probs = context.compute(
|
||||
{
|
||||
name: 'AttentionProbs',
|
||||
shaderCache: {hint: JSON.stringify(parameters)},
|
||||
getRunData: () => ({
|
||||
outputs: [{dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default}],
|
||||
dispatchGroup: dispatch,
|
||||
}),
|
||||
getShaderSource,
|
||||
},
|
||||
{inputs, outputs: [-1]})[0];
|
||||
|
||||
computeInPlaceSoftmax(
|
||||
context, probs, parameters.batchSize * parameters.numHeads * parameters.sequenceLength,
|
||||
parameters.totalSequenceLength);
|
||||
|
||||
return probs;
|
||||
};
|
||||
|
||||
const computeVxAttentionScore =
|
||||
(context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters) => {
|
||||
const outputShape = [params.batchSize, params.sequenceLength, params.vHiddenSize];
|
||||
|
||||
const probsHelper = inputVariable('probs', probs.dataType, probs.dims);
|
||||
const vHelper = inputVariable('v', v.dataType, v.dims);
|
||||
const output = outputVariable('output', probs.dataType, outputShape);
|
||||
|
||||
const dataType = tensorTypeToWsglStorageType(probs.dataType);
|
||||
|
||||
const TILE_SIZE = 12;
|
||||
const dispatch = {
|
||||
x: Math.ceil(params.vHeadSize / TILE_SIZE),
|
||||
y: Math.ceil(params.sequenceLength / TILE_SIZE),
|
||||
z: params.batchSize * params.numHeads
|
||||
};
|
||||
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
const M: u32 = ${params.sequenceLength}u;
|
||||
const N: u32 = ${params.vHeadSize}u;
|
||||
const K: u32 = ${params.totalSequenceLength}u;
|
||||
const numHeads: u32 = ${params.numHeads}u;
|
||||
const TILE_SIZE = ${TILE_SIZE}u;
|
||||
|
||||
var<workgroup> tileQ: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>;
|
||||
var<workgroup> tileK: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>;
|
||||
|
||||
${shaderHelper.declareVariables(probsHelper, vHelper, output)}
|
||||
|
||||
@compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1)
|
||||
fn main(@builtin(workgroup_id) workgroup_id : vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id : vec3<u32>, @builtin(local_invocation_index) local_index : u32) {
|
||||
let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u +
|
||||
workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index;
|
||||
|
||||
let headIdx = workgroup_id.z;
|
||||
let m = workgroup_id.y * TILE_SIZE + local_id.y;
|
||||
let n = workgroup_id.x * TILE_SIZE + local_id.x;
|
||||
|
||||
let offsetA = headIdx * (M * K) + m * K;
|
||||
let offsetB = headIdx * (N * K) + n;
|
||||
|
||||
var value = ${dataType}(0);
|
||||
for (var w: u32 = 0u; w < K; w += TILE_SIZE) {
|
||||
if (m < M && w + local_id.x < K) {
|
||||
tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x];
|
||||
}
|
||||
if (n < N && w + local_id.y < K) {
|
||||
tileK[TILE_SIZE * local_id.y + local_id.x] = v[offsetB + (w + local_id.y) * N];
|
||||
}
|
||||
workgroupBarrier();
|
||||
for (var k: u32 = 0u; k<TILE_SIZE && w+k < K; k++) {
|
||||
value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x];
|
||||
}
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
// we need to transpose output from BNSH_v to BSND_v
|
||||
let batchIdx = workgroup_id.z / ${params.numHeads};
|
||||
let currentBatchHeadNumber = workgroup_id.z % ${params.numHeads};
|
||||
let headOffset = (batchIdx * M * ${params.numHeads} + currentBatchHeadNumber) * ${params.vHeadSize};
|
||||
if (m < M && n < N) {
|
||||
let outputIdx = batchIdx * ${params.sequenceLength * params.vHiddenSize} + m * ${params.vHiddenSize}
|
||||
+ currentBatchHeadNumber * ${params.vHeadSize} + n;
|
||||
output[outputIdx] = value;
|
||||
}
|
||||
}`;
|
||||
|
||||
return context.compute(
|
||||
{
|
||||
name: 'AttentionScore',
|
||||
shaderCache: {hint: JSON.stringify(params)},
|
||||
getRunData: () => ({
|
||||
outputs: [{dims: outputShape, dataType: probs.dataType, gpuDataType: GpuDataType.default}],
|
||||
dispatchGroup: dispatch,
|
||||
}),
|
||||
getShaderSource,
|
||||
},
|
||||
{inputs: [probs, v], outputs: [0]})[0];
|
||||
};
|
||||
|
||||
export const applyAttention =
|
||||
(context: ComputeContext, q: TensorView, k: TensorView, v: TensorView, _maskIndex: TensorView|undefined,
|
||||
_past: TensorView|undefined, _pastKey: TensorView|undefined, _pastValue: TensorView|undefined,
|
||||
relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs) => {
|
||||
const probs = computeAttentionProbs(context, q, k, relativePositionBias, parameters, attributes);
|
||||
|
||||
computeVxAttentionScore(context, probs, v, parameters);
|
||||
};
|
||||
|
||||
const prepare = (context: ComputeContext, parameters: AttentionParameters) => {
|
||||
const outputShape = [
|
||||
parameters.batchSize,
|
||||
parameters.numHeads,
|
||||
parameters.sequenceLength,
|
||||
parameters.headSize,
|
||||
];
|
||||
|
||||
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
|
||||
|
||||
const M = parameters.sequenceLength;
|
||||
const K = parameters.inputHiddenSize;
|
||||
const N = parameters.headSize;
|
||||
|
||||
const TILE_SIZE = 12;
|
||||
const dispatch = {
|
||||
x: Math.ceil(parameters.headSize / TILE_SIZE),
|
||||
y: Math.ceil(parameters.sequenceLength / TILE_SIZE),
|
||||
z: parameters.batchSize * parameters.numHeads
|
||||
};
|
||||
|
||||
const getShaderSource = () => `
|
||||
const M: u32 = ${M}u;
|
||||
const K: u32 = ${K}u;
|
||||
const N: u32 = ${N}u;
|
||||
const numHeads: u32 = ${parameters.numHeads};
|
||||
const ldb = ${parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize}u;
|
||||
const TILE_SIZE = ${TILE_SIZE}u;
|
||||
|
||||
var<workgroup> tileInput: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>;
|
||||
var<workgroup> tileWeightQ: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>;
|
||||
var<workgroup> tileWeightK: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>;
|
||||
var<workgroup> tileWeightV: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>;
|
||||
|
||||
@group(0) @binding(0) var<storage, read> input: array<${dataType}>;
|
||||
@group(0) @binding(1) var<storage, read> weight: array<${dataType}>;
|
||||
@group(0) @binding(2) var<storage, read> bias: array<${dataType}>;
|
||||
@group(0) @binding(3) var<storage, read_write> outputQ: array<${dataType}>;
|
||||
@group(0) @binding(4) var<storage, read_write> outputK: array<${dataType}>;
|
||||
@group(0) @binding(5) var<storage, read_write> outputV: array<${dataType}>;
|
||||
|
||||
@compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1)
|
||||
fn main(@builtin(workgroup_id) workgroup_id : vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id : vec3<u32>, @builtin(local_invocation_index) local_index : u32) {
|
||||
let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u +
|
||||
workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index;
|
||||
|
||||
let batchIndex = workgroup_id.z / ${parameters.numHeads};
|
||||
let headNumber = workgroup_id.z % ${parameters.numHeads};
|
||||
let m = workgroup_id.y * TILE_SIZE + local_id.y;
|
||||
let n = workgroup_id.x * TILE_SIZE + local_id.x;
|
||||
|
||||
let inputOffset = batchIndex * (M * K) + m * K;
|
||||
let biasOffsetQ = headNumber * ${parameters.headSize};
|
||||
let biasOffsetK = ${parameters.hiddenSize} + biasOffsetQ;
|
||||
let biasOffsetV = ${parameters.hiddenSize} + biasOffsetK;
|
||||
|
||||
var valueQ = ${dataType}(0);
|
||||
var valueK = ${dataType}(0);
|
||||
var valueV = ${dataType}(0);
|
||||
for (var w: u32 = 0u; w < K; w += TILE_SIZE) {
|
||||
if (m < M && w + local_id.x < K) {
|
||||
tileInput[TILE_SIZE * local_id.y + local_id.x] = input[inputOffset + w + local_id.x];
|
||||
}
|
||||
if (n < N && w + local_id.y < K) {
|
||||
let offset = n + (w + local_id.y) * ldb;
|
||||
tileWeightQ[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetQ + offset];
|
||||
tileWeightK[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetK + offset];
|
||||
tileWeightV[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetV + offset];
|
||||
}
|
||||
workgroupBarrier();
|
||||
for (var k: u32 = 0u; k<TILE_SIZE && w+k < K; k++) {
|
||||
let inputTileOffset = TILE_SIZE * local_id.y + k;
|
||||
let weightTileOffset = TILE_SIZE * k + local_id.x;
|
||||
valueQ += tileInput[inputTileOffset] * tileWeightQ[weightTileOffset];
|
||||
valueK += tileInput[inputTileOffset] * tileWeightK[weightTileOffset];
|
||||
valueV += tileInput[inputTileOffset] * tileWeightV[weightTileOffset];
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
let headOffset = (m * N + n) % ${parameters.headSize};
|
||||
valueQ += bias[headOffset + biasOffsetQ];
|
||||
valueK += bias[headOffset + biasOffsetK];
|
||||
valueV += bias[headOffset + biasOffsetV];
|
||||
|
||||
let offset = workgroup_id.z * M * N;
|
||||
if (m < M && n < N) {
|
||||
let outputIdx = offset + m * N + n;
|
||||
outputQ[outputIdx] = valueQ;
|
||||
outputK[outputIdx] = valueK;
|
||||
outputV[outputIdx] = valueV;
|
||||
}
|
||||
}`;
|
||||
|
||||
const inputs = [context.inputs[0], context.inputs[1], context.inputs[2]];
|
||||
|
||||
return context.compute(
|
||||
{
|
||||
name: 'AttentionPrepare',
|
||||
shaderCache: {hint: JSON.stringify(parameters)},
|
||||
getRunData: () => ({
|
||||
outputs: [
|
||||
{dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default},
|
||||
{dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default},
|
||||
{dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default},
|
||||
],
|
||||
dispatchGroup: dispatch,
|
||||
}),
|
||||
getShaderSource,
|
||||
},
|
||||
{inputs, outputs: [-1, -1, -1]});
|
||||
};
|
||||
|
||||
export const attention = (context: ComputeContext, attributes: AttentionAttrs): void => {
|
||||
const params = validateAttentionInputs(context.inputs, attributes);
|
||||
|
||||
const [q, k, v] = prepare(context, params);
|
||||
|
||||
return applyAttention(
|
||||
context, q, k, v, context.inputs[4], undefined, undefined, undefined, context.inputs[5], params, attributes);
|
||||
};
|
||||
335
js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts
Normal file
335
js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts
Normal file
|
|
@ -0,0 +1,335 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {TensorView} from '../../tensor-view';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext, GpuDataType} from '../types';
|
||||
|
||||
import {applyAttention, AttentionAttrs, AttentionMaskType, AttentionParameters, AttentionQkvFormat} from './attention';
|
||||
import {ShaderHelper, tensorTypeToWsglStorageType} from './common';
|
||||
import {createTransposeProgramInfo, TransposeAttributes} from './transpose';
|
||||
|
||||
const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => {
|
||||
const query = inputs[0];
|
||||
const key = inputs[1];
|
||||
const value = inputs[2];
|
||||
const bias = inputs[3];
|
||||
const keyPaddingMask = inputs[4];
|
||||
const relativePositionBias = inputs[5];
|
||||
const pastKey = inputs[6];
|
||||
const pastValue = inputs[7];
|
||||
|
||||
// Abbreviation and Meanings:
|
||||
// B: batch_size
|
||||
// S: sequence_length (input sequence length of query)
|
||||
// P: past_sequence_length (past sequence length of key or value)
|
||||
// L: kv_sequence_length (input sequence length of key or value)
|
||||
// M: max_sequence_length
|
||||
// T: total_sequence_length = past_sequence_length + kv_sequence_length
|
||||
// N: num_heads
|
||||
// H: head size for Q and K, aka q_head_size or k_head_size or qk_head_size
|
||||
// H_v: v_head_size
|
||||
// D_i: input hidden size
|
||||
// D: hidden size for Q and K (D = N * H), aka q_hidden_size or k_hidden_size or qk_hidden_size
|
||||
// D_v: v_hidden_size = num_heads * v_head_size
|
||||
|
||||
// key_padding_mask (K/V) : (B) or (2*B + 1) or (B, L) or None
|
||||
// relative_position_bias : (B, 1, S, L)
|
||||
// past_key : (B, N, S*, H)
|
||||
// past_value : (B, N, S*, H)
|
||||
// When no packing for q/k/v:
|
||||
// query (Q) : (B, S, D)
|
||||
// key (K) : (B, L, D) or (B, N, S*, H)
|
||||
// value (V) : (B, L, D_v) or (B, N, S*, H)
|
||||
// bias (Q/K/V) : (D + D + D_v)
|
||||
// When packed kv is used:
|
||||
// query (Q) : (B, S, D)
|
||||
// key (K) : (B, L, N, 2, H)
|
||||
// value (V) : None
|
||||
// bias (Q/K/V) : None
|
||||
// When packed qkv is used:
|
||||
// query (Q) : (B, L, N, 3, H) or (B, S, 3*D)
|
||||
// key (K) : None
|
||||
// value (V) : None
|
||||
// bias (Q/K/V) : None or (D + D + D_v)
|
||||
|
||||
if (query.dims.length !== 3 && query.dims.length !== 5) {
|
||||
throw new Error('Input query is expected to have 3 or 5 dimensions');
|
||||
}
|
||||
|
||||
const dmmhaPacking = false;
|
||||
const batchSize = query.dims[0];
|
||||
const sequenceLength = query.dims[1];
|
||||
const hiddenSize = query.dims.length === 3 ? (dmmhaPacking ? query.dims[2] / 3 : query.dims[2]) :
|
||||
attributes.numHeads * query.dims[4];
|
||||
let kvSequenceLength = sequenceLength;
|
||||
|
||||
let pastSequenceLength = 0;
|
||||
let maxSequenceLength = 0;
|
||||
const headSize = Math.floor(hiddenSize / attributes.numHeads);
|
||||
if (pastKey && pastValue) {
|
||||
if (pastKey.dims.length !== 4) {
|
||||
throw new Error('Input "past_key" is expected to have 4 dimensions');
|
||||
}
|
||||
if (pastValue.dims.length !== 4) {
|
||||
throw new Error('Input "past_value" is expected to have 4 dimensions');
|
||||
}
|
||||
pastSequenceLength = pastKey.dims[2];
|
||||
maxSequenceLength = pastKey.dims[2];
|
||||
} else if (pastKey || pastValue) {
|
||||
throw new Error('Input "past_key" and "past_value" shall be both present or both absent');
|
||||
}
|
||||
|
||||
let qkvFormat: AttentionQkvFormat;
|
||||
if (key) {
|
||||
if (query.dims.length !== 3) {
|
||||
throw new Error('Input "query" is expected to have 3 dimensions when key is given');
|
||||
}
|
||||
if (key.dims.length < 3 || key.dims.length > 5) {
|
||||
throw new Error('Input "key" is expected to have 3, 4, or 5 dimensions');
|
||||
}
|
||||
if (query.dims[0] !== key.dims[0]) {
|
||||
throw new Error('Input "query" and "key" shall have same dim 0 (batch size)');
|
||||
}
|
||||
|
||||
if (key.dims.length === 3) {
|
||||
if (key.dims[2] !== query.dims[2]) {
|
||||
throw new Error('Input "query" and "key" shall have same dim 2 (hidden_size)');
|
||||
}
|
||||
qkvFormat = AttentionQkvFormat.qkvBSNH;
|
||||
kvSequenceLength = key.dims[1];
|
||||
} else if (key.dims.length === 5) {
|
||||
if (key.dims[2] !== attributes.numHeads || key.dims[3] !== 2 || key.dims[4] !== headSize) {
|
||||
throw new Error('Expect "key" shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv');
|
||||
}
|
||||
if (value) {
|
||||
throw new Error('Expect "value" be none when "key" has packed kv format.');
|
||||
}
|
||||
qkvFormat = AttentionQkvFormat.qKvBSNHxBSN2H;
|
||||
kvSequenceLength = key.dims[1];
|
||||
} else { // key_dims.size() == 4 (cross-attention with past_key)
|
||||
if (key.dims[1] !== attributes.numHeads || key.dims[3] !== headSize) {
|
||||
throw new Error('Expect "key" shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key');
|
||||
}
|
||||
|
||||
qkvFormat = AttentionQkvFormat.unknown;
|
||||
kvSequenceLength = key.dims[2];
|
||||
}
|
||||
} else { // packed QKV
|
||||
if (query.dims.length !== 3 && query.dims.length !== 5) {
|
||||
throw new Error('Input "query" is expected to have 3 or 5 dimensions when key is empty');
|
||||
}
|
||||
if (query.dims.length === 5 && (query.dims[2] !== attributes.numHeads || query.dims[3] !== 3)) {
|
||||
throw new Error('Expect "query" shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv');
|
||||
}
|
||||
|
||||
qkvFormat = AttentionQkvFormat.qkvBSN3H;
|
||||
}
|
||||
|
||||
if (bias) {
|
||||
if (bias.dims.length !== 1) {
|
||||
throw new Error('Input "bias" is expected to have 1 dimension');
|
||||
}
|
||||
|
||||
if (value) {
|
||||
if (query.dims.length === 5 && query.dims[3] === 2) {
|
||||
throw new Error('bias is not allowed for packed kv.');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let maskType: AttentionMaskType = AttentionMaskType.none;
|
||||
if (keyPaddingMask) {
|
||||
maskType = AttentionMaskType.maskUnknown;
|
||||
const maskDims = keyPaddingMask.dims;
|
||||
if (maskDims.length === 1) {
|
||||
if (maskDims[0] === batchSize) {
|
||||
maskType = AttentionMaskType.mask1dKeySeqLen;
|
||||
} else if (maskDims[0] === 3 * batchSize + 2) {
|
||||
maskType = AttentionMaskType.mask1DKeySeqLenStart;
|
||||
}
|
||||
} else if (maskDims.length === 2 && maskDims[0] === batchSize && maskDims[1] === kvSequenceLength) {
|
||||
maskType = AttentionMaskType.mask2dKeyPadding;
|
||||
}
|
||||
if (maskType === AttentionMaskType.maskUnknown) {
|
||||
throw new Error('Input "key_padding_mask" shape shall be (batch_size) or (batch_size, kv_sequence_length)');
|
||||
}
|
||||
throw new Error('Mask not supported');
|
||||
}
|
||||
|
||||
let passPastInKv = false;
|
||||
let vHiddenSize = hiddenSize;
|
||||
if (value) {
|
||||
if (value.dims.length !== 3 && value.dims.length !== 4) {
|
||||
throw new Error('Input "value" is expected to have 3 or 4 dimensions');
|
||||
}
|
||||
|
||||
if (query.dims[0] !== value.dims[0]) {
|
||||
throw new Error('Input "query" and "value" shall have same dim 0 (batch_size)');
|
||||
}
|
||||
|
||||
if (value.dims.length === 3) {
|
||||
if (kvSequenceLength !== value.dims[1]) {
|
||||
throw new Error('Input "key" and "value" shall have the same dim 1 (kv_sequence_length)');
|
||||
}
|
||||
vHiddenSize = value.dims[2];
|
||||
} else {
|
||||
if (kvSequenceLength !== value.dims[2]) {
|
||||
throw new Error('Input "past_key" and "past_value" shall have the same dim 2 (kv_sequence_length)');
|
||||
}
|
||||
vHiddenSize = value.dims[1] * value.dims[3];
|
||||
passPastInKv = true;
|
||||
}
|
||||
}
|
||||
|
||||
const totalSequenceLength = pastSequenceLength + kvSequenceLength;
|
||||
const broadcastResPosBias = false;
|
||||
// if (extraAddQk) {
|
||||
// if (extraAddQk.dims[0] === 1) {
|
||||
// broadcastResPosBias = true;
|
||||
// }
|
||||
// }
|
||||
|
||||
if (keyPaddingMask) {
|
||||
throw new Error('Key padding mask is not supported');
|
||||
}
|
||||
if (relativePositionBias) {
|
||||
throw new Error('extraAddQk is not supported');
|
||||
}
|
||||
if (pastKey) {
|
||||
throw new Error('pastKey is not supported');
|
||||
}
|
||||
if (pastValue) {
|
||||
throw new Error('pastValue is not supported');
|
||||
}
|
||||
|
||||
return {
|
||||
batchSize,
|
||||
sequenceLength,
|
||||
pastSequenceLength,
|
||||
kvSequenceLength,
|
||||
totalSequenceLength,
|
||||
maxSequenceLength,
|
||||
inputHiddenSize: 0,
|
||||
hiddenSize,
|
||||
vHiddenSize,
|
||||
headSize,
|
||||
vHeadSize: Math.floor(vHiddenSize / attributes.numHeads),
|
||||
numHeads: attributes.numHeads,
|
||||
isUnidirectional: false,
|
||||
pastPresentShareBuffer: false,
|
||||
maskFilterValue: attributes.maskFilterValue,
|
||||
maskType,
|
||||
scale: attributes.scale,
|
||||
broadcastResPosBias,
|
||||
passPastInKv,
|
||||
qkvFormat,
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
export const parseMultiHeadAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs =>
|
||||
createAttributeWithCacheKey({...attributes});
|
||||
|
||||
const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({perm: [0, 2, 1, 3]});
|
||||
|
||||
const addBiasTranspose =
|
||||
(context: ComputeContext, qkv: TensorView, bias: TensorView, batchSize: number, sequenceLength: number,
|
||||
hiddenSize: number, biasOffset: number) => {
|
||||
const outputShape = [batchSize, sequenceLength, hiddenSize];
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
|
||||
const dataType = tensorTypeToWsglStorageType(qkv.dataType);
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
const biasOffset = ${biasOffset}u;
|
||||
const hiddenSize = ${hiddenSize}u;
|
||||
|
||||
@group(0) @binding(0) var<storage, read> qkv: array<${dataType}>;
|
||||
@group(0) @binding(1) var<storage, read> bias: array<${dataType}>;
|
||||
@group(0) @binding(2) var<storage, read_write> qkv_with_bias: array<${dataType}>;
|
||||
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
|
||||
let biasOffsetIdx = (global_idx % hiddenSize) + biasOffset;
|
||||
|
||||
qkv_with_bias[global_idx] = qkv[global_idx] + bias[biasOffsetIdx];
|
||||
}`;
|
||||
|
||||
return context.compute(
|
||||
{
|
||||
name: 'MultiHeadAttentionAddBias',
|
||||
shaderCache: {hint: JSON.stringify({batchSize, sequenceLength, hiddenSize, biasOffset})},
|
||||
getRunData: () => ({
|
||||
outputs: [{dims: outputShape, dataType: qkv.dataType, gpuDataType: GpuDataType.default}],
|
||||
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
|
||||
}),
|
||||
getShaderSource,
|
||||
},
|
||||
{inputs: [qkv, bias], outputs: [-1]})[0];
|
||||
};
|
||||
|
||||
const maybeTransposeToBNSHAndAddBias =
|
||||
(context: ComputeContext, batchSize: number, numHeads: number, sequenceLength: number, headSize: number,
|
||||
input: TensorView, bias?: TensorView, biasOffset?: number) => {
|
||||
// const newDims = [];
|
||||
|
||||
let reshapedInput = input;
|
||||
if (!bias) {
|
||||
if (input.dims.length === 3) {
|
||||
reshapedInput = input.reshape([batchSize, sequenceLength, numHeads, headSize]);
|
||||
}
|
||||
return context.compute(
|
||||
createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm),
|
||||
{inputs: [reshapedInput], outputs: [-1]})[0];
|
||||
} else {
|
||||
if (sequenceLength === 1) {
|
||||
throw new Error('AddBiasReshape is not implemented. Please export your model with packed QKV or KV');
|
||||
} else {
|
||||
reshapedInput =
|
||||
addBiasTranspose(context, input, bias, batchSize, sequenceLength, numHeads * headSize, biasOffset!);
|
||||
reshapedInput = reshapedInput.reshape([batchSize, sequenceLength, numHeads, headSize]);
|
||||
return context.compute(
|
||||
createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm),
|
||||
{inputs: [reshapedInput], outputs: [-1]})[0];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
export const multiHeadAttention = (context: ComputeContext, attributes: AttentionAttrs): void => {
|
||||
const params = validateInputs(context.inputs, attributes);
|
||||
|
||||
if (context.inputs[0].dims.length === 5) {
|
||||
throw new Error('Packed QKV is not implemented');
|
||||
}
|
||||
|
||||
if (context.inputs[1]?.dims.length === 5) {
|
||||
throw new Error('Packed KV is not implemented');
|
||||
}
|
||||
|
||||
// applyAttention expects BNSH inputs
|
||||
const kvBNSH = context.inputs[1] && context.inputs[2] && context.inputs[1].dims.length === 4 &&
|
||||
context.inputs[2].dims.length === 4;
|
||||
|
||||
const Q = maybeTransposeToBNSHAndAddBias(
|
||||
context, params.batchSize, params.numHeads, params.sequenceLength, params.headSize, context.inputs[0],
|
||||
context.inputs[3], 0);
|
||||
|
||||
if (kvBNSH) {
|
||||
return applyAttention(
|
||||
context, Q, context.inputs[1], context.inputs[2], context.inputs[4], undefined, undefined, undefined,
|
||||
context.inputs[5], params, attributes);
|
||||
}
|
||||
|
||||
const K = maybeTransposeToBNSHAndAddBias(
|
||||
context, params.batchSize, params.numHeads, params.kvSequenceLength, params.headSize, context.inputs[1],
|
||||
context.inputs[3], params.hiddenSize);
|
||||
|
||||
const V = maybeTransposeToBNSHAndAddBias(
|
||||
context, params.batchSize, params.numHeads, params.kvSequenceLength, params.vHeadSize, context.inputs[2],
|
||||
context.inputs[3], 2 * params.hiddenSize);
|
||||
|
||||
applyAttention(
|
||||
context, Q, K, V, context.inputs[4], undefined, context.inputs[6], context.inputs[7], context.inputs[5], params,
|
||||
attributes);
|
||||
};
|
||||
|
|
@ -16,6 +16,8 @@ const COMMENTS: Record<string, string> = {
|
|||
'Reshape': 'no GPU kernel',
|
||||
'Shape': 'no GPU kernel; an ORT warning is generated - need to fix',
|
||||
'Resize': 'CoordinateTransformMode align_corners is not supported with downsampling',
|
||||
'Attention': 'need implementing mask and past/present',
|
||||
'MultiHeadAttention': 'need implementing mask and past/present',
|
||||
};
|
||||
|
||||
/* eslint-disable max-len */
|
||||
|
|
|
|||
557
js/web/test/data/ops/attention.jsonc
Normal file
557
js/web/test/data/ops/attention.jsonc
Normal file
|
|
@ -0,0 +1,557 @@
|
|||
[
|
||||
{
|
||||
"name": "Attention Basic",
|
||||
"operator": "Attention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8],
|
||||
"dims": [1, 2, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
|
||||
"dims": [4, 3],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1, 2, 3],
|
||||
"dims": [3],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [213, 213],
|
||||
"dims": [1, 2, 1],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Attention Basic Batch 2 with 2 heads",
|
||||
"operator": "Attention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 2, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
|
||||
16
|
||||
],
|
||||
"dims": [2, 2, 8],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4,
|
||||
4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4
|
||||
],
|
||||
"dims": [8, 6],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6],
|
||||
"dims": [6],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [320, 321, 320, 321, 320, 321, 320, 321],
|
||||
"dims": [2, 2, 2],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Attention Basic",
|
||||
"operator": "Attention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863],
|
||||
"dims": [1, 3, 2],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094],
|
||||
"dims": [2, 3],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1.1103, -1.6898, -0.989],
|
||||
"dims": [3],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [-1.328187108039856, -1.297916054725647, -0.8599594831466675],
|
||||
"dims": [1, 3, 1],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Attention Basic one head, batch 2",
|
||||
"operator": "Attention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094],
|
||||
"dims": [2, 3, 2],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094],
|
||||
"dims": [2, 3],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1.1103, -1.6898, -0.989],
|
||||
"dims": [3],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [
|
||||
-1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503,
|
||||
-0.25473490357398987
|
||||
],
|
||||
"dims": [2, 3, 1],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Attention Basic 2 head, batch 1",
|
||||
"operator": "Attention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 2, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094],
|
||||
"dims": [2, 3, 2],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643],
|
||||
"dims": [2, 6],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1.1103, -1.6898, -0.989, -0.989, 1.1103, -1.6898],
|
||||
"dims": [6],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [
|
||||
0.8701779842376709, -2.6158859729766846, 0.8710794448852539, -2.5763747692108154, 0.9005484580993652,
|
||||
-2.182751178741455, 2.1661579608917236, -2.1045265197753906, 1.6716957092285156, -1.797281265258789,
|
||||
1.7134947776794434, -1.765358328819275
|
||||
],
|
||||
"dims": [2, 3, 2],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Attention Basic 5 head, batch 2",
|
||||
"operator": "Attention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 5, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [
|
||||
0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094,
|
||||
0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312,
|
||||
-1.8803634643554688, 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156,
|
||||
-1.0069535970687866, -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675,
|
||||
-0.1792980432510376, -0.26380985975265503, -0.25473490357398987
|
||||
],
|
||||
"dims": [2, 3, 5],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643,
|
||||
0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652,
|
||||
-1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236,
|
||||
1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185,
|
||||
-1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503,
|
||||
-0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22,
|
||||
3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709,
|
||||
0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688,
|
||||
2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866,
|
||||
-1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376,
|
||||
-0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312,
|
||||
0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688,
|
||||
2.1661579608917236
|
||||
],
|
||||
"dims": [5, 15],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
1.1103, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376,
|
||||
-0.26380985975265503, -0.25473490357398987, -1.6898, -0.989, -1.9029953479766846, 0.8710794448852539,
|
||||
-1.9054111242294312, -1.8803634643554688, 2.1661579608917236, 1.7134947776794434
|
||||
],
|
||||
"dims": [15],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [
|
||||
-1.6956915855407715, -2.8863370418548584, 1.3899128437042236, 1.6789076328277588, -1.4083852767944336,
|
||||
-1.7009180784225464, -3.1053788661956787, 3.5959298610687256, 1.1027096509933472, -0.009643087163567543,
|
||||
-1.694351315498352, -2.9284396171569824, 1.734721302986145, 2.0606398582458496, -0.2571452260017395,
|
||||
3.671973943710327, -5.285338401794434, -6.833454132080078, 1.7506506443023682, -2.262148380279541,
|
||||
2.5110034942626953, 1.440049171447754, -0.9423203468322754, 1.7506506443023682, -1.86212158203125,
|
||||
-0.5036701560020447, -5.732386589050293, -1.5674757957458496, 1.7506510019302368, -2.264472246170044
|
||||
],
|
||||
"dims": [2, 3, 5],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Attention Basic 5 head, batch 1",
|
||||
"operator": "Attention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 5, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [
|
||||
0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094,
|
||||
0.8701779842376709, 0.9005484580993652, -1.9029953479766846
|
||||
],
|
||||
"dims": [1, 3, 5],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643,
|
||||
0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652,
|
||||
-1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236,
|
||||
1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185,
|
||||
-1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503,
|
||||
-0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22,
|
||||
3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709,
|
||||
0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688,
|
||||
2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866,
|
||||
-1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376,
|
||||
-0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312,
|
||||
0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688,
|
||||
2.1661579608917236
|
||||
],
|
||||
"dims": [5, 15],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||
"dims": [15],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [
|
||||
-1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168,
|
||||
-1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405,
|
||||
-1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326
|
||||
],
|
||||
"dims": [1, 3, 5],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Attention Basic 5 head, batch 3",
|
||||
"operator": "Attention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 5, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [
|
||||
0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094,
|
||||
0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229,
|
||||
-0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652,
|
||||
-1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236,
|
||||
1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185,
|
||||
-1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503,
|
||||
-0.25473490357398987
|
||||
],
|
||||
"dims": [3, 3, 5],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643,
|
||||
0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652,
|
||||
-1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236,
|
||||
1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185,
|
||||
-1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503,
|
||||
-0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22,
|
||||
3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709,
|
||||
0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688,
|
||||
2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866,
|
||||
-1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376,
|
||||
-0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312,
|
||||
0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688,
|
||||
2.1661579608917236
|
||||
],
|
||||
"dims": [5, 15],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||
"dims": [15],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [
|
||||
-1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168,
|
||||
-1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405,
|
||||
-1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326,
|
||||
-1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168,
|
||||
-1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405,
|
||||
-1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326,
|
||||
3.7965505123138428, -2.3799397945404053, -3.9530906677246094, 0.5844926834106445, -2.9756431579589844,
|
||||
2.448162794113159, 4.34546422958374, 1.9380426406860352, 0.5870105624198914, -2.7368364334106445,
|
||||
-0.4769568145275116, 4.255186557769775, -3.9529950618743896, 0.6987408995628357, -2.9756433963775635
|
||||
],
|
||||
"dims": [3, 3, 5],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Attention Basic 5 head, batch 3",
|
||||
"operator": "Attention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 5, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [
|
||||
0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094,
|
||||
0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229,
|
||||
-0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652,
|
||||
-1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236,
|
||||
1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185,
|
||||
-1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503,
|
||||
-0.25473490357398987, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674,
|
||||
0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345,
|
||||
0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709,
|
||||
0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688,
|
||||
2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866,
|
||||
-1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376,
|
||||
-0.26380985975265503, -0.25473490357398987
|
||||
],
|
||||
"dims": [3, 3, 10],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643,
|
||||
0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652,
|
||||
-1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236,
|
||||
1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185,
|
||||
-1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503,
|
||||
-0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22,
|
||||
3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709,
|
||||
0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688,
|
||||
2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866,
|
||||
-1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376,
|
||||
-0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312,
|
||||
0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688,
|
||||
2.1661579608917236, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22,
|
||||
3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709,
|
||||
0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688,
|
||||
2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866,
|
||||
-1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376,
|
||||
-0.26380985975265503, -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345,
|
||||
0.2303, 0.4617, 1.44, -2.22, 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652,
|
||||
0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312,
|
||||
-1.8803634643554688, 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156,
|
||||
-1.0069535970687866, -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675,
|
||||
-0.1792980432510376, -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539,
|
||||
-1.9054111242294312, 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312,
|
||||
-1.8803634643554688, 2.1661579608917236
|
||||
],
|
||||
"dims": [10, 15],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
-1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168,
|
||||
-1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405,
|
||||
-1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326
|
||||
],
|
||||
"dims": [15],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [
|
||||
-8.01101303100586, -5.782258987426758, 6.016238689422607, 0.26747000217437744, -6.992541313171387,
|
||||
-8.011263847351074, -5.782248020172119, 5.366001129150391, 0.26747000217437744, -6.99449348449707,
|
||||
-8.011263847351074, -5.782265663146973, 6.016238689422607, 0.26747000217437744, -6.992537021636963,
|
||||
-6.102723598480225, -7.28973388671875, -4.578637599945068, 7.2203369140625, -6.028444766998291,
|
||||
-6.102705478668213, -7.2897748947143555, -3.7882626056671143, 5.393260478973389, -5.754333972930908,
|
||||
-1.3616288900375366, -7.289827823638916, -6.341128349304199, 6.329389572143555, -5.751791954040527,
|
||||
-2.3945987224578857, -14.532954216003418, 3.969801902770996, 12.744998931884766, -11.1966552734375,
|
||||
-2.4002532958984375, -14.538958549499512, -6.684961318969727, 12.476543426513672, -9.24352741241455,
|
||||
-4.787771701812744, -8.640848159790039, 3.969801902770996, -0.6471102833747864, -11.1966552734375
|
||||
],
|
||||
"dims": [3, 3, 5],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Attention Basic 1 head, batch 3",
|
||||
"operator": "Attention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [
|
||||
0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094,
|
||||
0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229,
|
||||
-0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652,
|
||||
-1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236,
|
||||
1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185,
|
||||
-1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503,
|
||||
-0.25473490357398987, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674,
|
||||
0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345,
|
||||
0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709,
|
||||
0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688,
|
||||
2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866,
|
||||
-1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376,
|
||||
-0.26380985975265503, -0.25473490357398987
|
||||
],
|
||||
"dims": [3, 3, 10],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643,
|
||||
0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652,
|
||||
-1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236,
|
||||
1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185,
|
||||
-1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503,
|
||||
-0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22,
|
||||
3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709,
|
||||
0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688,
|
||||
2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866,
|
||||
-1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376,
|
||||
-0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312,
|
||||
0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688,
|
||||
2.1661579608917236, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22,
|
||||
3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709,
|
||||
0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688,
|
||||
2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866,
|
||||
-1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376,
|
||||
-0.26380985975265503, -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345,
|
||||
0.2303, 0.4617, 1.44, -2.22, 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652,
|
||||
0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312,
|
||||
-1.8803634643554688, 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156,
|
||||
-1.0069535970687866, -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675,
|
||||
-0.1792980432510376, -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539,
|
||||
-1.9054111242294312, 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312,
|
||||
-1.8803634643554688, 2.1661579608917236
|
||||
],
|
||||
"dims": [10, 15],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [
|
||||
-1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168,
|
||||
-1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405,
|
||||
-1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326
|
||||
],
|
||||
"dims": [15],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [
|
||||
-8.011263847351074, -5.7822418212890625, 6.016238689422607, 0.26747000217437744, -6.992536544799805,
|
||||
-8.011263847351074, -5.7822418212890625, 6.016238689422607, 0.26747000217437744, -6.992536544799805,
|
||||
-8.011263847351074, -5.7822418212890625, 6.016238689422607, 0.26747000217437744, -6.992536544799805,
|
||||
1.3541864156723022, -7.813620090484619, -6.758509635925293, 7.597365856170654, -13.926229476928711,
|
||||
-1.322464108467102, -7.297357559204102, -0.05962071940302849, 6.347561836242676, -5.869992256164551,
|
||||
-1.3616288900375366, -7.28973388671875, 0.0386197566986084, 6.329389572143555, -5.751791954040527,
|
||||
-2.400698661804199, -14.538958549499512, -7.898950576782227, 12.744998931884766, -11.1966552734375,
|
||||
-2.400698661804199, -14.538958549499512, -7.898950576782227, 12.744998931884766, -11.1966552734375,
|
||||
1.021930456161499, -2.373898983001709, 3.8501391410827637, -0.6108309626579285, -9.256340980529785
|
||||
],
|
||||
"dims": [3, 3, 5],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
194
js/web/test/data/ops/multi-head-attention.jsonc
Normal file
194
js/web/test/data/ops/multi-head-attention.jsonc
Normal file
|
|
@ -0,0 +1,194 @@
|
|||
[
|
||||
{
|
||||
"name": "MultiHeadAttention Basic, one head",
|
||||
"operator": "MultiHeadAttention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8],
|
||||
"dims": [1, 2, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1, 1, 1, 1, 2, 2, 2, 2],
|
||||
"dims": [1, 2, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8],
|
||||
"dims": [1, 2, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [
|
||||
4.973228454589844, 5.973228454589844, 6.973228454589844, 7.973228454589844, 4.999990940093994,
|
||||
5.999990940093994, 6.999990940093994, 7.999990940093994
|
||||
],
|
||||
"dims": [1, 2, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "MultiHeadAttention Basic",
|
||||
"operator": "MultiHeadAttention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 2, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8],
|
||||
"dims": [1, 2, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1, 1, 1, 1, 2, 2, 2, 2],
|
||||
"dims": [1, 2, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8],
|
||||
"dims": [1, 2, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [
|
||||
4.571832656860352, 5.571832656860352, 6.971858501434326, 7.971858501434326, 4.998325824737549,
|
||||
5.998325824737549, 6.999900817871094, 7.999900817871094
|
||||
],
|
||||
"dims": [1, 2, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "MultiHeadAttention Basic with bias",
|
||||
"operator": "MultiHeadAttention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 2, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8],
|
||||
"dims": [1, 2, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1, 1, 1, 1, 2, 2, 2, 2],
|
||||
"dims": [1, 2, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8],
|
||||
"dims": [1, 2, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4],
|
||||
"dims": [12],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [
|
||||
5.943336009979248, 7.94333553314209, 9.999799728393555, 11.999798774719238, 5.9997992515563965,
|
||||
7.9997992515563965, 10, 11.999999046325684
|
||||
],
|
||||
"dims": [1, 2, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "MultiHeadAttention two heads",
|
||||
"operator": "MultiHeadAttention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 2, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
|
||||
"dims": [1, 2, 8],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4],
|
||||
"dims": [1, 2, 8],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
|
||||
"dims": [1, 2, 8],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [
|
||||
8.99963665008545, 9.99963665008545, 10.99963665008545, 11.999635696411133, 13, 14, 15, 16, 9, 10, 11, 12,
|
||||
13, 14, 15, 16
|
||||
],
|
||||
"dims": [1, 2, 8],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "MultiHeadAttention two heads",
|
||||
"operator": "MultiHeadAttention",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"attributes": [{ "name": "num_heads", "data": 2, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[1]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
|
||||
"dims": [1, 2, 8],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1, 1, 1, 1, 2, 2, 2, 2],
|
||||
"dims": [1, 1, 8],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8],
|
||||
"dims": [1, 1, 8],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8],
|
||||
"dims": [1, 2, 8],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
|
@ -1336,6 +1336,7 @@
|
|||
"add_int32.jsonc",
|
||||
//"and.jsonc",
|
||||
"asin.jsonc",
|
||||
"attention.jsonc",
|
||||
"bias-add.jsonc",
|
||||
"bias-split-gelu.jsonc",
|
||||
"ceil.jsonc",
|
||||
|
|
@ -1362,6 +1363,7 @@
|
|||
"matmul-broadcast.jsonc",
|
||||
"mul.jsonc",
|
||||
"mul_int32.jsonc",
|
||||
"multi-head-attention.jsonc",
|
||||
//"neg.jsonc",
|
||||
"neg-int32.jsonc",
|
||||
"not.jsonc",
|
||||
|
|
|
|||
24
onnxruntime/contrib_ops/js/bert/attention.cc
Normal file
24
onnxruntime/contrib_ops/js/bert/attention.cc
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "attention.h"
|
||||
#include "core/providers/js/js_data_types.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace js {
|
||||
|
||||
using onnxruntime::js::JsepSupportedFloatTypes;
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
Attention,
|
||||
kMSDomain,
|
||||
1,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", JsepSupportedFloatTypes()),
|
||||
Attention);
|
||||
|
||||
} // namespace js
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
47
onnxruntime/contrib_ops/js/bert/attention.h
Normal file
47
onnxruntime/contrib_ops/js/bert/attention.h
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "contrib_ops/cpu/bert/attention_base.h"
|
||||
#include "core/providers/js/js_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace js {
|
||||
|
||||
using onnxruntime::contrib::AttentionBase;
|
||||
using onnxruntime::js::JsKernel;
|
||||
|
||||
class Attention : public JsKernel, AttentionBase {
|
||||
public:
|
||||
explicit Attention(const OpKernelInfo& info) : JsKernel(info), AttentionBase(info, false) {
|
||||
std::vector<int32_t> qkv_sizes(qkv_hidden_sizes_.size());
|
||||
if (qkv_hidden_sizes_.size() > 0) {
|
||||
std::transform(qkv_hidden_sizes_.begin(), qkv_hidden_sizes_.end(), qkv_sizes.begin(),
|
||||
[](int64_t sz) { return gsl::narrow_cast<int32_t>(sz); });
|
||||
}
|
||||
|
||||
JSEP_INIT_KERNEL_ATTRIBUTE(Attention, ({
|
||||
"numHeads" : $1,
|
||||
"isUnidirectional" : $2,
|
||||
"maskFilterValue" : $3,
|
||||
"scale" : $4,
|
||||
"doRotary" : $5,
|
||||
"qkvHiddenSizes" : $6 ? (Array.from(HEAP32.subarray(Number($7), Number($7) + $6))) : [],
|
||||
"pastPresentShareBuffer" : !!$8,
|
||||
}),
|
||||
static_cast<int32_t>(num_heads_),
|
||||
static_cast<int32_t>(is_unidirectional_),
|
||||
static_cast<int32_t>(mask_filter_value_),
|
||||
static_cast<int32_t>(scale_),
|
||||
static_cast<int32_t>(do_rotary_),
|
||||
static_cast<int32_t>(qkv_hidden_sizes_.size()),
|
||||
reinterpret_cast<uintptr_t>((qkv_sizes.size() > 0) ? qkv_sizes.data() : nullptr) >> 2,
|
||||
static_cast<int32_t>(past_present_share_buffer_));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace js
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
24
onnxruntime/contrib_ops/js/bert/multi_head_attention.cc
Normal file
24
onnxruntime/contrib_ops/js/bert/multi_head_attention.cc
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "multi_head_attention.h"
|
||||
#include "core/providers/js/js_data_types.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace js {
|
||||
|
||||
using onnxruntime::js::JsepSupportedFloatTypes;
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
MultiHeadAttention,
|
||||
kMSDomain,
|
||||
1,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", JsepSupportedFloatTypes()),
|
||||
MultiHeadAttention);
|
||||
|
||||
} // namespace js
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
36
onnxruntime/contrib_ops/js/bert/multi_head_attention.h
Normal file
36
onnxruntime/contrib_ops/js/bert/multi_head_attention.h
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "contrib_ops/cpu/bert/attention_base.h"
|
||||
#include "core/providers/js/js_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace js {
|
||||
|
||||
using onnxruntime::contrib::AttentionBase;
|
||||
using onnxruntime::js::JsKernel;
|
||||
|
||||
class MultiHeadAttention : public JsKernel, AttentionBase {
|
||||
public:
|
||||
explicit MultiHeadAttention(const OpKernelInfo& info) : JsKernel(info), AttentionBase(info, false) {
|
||||
JSEP_INIT_KERNEL_ATTRIBUTE(MultiHeadAttention, ({
|
||||
"numHeads" : $1,
|
||||
"isUnidirectional" : $2,
|
||||
"maskFilterValue" : $3,
|
||||
"scale" : $4,
|
||||
"doRotary" : $5,
|
||||
}),
|
||||
static_cast<int32_t>(num_heads_),
|
||||
static_cast<int32_t>(is_unidirectional_),
|
||||
static_cast<int32_t>(mask_filter_value_),
|
||||
static_cast<int32_t>(scale_),
|
||||
static_cast<int32_t>(do_rotary_));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace js
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -7,7 +7,9 @@ namespace onnxruntime {
|
|||
namespace contrib {
|
||||
namespace js {
|
||||
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Attention);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization);
|
||||
|
|
@ -21,7 +23,9 @@ KernelCreateInfo BuildKernelCreateInfo<void>() {
|
|||
|
||||
Status RegisterJsContribKernels(KernelRegistry& kernel_registry) {
|
||||
static const BuildKernelCreateInfoFn function_table[] = {
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Attention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1,
|
||||
|
|
|
|||
Loading…
Reference in a new issue