From fac3e33da510c27c7a2631cf44a79923ee14e09f Mon Sep 17 00:00:00 2001 From: Arthur Islamov Date: Sat, 18 Nov 2023 00:23:52 +0400 Subject: [PATCH] [js/web] JSEP Attention & MultiHeadAttention (#17742) ### Description This is a narrow implementation of Attention/MultiHeadAttention as it does not support: a. inputs 5-7 for MHA b. packed QKV/KV c. past/present d. attention mask But it works well for StableDiffusion and can be extended later. It reduces VRAM usage as it combines many ops into few I've updated demo here https://islamov.ai/stable-diffusion-webgpu/ it takes ~13sec for 1 image with 20 steps on RTX3090Ti and about 25s on M1 Pro VRAM usage is about 8gb if you don't use img2img Going to focus on SDXL now --------- Co-authored-by: Guenther Schmuelling Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- js/web/docs/webgpu-operators.md | 2 + .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 4 + js/web/lib/wasm/jsep/webgpu/ops/attention.ts | 635 ++++++++++++++++++ .../jsep/webgpu/ops/multi-head-attentiion.ts | 335 +++++++++ js/web/script/generate-webgpu-operator-md.ts | 2 + js/web/test/data/ops/attention.jsonc | 557 +++++++++++++++ .../test/data/ops/multi-head-attention.jsonc | 194 ++++++ js/web/test/suite-test-list.jsonc | 2 + onnxruntime/contrib_ops/js/bert/attention.cc | 24 + onnxruntime/contrib_ops/js/bert/attention.h | 47 ++ .../js/bert/multi_head_attention.cc | 24 + .../js/bert/multi_head_attention.h | 36 + .../contrib_ops/js/js_contrib_kernels.cc | 4 + 13 files changed, 1866 insertions(+) create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/attention.ts create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts create mode 100644 js/web/test/data/ops/attention.jsonc create mode 100644 js/web/test/data/ops/multi-head-attention.jsonc create mode 100644 onnxruntime/contrib_ops/js/bert/attention.cc create mode 100644 onnxruntime/contrib_ops/js/bert/attention.h create mode 100644 onnxruntime/contrib_ops/js/bert/multi_head_attention.cc create mode 100644 onnxruntime/contrib_ops/js/bert/multi_head_attention.h diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 0b82a9c031..b246e19137 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -20,6 +20,7 @@ Do not modify directly.* | Asinh | ai.onnx(9+) | | | Atan | ai.onnx(7+) | | | Atanh | ai.onnx(9+) | | +| Attention | com.microsoft(1+) | need implementing mask and past/present | | AveragePool | ai.onnx(7-9,10,11+); com.ms.internal.nhwc(7-9,10,11+) | need perf optimization; need implementing activation | | BiasAdd | com.microsoft(1+) | | | BiasSplitGelu | com.microsoft(1+) | | @@ -61,6 +62,7 @@ Do not modify directly.* | MemcpyFromHost | ai.onnx(1+) | | | MemcpyToHost | ai.onnx(1+) | | | Mul | ai.onnx(7-12,13,14+) | | +| MultiHeadAttention | com.microsoft(1+) | need implementing mask and past/present | | Neg | ai.onnx(6-12,13+) | | | Not | ai.onnx(1+) | | | Pad | ai.onnx(2-10,11-12,13-17,18,19+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index a4d51e68b6..9f5dceb8f4 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -2,6 +2,7 @@ // Licensed under the MIT License. import {argMax, argMin, parseArgMinMaxAttributes} from './ops/argminmax'; +import {attention, parseAttentionAttributes} from './ops/attention'; import {biasAdd} from './ops/bias-add'; import {biasSplitGelu} from './ops/bias-split-gelu'; import * as binaryOps from './ops/binary-op'; @@ -16,6 +17,7 @@ import {gemm, parseGemmAttributes} from './ops/gemm'; import {instanceNorm, parseInstanceNormAttributes} from './ops/instance-norm'; import {layerNorm, parseLayerNormAttributes} from './ops/layer-norm'; import {matMul} from './ops/matmul'; +import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multi-head-attentiion'; import {pad, parsePadAttributes} from './ops/pad'; import * as pool from './ops/pool'; import {range} from './ops/range'; @@ -46,6 +48,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Asinh', [unaryOps.asinh]], ['Atan', [unaryOps.atan]], ['Atanh', [unaryOps.atanh]], + ['Attention', [attention, parseAttentionAttributes]], // TODO: support new attributes for AveragePool-10 ['AveragePool', [pool.averagePool, pool.parseAveragePoolAttributes]], ['BiasAdd', [biasAdd]], @@ -86,6 +89,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new // TODO: support new attributes for MaxPool-8 and MaxPool-10 ['MaxPool', [pool.maxPool, pool.parseMaxPoolAttributes]], ['Mul', [binaryOps.mul]], + ['MultiHeadAttention', [multiHeadAttention, parseMultiHeadAttentionAttributes]], ['Neg', [unaryOps.neg]], ['Not', [unaryOps.not]], ['Pad', [pad, parsePadAttributes]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts new file mode 100644 index 0000000000..e1f2a47301 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -0,0 +1,635 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TensorView} from '../../tensor-view'; +import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, GpuDataType} from '../types'; + +import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common'; + +export const enum AttentionQkvFormat { + unknown, // enum value not set, or depends on qkv projection implementation details + qkvBNSH, // for non-packed qkv, permuted + qkvBSNH, // for non-packed qkv, not permuted, used by memory efficient attention or MultiHeadAttention + qkvBSN3H, // for TRT fused attention, qkv are packed + qkvBNSHqkvBS3NH, // for TRT fused causal attention, data has two formats (qkv is 3BNSH, gemm_buffer is BS3NH) + qKvBSNHxBSN2H, // for TRT fused cross attention, kv are packed + qkvTNH, // for memory efficient attention, qkv are not packed, and paddings are removed. + qkvTN3H, // for TRT fused attention, qkv are packed and paddings are removed +} + +export const enum AttentionMaskType { + none, // No mask + mask1dKeySeqLen, // [batch_size], key sequence length + mask1dEndStart, // [2 * batch_size] with end positions and start positions + mask1DKeySeqLenStart, // [3 * batch_size + 2] with [key_len[0], ..., key_len[batch_size - 1], query_start[0], + // ..., query_start[batch_size - 1], query_end[batch_size - 1], key_start[0], ..., + // key_start[batch_size - 1], key_end[batch_size - 1]] + mask2dDummy, // dummy mask with shape [1, 1] or [batch_size, 1]. It has same effect as no mask. + mask2dKeyPadding, // [batch_size, total_sequence_length] + mask3dAttention, // [batch_size, sequence_length, total_sequence_length] + mask4dMegatron, // Megatron causal mask with shape [batch_size, 1, max_sequence_length, max_sequence_length] + maskUnknown +} + +export interface AttentionParameters { + batchSize: number; + sequenceLength: number; + pastSequenceLength: number; + kvSequenceLength: number; + totalSequenceLength: number; + maxSequenceLength: number; + inputHiddenSize: number; + hiddenSize: number; + vHiddenSize: number; + headSize: number; + vHeadSize: number; + numHeads: number; + isUnidirectional: boolean; + pastPresentShareBuffer: boolean; + maskFilterValue: number; + maskType: AttentionMaskType; + scale: number; + broadcastResPosBias: boolean; + passPastInKv: boolean; + qkvFormat: AttentionQkvFormat; +} + +export interface AttentionAttrs { + numHeads: number; + isUnidirectional: number; + maskFilterValue: number; + scale: number; + doRotary: number; + qkvHiddenSizes: number[]; + pastPresentShareBuffer: boolean; +} + +const validateAttentionInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => { + // Abbreviation and Meanings: + // B: batch_size + // S: sequence_length (input sequence length of query) + // P: past_sequence_length (past sequence length of key or value) + // L: kv_sequence_length (input sequence length of key or value) + // M: max_sequence_length + // T: total_sequence_length = past_sequence_length + kv_sequence_length + // N: num_heads + // H: head size for Q and K, aka q_head_size or k_head_size or qk_head_size + // H_v: v_head_size + // D_i: input hidden size + // D: hidden size for Q and K (D = N * H), aka q_hidden_size or k_hidden_size or qk_hidden_size + // D_v: v_hidden_size = num_heads * v_head_size + + // When past state is used, Q, K and V should have same hidden size (unless we split it into past_key and past_value). + + // Input shapes: + // input (Q/K/V) : (B, S, D_i) + // weights (Q/K/V) : (D_i, D + D + D_v) + // bias (Q/K/V) : (D + D + D_v) + // mask_index : see below + // past (K/V) : (2, B, N, P, H) or NULL + // relative_position_bias : (B, N, S, T) or NULL + + // For mask_index, the following shapes are supported: + // NULL, (B, 1), (1, 1) + // (B), (2 * B), (3 * B + 2) + // (B, T) + // (B, S, T) + // (B, 1, M, M) + // + // When a model is pruned (like some attention heads are removed in Q/K/V), input_hidden_size could be larger + // than hidden dimension of Q, K and V. + + const input = inputs[0]; + const weights = inputs[1]; + const bias = inputs[2]; + const maskIndex = inputs[3]; + const past = inputs[4]; + const relativePositionBias = inputs[5]; + + if (past && relativePositionBias) { + throw new Error('Attention cannot have both past and relative_position_bias'); + } + + if (input.dims.length !== 3) { + throw new Error('Input "input" must have 3 dimensions'); + } + + const batchSize = input.dims[0]; + const sequenceLength = input.dims[1]; + const inputHiddenSize = input.dims[2]; + + if (bias.dims.length !== 1) { + throw new Error('Input "bias" is expected to have 1 dimensions'); + } + + if (weights.dims.length !== 2) { + throw new Error('Input "weights" is expected to have 2 dimensions'); + } + + if (weights.dims[0] !== inputHiddenSize) { + throw new Error('Input 1 dimension 0 should have same length as dimension 2 of input 0'); + } + + if (bias.dims[0] !== weights.dims[1]) { + throw new Error('Input "bias" dimension 0 should have same length as dimension 1 of input "weights"'); + } + + let qHiddenSize = bias.dims[0] / 3; + let kHiddenSize = qHiddenSize; + let vHiddenSize = kHiddenSize; + if (attributes.qkvHiddenSizes.length > 0) { + if (attributes.qkvHiddenSizes.length !== 3) { + throw new Error('qkv_hidden_sizes attribute should have 3 elements'); + } + for (const sz of attributes.qkvHiddenSizes) { + if (sz % attributes.numHeads !== 0) { + throw new Error('qkv_hidden_sizes should be divisible by num_heads'); + } + } + + qHiddenSize = attributes.qkvHiddenSizes[0]; + kHiddenSize = attributes.qkvHiddenSizes[1]; + vHiddenSize = attributes.qkvHiddenSizes[2]; + } + + const kvSequenceLength = sequenceLength; + + if (qHiddenSize !== kHiddenSize) { + throw new Error('qkv_hidden_sizes first element should be same as the second'); + } + + if (bias.dims[0] !== qHiddenSize + kHiddenSize + vHiddenSize) { + throw new Error('Input "bias" dimension 0 should have same length as sum of Q/K/V hidden sizes'); + } + + let pastSequenceLength = 0; + if (past) { + if (kHiddenSize !== vHiddenSize) { + throw new Error('Input "past" expect k_hidden_size == v_hidden_size'); + } + if (past.dims.length !== 5) { + throw new Error('Input "past" must have 5 dimensions'); + } + if (past.dims[0] !== 2) { + throw new Error('Input "past" first dimension must be 2'); + } + if (past.dims[1] !== batchSize) { + throw new Error('Input "past" second dimension must be batch_size'); + } + if (past.dims[2] !== attributes.numHeads) { + throw new Error('Input "past" third dimension must be num_heads'); + } + if (past.dims[4] !== kHiddenSize / attributes.numHeads) { + throw new Error('Input "past" fifth dimension must be k_hidden_size / num_heads'); + } + + if (!attributes.pastPresentShareBuffer) { + pastSequenceLength = past.dims[3]; + } + // TODO: handle past_seq_len + } + + const totalSequenceLength = kvSequenceLength + pastSequenceLength; + const maxSequenceLength = -1; + + const maskType = AttentionMaskType.none; + if (maskIndex) { + // maskType = AttentionMaskType.MASK_UNKNOWN; + // TODO: handle mask + throw new Error('Mask not supported'); + } + + if (past) { + throw new Error('past is not supported'); + } + if (relativePositionBias) { + throw new Error('relativePositionBias is not supported'); + } + + return { + batchSize, + sequenceLength, + pastSequenceLength, + kvSequenceLength, + totalSequenceLength, + maxSequenceLength, + inputHiddenSize, + hiddenSize: qHiddenSize, + vHiddenSize, + headSize: Math.floor(qHiddenSize / attributes.numHeads), + vHeadSize: Math.floor(vHiddenSize / attributes.numHeads), + numHeads: attributes.numHeads, + isUnidirectional: false, + pastPresentShareBuffer: false, + maskFilterValue: attributes.maskFilterValue, + maskType, + scale: attributes.scale, + broadcastResPosBias: false, + passPastInKv: false, + qkvFormat: AttentionQkvFormat.qkvBNSH, + }; +}; + +export const parseAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs => + createAttributeWithCacheKey({...attributes}); + +export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView, n: number, d: number) => { + const components = getMaxComponents(d); + const inputHelper = outputVariable('x', input.dataType, input.dims, components); + + let threadMaxValue = 'threadMaxVector'; + if (components === 2) { + threadMaxValue = 'max(threadMaxVector.x, threadMaxVector.y)'; + } else if (components === 4) { + threadMaxValue = 'max(max(threadMaxVector.x, threadMaxVector.y), max(threadMaxVector.z, threadMaxVector.w))'; + } + const dataType = tensorTypeToWsglStorageType(input.dataType); + let WG = 64; + const dComp = d / components; + if (dComp < WG) { + WG = 1; + } else if (dComp / 8 < 64) { + WG = Math.ceil(dComp / 8); + } + const elementsPerWG = Math.ceil(d / components / WG); + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const dInv: ${dataType} = 1 / ${d}; + const dComp = ${d / components}; + var wgMax: array; + var wgSum: array; + + ${shaderHelper.declareVariables(inputHelper)} + @compute @workgroup_size(${WG}, 1, 1) + fn main(@builtin(workgroup_id) workgroup_id : vec3, + @builtin(local_invocation_index) local_index : u32) { + let localOffset = local_index * ${elementsPerWG}; + let offset: u32 = workgroup_id.x * dComp + localOffset; + + var threadMaxVector = ${fillVector('f32', components, '-3.402823e+38f')}; + for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { + threadMaxVector = max(${castToF32(dataType, components, 'x[offset + i]')}, threadMaxVector); + } + wgMax[local_index] = ${threadMaxValue}; + workgroupBarrier(); + + var maxValue = -3.402823e+38f; + for (var i = 0u; i < ${WG}; i++) { + maxValue = max(wgMax[i], maxValue); + } + + var sumVector = ${fillVector('f32', components, '0')}; + for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { + sumVector += exp(${castToF32(dataType, components, 'x[offset + i]')} - maxValue); + } + wgSum[local_index] = ${sumVector('sumVector', components)}; + workgroupBarrier(); + + var sum: f32 = 0; + for (var i = 0u; i < ${WG}; i++) { + sum += wgSum[i]; + } + + if (sum == 0) { + for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { + x[offset + i] = ${fillVector(dataType, components, 'dInv')}; + } + } else { + for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) { + let f32input = ${castToF32(dataType, components, 'x[offset + i]')}; + x[offset + i] = ${inputHelper.type.value}(exp(f32input - maxValue) / sum); + } + } + }`; + + context.compute( + { + name: 'AttentionProbsSoftmax', + shaderCache: {hint: `${d}`}, + getShaderSource, + getRunData: () => ({ + outputs: [], + dispatchGroup: {x: n}, + }), + }, + {inputs: [input], outputs: []}); +}; + +const computeAttentionProbs = + (context: ComputeContext, q: TensorView, key: TensorView, _bias: TensorView|undefined, + parameters: AttentionParameters, attributes: AttentionAttrs) => { + const probsShape = [ + parameters.batchSize, parameters.numHeads, parameters.sequenceLength, + parameters.kvSequenceLength + parameters.pastSequenceLength + ]; + // TODO: handle mask + + const alpha = attributes.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : attributes.scale; + + const dataType = tensorTypeToWsglStorageType(q.dataType); + + const components = getMaxComponents(parameters.headSize); + const qInput = inputVariable('q', q.dataType, q.dims, components); + const kInput = inputVariable('key', key.dataType, key.dims, components); + const output = outputVariable('output', q.dataType, probsShape); + + const vectorizedHeadSize = parameters.headSize / components; + const M = parameters.sequenceLength; + const N = parameters.totalSequenceLength; + const K = vectorizedHeadSize; + + const TILE_SIZE = 12; + + const dispatch = { + x: Math.ceil(parameters.totalSequenceLength / TILE_SIZE), + y: Math.ceil(parameters.sequenceLength / TILE_SIZE), + z: parameters.batchSize * parameters.numHeads + }; + + const inputs = [q, key]; + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const M: u32 = ${M}u; + const N: u32 = ${N}u; + const K: u32 = ${K}u; + const alpha: ${dataType} = ${alpha}; + const beta: ${dataType} = 1.0; + const TILE_SIZE = ${TILE_SIZE}u; + + var tileQ: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>; + var tileK: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>; + + ${shaderHelper.declareVariables(qInput, kInput, output)} + + @compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1) + fn main(@builtin(workgroup_id) workgroup_id : vec3, + @builtin(local_invocation_id) local_id : vec3, @builtin(local_invocation_index) local_index : u32) { + let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u + + workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; + + // x holds the N and y holds the M + let headIdx = workgroup_id.z; + let m = workgroup_id.y * TILE_SIZE; + let n = workgroup_id.x * TILE_SIZE; + let lm = m + local_id.y; + let ln = n + local_id.x; + + let qOffset = ${parameters.sequenceLength * vectorizedHeadSize} * headIdx + m * K; + let kOffset = ${parameters.kvSequenceLength * vectorizedHeadSize} * headIdx + n * K; + + var value = ${fillVector(dataType, components)}; + for (var w: u32 = 0u; w < K; w += TILE_SIZE) { + if (m + local_id.y < M && w + local_id.x < K) { + tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * K + w + local_id.x]; + } + if (n + local_id.y < N && w + local_id.x < K) { + tileK[TILE_SIZE * local_id.y + local_id.x] = key[kOffset + local_id.y * K + w + local_id.x]; + } + workgroupBarrier(); + + for (var k: u32 = 0u; k ({ + outputs: [{dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default}], + dispatchGroup: dispatch, + }), + getShaderSource, + }, + {inputs, outputs: [-1]})[0]; + + computeInPlaceSoftmax( + context, probs, parameters.batchSize * parameters.numHeads * parameters.sequenceLength, + parameters.totalSequenceLength); + + return probs; + }; + +const computeVxAttentionScore = + (context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters) => { + const outputShape = [params.batchSize, params.sequenceLength, params.vHiddenSize]; + + const probsHelper = inputVariable('probs', probs.dataType, probs.dims); + const vHelper = inputVariable('v', v.dataType, v.dims); + const output = outputVariable('output', probs.dataType, outputShape); + + const dataType = tensorTypeToWsglStorageType(probs.dataType); + + const TILE_SIZE = 12; + const dispatch = { + x: Math.ceil(params.vHeadSize / TILE_SIZE), + y: Math.ceil(params.sequenceLength / TILE_SIZE), + z: params.batchSize * params.numHeads + }; + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const M: u32 = ${params.sequenceLength}u; + const N: u32 = ${params.vHeadSize}u; + const K: u32 = ${params.totalSequenceLength}u; + const numHeads: u32 = ${params.numHeads}u; + const TILE_SIZE = ${TILE_SIZE}u; + + var tileQ: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>; + var tileK: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>; + + ${shaderHelper.declareVariables(probsHelper, vHelper, output)} + + @compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1) + fn main(@builtin(workgroup_id) workgroup_id : vec3, + @builtin(local_invocation_id) local_id : vec3, @builtin(local_invocation_index) local_index : u32) { + let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u + + workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; + + let headIdx = workgroup_id.z; + let m = workgroup_id.y * TILE_SIZE + local_id.y; + let n = workgroup_id.x * TILE_SIZE + local_id.x; + + let offsetA = headIdx * (M * K) + m * K; + let offsetB = headIdx * (N * K) + n; + + var value = ${dataType}(0); + for (var w: u32 = 0u; w < K; w += TILE_SIZE) { + if (m < M && w + local_id.x < K) { + tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x]; + } + if (n < N && w + local_id.y < K) { + tileK[TILE_SIZE * local_id.y + local_id.x] = v[offsetB + (w + local_id.y) * N]; + } + workgroupBarrier(); + for (var k: u32 = 0u; k ({ + outputs: [{dims: outputShape, dataType: probs.dataType, gpuDataType: GpuDataType.default}], + dispatchGroup: dispatch, + }), + getShaderSource, + }, + {inputs: [probs, v], outputs: [0]})[0]; + }; + +export const applyAttention = + (context: ComputeContext, q: TensorView, k: TensorView, v: TensorView, _maskIndex: TensorView|undefined, + _past: TensorView|undefined, _pastKey: TensorView|undefined, _pastValue: TensorView|undefined, + relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs) => { + const probs = computeAttentionProbs(context, q, k, relativePositionBias, parameters, attributes); + + computeVxAttentionScore(context, probs, v, parameters); + }; + +const prepare = (context: ComputeContext, parameters: AttentionParameters) => { + const outputShape = [ + parameters.batchSize, + parameters.numHeads, + parameters.sequenceLength, + parameters.headSize, + ]; + + const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType); + + const M = parameters.sequenceLength; + const K = parameters.inputHiddenSize; + const N = parameters.headSize; + + const TILE_SIZE = 12; + const dispatch = { + x: Math.ceil(parameters.headSize / TILE_SIZE), + y: Math.ceil(parameters.sequenceLength / TILE_SIZE), + z: parameters.batchSize * parameters.numHeads + }; + + const getShaderSource = () => ` + const M: u32 = ${M}u; + const K: u32 = ${K}u; + const N: u32 = ${N}u; + const numHeads: u32 = ${parameters.numHeads}; + const ldb = ${parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize}u; + const TILE_SIZE = ${TILE_SIZE}u; + + var tileInput: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; + var tileWeightQ: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; + var tileWeightK: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; + var tileWeightV: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; + + @group(0) @binding(0) var input: array<${dataType}>; + @group(0) @binding(1) var weight: array<${dataType}>; + @group(0) @binding(2) var bias: array<${dataType}>; + @group(0) @binding(3) var outputQ: array<${dataType}>; + @group(0) @binding(4) var outputK: array<${dataType}>; + @group(0) @binding(5) var outputV: array<${dataType}>; + + @compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1) + fn main(@builtin(workgroup_id) workgroup_id : vec3, + @builtin(local_invocation_id) local_id : vec3, @builtin(local_invocation_index) local_index : u32) { + let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u + + workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; + + let batchIndex = workgroup_id.z / ${parameters.numHeads}; + let headNumber = workgroup_id.z % ${parameters.numHeads}; + let m = workgroup_id.y * TILE_SIZE + local_id.y; + let n = workgroup_id.x * TILE_SIZE + local_id.x; + + let inputOffset = batchIndex * (M * K) + m * K; + let biasOffsetQ = headNumber * ${parameters.headSize}; + let biasOffsetK = ${parameters.hiddenSize} + biasOffsetQ; + let biasOffsetV = ${parameters.hiddenSize} + biasOffsetK; + + var valueQ = ${dataType}(0); + var valueK = ${dataType}(0); + var valueV = ${dataType}(0); + for (var w: u32 = 0u; w < K; w += TILE_SIZE) { + if (m < M && w + local_id.x < K) { + tileInput[TILE_SIZE * local_id.y + local_id.x] = input[inputOffset + w + local_id.x]; + } + if (n < N && w + local_id.y < K) { + let offset = n + (w + local_id.y) * ldb; + tileWeightQ[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetQ + offset]; + tileWeightK[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetK + offset]; + tileWeightV[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetV + offset]; + } + workgroupBarrier(); + for (var k: u32 = 0u; k ({ + outputs: [ + {dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default}, + {dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default}, + {dims: outputShape, dataType: context.inputs[0].dataType, gpuDataType: GpuDataType.default}, + ], + dispatchGroup: dispatch, + }), + getShaderSource, + }, + {inputs, outputs: [-1, -1, -1]}); +}; + +export const attention = (context: ComputeContext, attributes: AttentionAttrs): void => { + const params = validateAttentionInputs(context.inputs, attributes); + + const [q, k, v] = prepare(context, params); + + return applyAttention( + context, q, k, v, context.inputs[4], undefined, undefined, undefined, context.inputs[5], params, attributes); +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts new file mode 100644 index 0000000000..b7726a36bc --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts @@ -0,0 +1,335 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, GpuDataType} from '../types'; + +import {applyAttention, AttentionAttrs, AttentionMaskType, AttentionParameters, AttentionQkvFormat} from './attention'; +import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; +import {createTransposeProgramInfo, TransposeAttributes} from './transpose'; + +const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => { + const query = inputs[0]; + const key = inputs[1]; + const value = inputs[2]; + const bias = inputs[3]; + const keyPaddingMask = inputs[4]; + const relativePositionBias = inputs[5]; + const pastKey = inputs[6]; + const pastValue = inputs[7]; + + // Abbreviation and Meanings: + // B: batch_size + // S: sequence_length (input sequence length of query) + // P: past_sequence_length (past sequence length of key or value) + // L: kv_sequence_length (input sequence length of key or value) + // M: max_sequence_length + // T: total_sequence_length = past_sequence_length + kv_sequence_length + // N: num_heads + // H: head size for Q and K, aka q_head_size or k_head_size or qk_head_size + // H_v: v_head_size + // D_i: input hidden size + // D: hidden size for Q and K (D = N * H), aka q_hidden_size or k_hidden_size or qk_hidden_size + // D_v: v_hidden_size = num_heads * v_head_size + + // key_padding_mask (K/V) : (B) or (2*B + 1) or (B, L) or None + // relative_position_bias : (B, 1, S, L) + // past_key : (B, N, S*, H) + // past_value : (B, N, S*, H) + // When no packing for q/k/v: + // query (Q) : (B, S, D) + // key (K) : (B, L, D) or (B, N, S*, H) + // value (V) : (B, L, D_v) or (B, N, S*, H) + // bias (Q/K/V) : (D + D + D_v) + // When packed kv is used: + // query (Q) : (B, S, D) + // key (K) : (B, L, N, 2, H) + // value (V) : None + // bias (Q/K/V) : None + // When packed qkv is used: + // query (Q) : (B, L, N, 3, H) or (B, S, 3*D) + // key (K) : None + // value (V) : None + // bias (Q/K/V) : None or (D + D + D_v) + + if (query.dims.length !== 3 && query.dims.length !== 5) { + throw new Error('Input query is expected to have 3 or 5 dimensions'); + } + + const dmmhaPacking = false; + const batchSize = query.dims[0]; + const sequenceLength = query.dims[1]; + const hiddenSize = query.dims.length === 3 ? (dmmhaPacking ? query.dims[2] / 3 : query.dims[2]) : + attributes.numHeads * query.dims[4]; + let kvSequenceLength = sequenceLength; + + let pastSequenceLength = 0; + let maxSequenceLength = 0; + const headSize = Math.floor(hiddenSize / attributes.numHeads); + if (pastKey && pastValue) { + if (pastKey.dims.length !== 4) { + throw new Error('Input "past_key" is expected to have 4 dimensions'); + } + if (pastValue.dims.length !== 4) { + throw new Error('Input "past_value" is expected to have 4 dimensions'); + } + pastSequenceLength = pastKey.dims[2]; + maxSequenceLength = pastKey.dims[2]; + } else if (pastKey || pastValue) { + throw new Error('Input "past_key" and "past_value" shall be both present or both absent'); + } + + let qkvFormat: AttentionQkvFormat; + if (key) { + if (query.dims.length !== 3) { + throw new Error('Input "query" is expected to have 3 dimensions when key is given'); + } + if (key.dims.length < 3 || key.dims.length > 5) { + throw new Error('Input "key" is expected to have 3, 4, or 5 dimensions'); + } + if (query.dims[0] !== key.dims[0]) { + throw new Error('Input "query" and "key" shall have same dim 0 (batch size)'); + } + + if (key.dims.length === 3) { + if (key.dims[2] !== query.dims[2]) { + throw new Error('Input "query" and "key" shall have same dim 2 (hidden_size)'); + } + qkvFormat = AttentionQkvFormat.qkvBSNH; + kvSequenceLength = key.dims[1]; + } else if (key.dims.length === 5) { + if (key.dims[2] !== attributes.numHeads || key.dims[3] !== 2 || key.dims[4] !== headSize) { + throw new Error('Expect "key" shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv'); + } + if (value) { + throw new Error('Expect "value" be none when "key" has packed kv format.'); + } + qkvFormat = AttentionQkvFormat.qKvBSNHxBSN2H; + kvSequenceLength = key.dims[1]; + } else { // key_dims.size() == 4 (cross-attention with past_key) + if (key.dims[1] !== attributes.numHeads || key.dims[3] !== headSize) { + throw new Error('Expect "key" shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key'); + } + + qkvFormat = AttentionQkvFormat.unknown; + kvSequenceLength = key.dims[2]; + } + } else { // packed QKV + if (query.dims.length !== 3 && query.dims.length !== 5) { + throw new Error('Input "query" is expected to have 3 or 5 dimensions when key is empty'); + } + if (query.dims.length === 5 && (query.dims[2] !== attributes.numHeads || query.dims[3] !== 3)) { + throw new Error('Expect "query" shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv'); + } + + qkvFormat = AttentionQkvFormat.qkvBSN3H; + } + + if (bias) { + if (bias.dims.length !== 1) { + throw new Error('Input "bias" is expected to have 1 dimension'); + } + + if (value) { + if (query.dims.length === 5 && query.dims[3] === 2) { + throw new Error('bias is not allowed for packed kv.'); + } + } + } + + let maskType: AttentionMaskType = AttentionMaskType.none; + if (keyPaddingMask) { + maskType = AttentionMaskType.maskUnknown; + const maskDims = keyPaddingMask.dims; + if (maskDims.length === 1) { + if (maskDims[0] === batchSize) { + maskType = AttentionMaskType.mask1dKeySeqLen; + } else if (maskDims[0] === 3 * batchSize + 2) { + maskType = AttentionMaskType.mask1DKeySeqLenStart; + } + } else if (maskDims.length === 2 && maskDims[0] === batchSize && maskDims[1] === kvSequenceLength) { + maskType = AttentionMaskType.mask2dKeyPadding; + } + if (maskType === AttentionMaskType.maskUnknown) { + throw new Error('Input "key_padding_mask" shape shall be (batch_size) or (batch_size, kv_sequence_length)'); + } + throw new Error('Mask not supported'); + } + + let passPastInKv = false; + let vHiddenSize = hiddenSize; + if (value) { + if (value.dims.length !== 3 && value.dims.length !== 4) { + throw new Error('Input "value" is expected to have 3 or 4 dimensions'); + } + + if (query.dims[0] !== value.dims[0]) { + throw new Error('Input "query" and "value" shall have same dim 0 (batch_size)'); + } + + if (value.dims.length === 3) { + if (kvSequenceLength !== value.dims[1]) { + throw new Error('Input "key" and "value" shall have the same dim 1 (kv_sequence_length)'); + } + vHiddenSize = value.dims[2]; + } else { + if (kvSequenceLength !== value.dims[2]) { + throw new Error('Input "past_key" and "past_value" shall have the same dim 2 (kv_sequence_length)'); + } + vHiddenSize = value.dims[1] * value.dims[3]; + passPastInKv = true; + } + } + + const totalSequenceLength = pastSequenceLength + kvSequenceLength; + const broadcastResPosBias = false; + // if (extraAddQk) { + // if (extraAddQk.dims[0] === 1) { + // broadcastResPosBias = true; + // } + // } + + if (keyPaddingMask) { + throw new Error('Key padding mask is not supported'); + } + if (relativePositionBias) { + throw new Error('extraAddQk is not supported'); + } + if (pastKey) { + throw new Error('pastKey is not supported'); + } + if (pastValue) { + throw new Error('pastValue is not supported'); + } + + return { + batchSize, + sequenceLength, + pastSequenceLength, + kvSequenceLength, + totalSequenceLength, + maxSequenceLength, + inputHiddenSize: 0, + hiddenSize, + vHiddenSize, + headSize, + vHeadSize: Math.floor(vHiddenSize / attributes.numHeads), + numHeads: attributes.numHeads, + isUnidirectional: false, + pastPresentShareBuffer: false, + maskFilterValue: attributes.maskFilterValue, + maskType, + scale: attributes.scale, + broadcastResPosBias, + passPastInKv, + qkvFormat, + }; +}; + + +export const parseMultiHeadAttentionAttributes = (attributes: AttentionAttrs): AttentionAttrs => + createAttributeWithCacheKey({...attributes}); + +const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({perm: [0, 2, 1, 3]}); + +const addBiasTranspose = + (context: ComputeContext, qkv: TensorView, bias: TensorView, batchSize: number, sequenceLength: number, + hiddenSize: number, biasOffset: number) => { + const outputShape = [batchSize, sequenceLength, hiddenSize]; + const outputSize = ShapeUtil.size(outputShape); + + const dataType = tensorTypeToWsglStorageType(qkv.dataType); + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const biasOffset = ${biasOffset}u; + const hiddenSize = ${hiddenSize}u; + + @group(0) @binding(0) var qkv: array<${dataType}>; + @group(0) @binding(1) var bias: array<${dataType}>; + @group(0) @binding(2) var qkv_with_bias: array<${dataType}>; + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + let biasOffsetIdx = (global_idx % hiddenSize) + biasOffset; + + qkv_with_bias[global_idx] = qkv[global_idx] + bias[biasOffsetIdx]; + }`; + + return context.compute( + { + name: 'MultiHeadAttentionAddBias', + shaderCache: {hint: JSON.stringify({batchSize, sequenceLength, hiddenSize, biasOffset})}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: qkv.dataType, gpuDataType: GpuDataType.default}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + }), + getShaderSource, + }, + {inputs: [qkv, bias], outputs: [-1]})[0]; + }; + +const maybeTransposeToBNSHAndAddBias = + (context: ComputeContext, batchSize: number, numHeads: number, sequenceLength: number, headSize: number, + input: TensorView, bias?: TensorView, biasOffset?: number) => { + // const newDims = []; + + let reshapedInput = input; + if (!bias) { + if (input.dims.length === 3) { + reshapedInput = input.reshape([batchSize, sequenceLength, numHeads, headSize]); + } + return context.compute( + createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), + {inputs: [reshapedInput], outputs: [-1]})[0]; + } else { + if (sequenceLength === 1) { + throw new Error('AddBiasReshape is not implemented. Please export your model with packed QKV or KV'); + } else { + reshapedInput = + addBiasTranspose(context, input, bias, batchSize, sequenceLength, numHeads * headSize, biasOffset!); + reshapedInput = reshapedInput.reshape([batchSize, sequenceLength, numHeads, headSize]); + return context.compute( + createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), + {inputs: [reshapedInput], outputs: [-1]})[0]; + } + } + }; + +export const multiHeadAttention = (context: ComputeContext, attributes: AttentionAttrs): void => { + const params = validateInputs(context.inputs, attributes); + + if (context.inputs[0].dims.length === 5) { + throw new Error('Packed QKV is not implemented'); + } + + if (context.inputs[1]?.dims.length === 5) { + throw new Error('Packed KV is not implemented'); + } + + // applyAttention expects BNSH inputs + const kvBNSH = context.inputs[1] && context.inputs[2] && context.inputs[1].dims.length === 4 && + context.inputs[2].dims.length === 4; + + const Q = maybeTransposeToBNSHAndAddBias( + context, params.batchSize, params.numHeads, params.sequenceLength, params.headSize, context.inputs[0], + context.inputs[3], 0); + + if (kvBNSH) { + return applyAttention( + context, Q, context.inputs[1], context.inputs[2], context.inputs[4], undefined, undefined, undefined, + context.inputs[5], params, attributes); + } + + const K = maybeTransposeToBNSHAndAddBias( + context, params.batchSize, params.numHeads, params.kvSequenceLength, params.headSize, context.inputs[1], + context.inputs[3], params.hiddenSize); + + const V = maybeTransposeToBNSHAndAddBias( + context, params.batchSize, params.numHeads, params.kvSequenceLength, params.vHeadSize, context.inputs[2], + context.inputs[3], 2 * params.hiddenSize); + + applyAttention( + context, Q, K, V, context.inputs[4], undefined, context.inputs[6], context.inputs[7], context.inputs[5], params, + attributes); +}; diff --git a/js/web/script/generate-webgpu-operator-md.ts b/js/web/script/generate-webgpu-operator-md.ts index 7408f17004..eab8175a94 100644 --- a/js/web/script/generate-webgpu-operator-md.ts +++ b/js/web/script/generate-webgpu-operator-md.ts @@ -16,6 +16,8 @@ const COMMENTS: Record = { 'Reshape': 'no GPU kernel', 'Shape': 'no GPU kernel; an ORT warning is generated - need to fix', 'Resize': 'CoordinateTransformMode align_corners is not supported with downsampling', + 'Attention': 'need implementing mask and past/present', + 'MultiHeadAttention': 'need implementing mask and past/present', }; /* eslint-disable max-len */ diff --git a/js/web/test/data/ops/attention.jsonc b/js/web/test/data/ops/attention.jsonc new file mode 100644 index 0000000000..bd4483027c --- /dev/null +++ b/js/web/test/data/ops/attention.jsonc @@ -0,0 +1,557 @@ +[ + { + "name": "Attention Basic", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [4, 3], + "type": "float32" + }, + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [213, 213], + "dims": [1, 2, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic Batch 2 with 2 heads", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16 + ], + "dims": [2, 2, 8], + "type": "float32" + }, + { + "data": [ + 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 + ], + "dims": [8, 6], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [6], + "type": "float32" + } + ], + "outputs": [ + { + "data": [320, 321, 320, 321, 320, 321, 320, 321], + "dims": [2, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863], + "dims": [1, 3, 2], + "type": "float32" + }, + { + "data": [2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094], + "dims": [2, 3], + "type": "float32" + }, + { + "data": [1.1103, -1.6898, -0.989], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [-1.328187108039856, -1.297916054725647, -0.8599594831466675], + "dims": [1, 3, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic one head, batch 2", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094], + "dims": [2, 3, 2], + "type": "float32" + }, + { + "data": [2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094], + "dims": [2, 3], + "type": "float32" + }, + { + "data": [1.1103, -1.6898, -0.989], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987 + ], + "dims": [2, 3, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 2 head, batch 1", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094], + "dims": [2, 3, 2], + "type": "float32" + }, + { + "data": [2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643], + "dims": [2, 6], + "type": "float32" + }, + { + "data": [1.1103, -1.6898, -0.989, -0.989, 1.1103, -1.6898], + "dims": [6], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.8701779842376709, -2.6158859729766846, 0.8710794448852539, -2.5763747692108154, 0.9005484580993652, + -2.182751178741455, 2.1661579608917236, -2.1045265197753906, 1.6716957092285156, -1.797281265258789, + 1.7134947776794434, -1.765358328819275 + ], + "dims": [2, 3, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 5 head, batch 2", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 5, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, + -1.8803634643554688, 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, + -1.0069535970687866, -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, + -0.1792980432510376, -0.26380985975265503, -0.25473490357398987 + ], + "dims": [2, 3, 5], + "type": "float32" + }, + { + "data": [ + 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643, + 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312, + 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236 + ], + "dims": [5, 15], + "type": "float32" + }, + { + "data": [ + 1.1103, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, -1.6898, -0.989, -1.9029953479766846, 0.8710794448852539, + -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, 1.7134947776794434 + ], + "dims": [15], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.6956915855407715, -2.8863370418548584, 1.3899128437042236, 1.6789076328277588, -1.4083852767944336, + -1.7009180784225464, -3.1053788661956787, 3.5959298610687256, 1.1027096509933472, -0.009643087163567543, + -1.694351315498352, -2.9284396171569824, 1.734721302986145, 2.0606398582458496, -0.2571452260017395, + 3.671973943710327, -5.285338401794434, -6.833454132080078, 1.7506506443023682, -2.262148380279541, + 2.5110034942626953, 1.440049171447754, -0.9423203468322754, 1.7506506443023682, -1.86212158203125, + -0.5036701560020447, -5.732386589050293, -1.5674757957458496, 1.7506510019302368, -2.264472246170044 + ], + "dims": [2, 3, 5], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 5 head, batch 1", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 5, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846 + ], + "dims": [1, 3, 5], + "type": "float32" + }, + { + "data": [ + 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643, + 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312, + 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236 + ], + "dims": [5, 15], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + "dims": [15], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168, + -1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405, + -1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326 + ], + "dims": [1, 3, 5], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 5 head, batch 3", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 5, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, + -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987 + ], + "dims": [3, 3, 5], + "type": "float32" + }, + { + "data": [ + 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643, + 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312, + 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236 + ], + "dims": [5, 15], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + "dims": [15], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168, + -1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405, + -1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326, + -1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168, + -1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405, + -1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326, + 3.7965505123138428, -2.3799397945404053, -3.9530906677246094, 0.5844926834106445, -2.9756431579589844, + 2.448162794113159, 4.34546422958374, 1.9380426406860352, 0.5870105624198914, -2.7368364334106445, + -0.4769568145275116, 4.255186557769775, -3.9529950618743896, 0.6987408995628357, -2.9756433963775635 + ], + "dims": [3, 3, 5], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 5 head, batch 3", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 5, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, + -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, + 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, + 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987 + ], + "dims": [3, 3, 10], + "type": "float32" + }, + { + "data": [ + 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643, + 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312, + 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, + 0.2303, 0.4617, 1.44, -2.22, 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, + -1.8803634643554688, 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, + -1.0069535970687866, -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, + -0.1792980432510376, -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, + -1.9054111242294312, 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, + -1.8803634643554688, 2.1661579608917236 + ], + "dims": [10, 15], + "type": "float32" + }, + { + "data": [ + -1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168, + -1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405, + -1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326 + ], + "dims": [15], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -8.01101303100586, -5.782258987426758, 6.016238689422607, 0.26747000217437744, -6.992541313171387, + -8.011263847351074, -5.782248020172119, 5.366001129150391, 0.26747000217437744, -6.99449348449707, + -8.011263847351074, -5.782265663146973, 6.016238689422607, 0.26747000217437744, -6.992537021636963, + -6.102723598480225, -7.28973388671875, -4.578637599945068, 7.2203369140625, -6.028444766998291, + -6.102705478668213, -7.2897748947143555, -3.7882626056671143, 5.393260478973389, -5.754333972930908, + -1.3616288900375366, -7.289827823638916, -6.341128349304199, 6.329389572143555, -5.751791954040527, + -2.3945987224578857, -14.532954216003418, 3.969801902770996, 12.744998931884766, -11.1966552734375, + -2.4002532958984375, -14.538958549499512, -6.684961318969727, 12.476543426513672, -9.24352741241455, + -4.787771701812744, -8.640848159790039, 3.969801902770996, -0.6471102833747864, -11.1966552734375 + ], + "dims": [3, 3, 5], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Attention Basic 1 head, batch 3", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, + -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, + 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, + 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987 + ], + "dims": [3, 3, 10], + "type": "float32" + }, + { + "data": [ + 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643, + 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312, + 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, + 0.2303, 0.4617, 1.44, -2.22, 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, + -1.8803634643554688, 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, + -1.0069535970687866, -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, + -0.1792980432510376, -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, + -1.9054111242294312, 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, + -1.8803634643554688, 2.1661579608917236 + ], + "dims": [10, 15], + "type": "float32" + }, + { + "data": [ + -1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168, + -1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405, + -1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326 + ], + "dims": [15], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -8.011263847351074, -5.7822418212890625, 6.016238689422607, 0.26747000217437744, -6.992536544799805, + -8.011263847351074, -5.7822418212890625, 6.016238689422607, 0.26747000217437744, -6.992536544799805, + -8.011263847351074, -5.7822418212890625, 6.016238689422607, 0.26747000217437744, -6.992536544799805, + 1.3541864156723022, -7.813620090484619, -6.758509635925293, 7.597365856170654, -13.926229476928711, + -1.322464108467102, -7.297357559204102, -0.05962071940302849, 6.347561836242676, -5.869992256164551, + -1.3616288900375366, -7.28973388671875, 0.0386197566986084, 6.329389572143555, -5.751791954040527, + -2.400698661804199, -14.538958549499512, -7.898950576782227, 12.744998931884766, -11.1966552734375, + -2.400698661804199, -14.538958549499512, -7.898950576782227, 12.744998931884766, -11.1966552734375, + 1.021930456161499, -2.373898983001709, 3.8501391410827637, -0.6108309626579285, -9.256340980529785 + ], + "dims": [3, 3, 5], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/multi-head-attention.jsonc b/js/web/test/data/ops/multi-head-attention.jsonc new file mode 100644 index 0000000000..05687bd482 --- /dev/null +++ b/js/web/test/data/ops/multi-head-attention.jsonc @@ -0,0 +1,194 @@ +[ + { + "name": "MultiHeadAttention Basic, one head", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 2, 2, 2], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 4.973228454589844, 5.973228454589844, 6.973228454589844, 7.973228454589844, 4.999990940093994, + 5.999990940093994, 6.999990940093994, 7.999990940093994 + ], + "dims": [1, 2, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention Basic", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 2, 2, 2], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 4.571832656860352, 5.571832656860352, 6.971858501434326, 7.971858501434326, 4.998325824737549, + 5.998325824737549, 6.999900817871094, 7.999900817871094 + ], + "dims": [1, 2, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention Basic with bias", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 2, 2, 2], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4], + "dims": [12], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 5.943336009979248, 7.94333553314209, 9.999799728393555, 11.999798774719238, 5.9997992515563965, + 7.9997992515563965, 10, 11.999999046325684 + ], + "dims": [1, 2, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention two heads", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [1, 2, 8], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4], + "dims": [1, 2, 8], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [1, 2, 8], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 8.99963665008545, 9.99963665008545, 10.99963665008545, 11.999635696411133, 13, 14, 15, 16, 9, 10, 11, 12, + 13, 14, 15, 16 + ], + "dims": [1, 2, 8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention two heads", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 2, "type": "int" }], + "cases": [ + { + "name": "T[1]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [1, 2, 8], + "type": "float32" + }, + { + "data": [1, 1, 1, 1, 2, 2, 2, 2], + "dims": [1, 1, 8], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 1, 8], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8], + "dims": [1, 2, 8], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index c80f0b04a9..37aa9394c7 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1336,6 +1336,7 @@ "add_int32.jsonc", //"and.jsonc", "asin.jsonc", + "attention.jsonc", "bias-add.jsonc", "bias-split-gelu.jsonc", "ceil.jsonc", @@ -1362,6 +1363,7 @@ "matmul-broadcast.jsonc", "mul.jsonc", "mul_int32.jsonc", + "multi-head-attention.jsonc", //"neg.jsonc", "neg-int32.jsonc", "not.jsonc", diff --git a/onnxruntime/contrib_ops/js/bert/attention.cc b/onnxruntime/contrib_ops/js/bert/attention.cc new file mode 100644 index 0000000000..723ff00aa8 --- /dev/null +++ b/onnxruntime/contrib_ops/js/bert/attention.cc @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "attention.h" +#include "core/providers/js/js_data_types.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX( + Attention, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedFloatTypes()), + Attention); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/bert/attention.h b/onnxruntime/contrib_ops/js/bert/attention.h new file mode 100644 index 0000000000..0fa823befa --- /dev/null +++ b/onnxruntime/contrib_ops/js/bert/attention.h @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/cpu/bert/attention_base.h" +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::contrib::AttentionBase; +using onnxruntime::js::JsKernel; + +class Attention : public JsKernel, AttentionBase { + public: + explicit Attention(const OpKernelInfo& info) : JsKernel(info), AttentionBase(info, false) { + std::vector qkv_sizes(qkv_hidden_sizes_.size()); + if (qkv_hidden_sizes_.size() > 0) { + std::transform(qkv_hidden_sizes_.begin(), qkv_hidden_sizes_.end(), qkv_sizes.begin(), + [](int64_t sz) { return gsl::narrow_cast(sz); }); + } + + JSEP_INIT_KERNEL_ATTRIBUTE(Attention, ({ + "numHeads" : $1, + "isUnidirectional" : $2, + "maskFilterValue" : $3, + "scale" : $4, + "doRotary" : $5, + "qkvHiddenSizes" : $6 ? (Array.from(HEAP32.subarray(Number($7), Number($7) + $6))) : [], + "pastPresentShareBuffer" : !!$8, + }), + static_cast(num_heads_), + static_cast(is_unidirectional_), + static_cast(mask_filter_value_), + static_cast(scale_), + static_cast(do_rotary_), + static_cast(qkv_hidden_sizes_.size()), + reinterpret_cast((qkv_sizes.size() > 0) ? qkv_sizes.data() : nullptr) >> 2, + static_cast(past_present_share_buffer_)); + } +}; + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/bert/multi_head_attention.cc b/onnxruntime/contrib_ops/js/bert/multi_head_attention.cc new file mode 100644 index 0000000000..c43f8b7f18 --- /dev/null +++ b/onnxruntime/contrib_ops/js/bert/multi_head_attention.cc @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "multi_head_attention.h" +#include "core/providers/js/js_data_types.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX( + MultiHeadAttention, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedFloatTypes()), + MultiHeadAttention); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/bert/multi_head_attention.h b/onnxruntime/contrib_ops/js/bert/multi_head_attention.h new file mode 100644 index 0000000000..6c63a2ffed --- /dev/null +++ b/onnxruntime/contrib_ops/js/bert/multi_head_attention.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/cpu/bert/attention_base.h" +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::contrib::AttentionBase; +using onnxruntime::js::JsKernel; + +class MultiHeadAttention : public JsKernel, AttentionBase { + public: + explicit MultiHeadAttention(const OpKernelInfo& info) : JsKernel(info), AttentionBase(info, false) { + JSEP_INIT_KERNEL_ATTRIBUTE(MultiHeadAttention, ({ + "numHeads" : $1, + "isUnidirectional" : $2, + "maskFilterValue" : $3, + "scale" : $4, + "doRotary" : $5, + }), + static_cast(num_heads_), + static_cast(is_unidirectional_), + static_cast(mask_filter_value_), + static_cast(scale_), + static_cast(do_rotary_)); + } +}; + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc index 24d327576e..498a9f5679 100644 --- a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc @@ -7,7 +7,9 @@ namespace onnxruntime { namespace contrib { namespace js { +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Attention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization); @@ -21,7 +23,9 @@ KernelCreateInfo BuildKernelCreateInfo() { Status RegisterJsContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { + BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo