onnxruntime/js/web/test/data/ops/multi-head-attention.jsonc
Arthur Islamov fac3e33da5
[js/web] JSEP Attention & MultiHeadAttention (#17742)
### Description
This is a narrow implementation of Attention/MultiHeadAttention as it
does not support:
a. inputs 5-7 for MHA
b. packed QKV/KV
c. past/present
d. attention mask

But it works well for StableDiffusion and can be extended later. It
reduces VRAM usage as it combines many ops into few
I've updated demo here https://islamov.ai/stable-diffusion-webgpu/ it
takes ~13sec for 1 image with 20 steps on RTX3090Ti and about 25s on M1
Pro
VRAM usage is about 8gb if you don't use img2img

Going to focus on SDXL now

---------

Co-authored-by: Guenther Schmuelling <guschmue@microsoft.com>
Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com>
2023-11-17 12:23:52 -08:00

194 lines
5.2 KiB
Text

[
{
"name": "MultiHeadAttention Basic, one head",
"operator": "MultiHeadAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
"cases": [
{
"name": "T[0]",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6, 7, 8],
"dims": [1, 2, 4],
"type": "float32"
},
{
"data": [1, 1, 1, 1, 2, 2, 2, 2],
"dims": [1, 2, 4],
"type": "float32"
},
{
"data": [1, 2, 3, 4, 5, 6, 7, 8],
"dims": [1, 2, 4],
"type": "float32"
}
],
"outputs": [
{
"data": [
4.973228454589844, 5.973228454589844, 6.973228454589844, 7.973228454589844, 4.999990940093994,
5.999990940093994, 6.999990940093994, 7.999990940093994
],
"dims": [1, 2, 4],
"type": "float32"
}
]
}
]
},
{
"name": "MultiHeadAttention Basic",
"operator": "MultiHeadAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [{ "name": "num_heads", "data": 2, "type": "int" }],
"cases": [
{
"name": "T[0]",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6, 7, 8],
"dims": [1, 2, 4],
"type": "float32"
},
{
"data": [1, 1, 1, 1, 2, 2, 2, 2],
"dims": [1, 2, 4],
"type": "float32"
},
{
"data": [1, 2, 3, 4, 5, 6, 7, 8],
"dims": [1, 2, 4],
"type": "float32"
}
],
"outputs": [
{
"data": [
4.571832656860352, 5.571832656860352, 6.971858501434326, 7.971858501434326, 4.998325824737549,
5.998325824737549, 6.999900817871094, 7.999900817871094
],
"dims": [1, 2, 4],
"type": "float32"
}
]
}
]
},
{
"name": "MultiHeadAttention Basic with bias",
"operator": "MultiHeadAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [{ "name": "num_heads", "data": 2, "type": "int" }],
"cases": [
{
"name": "T[0]",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6, 7, 8],
"dims": [1, 2, 4],
"type": "float32"
},
{
"data": [1, 1, 1, 1, 2, 2, 2, 2],
"dims": [1, 2, 4],
"type": "float32"
},
{
"data": [1, 2, 3, 4, 5, 6, 7, 8],
"dims": [1, 2, 4],
"type": "float32"
},
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4],
"dims": [12],
"type": "float32"
}
],
"outputs": [
{
"data": [
5.943336009979248, 7.94333553314209, 9.999799728393555, 11.999798774719238, 5.9997992515563965,
7.9997992515563965, 10, 11.999999046325684
],
"dims": [1, 2, 4],
"type": "float32"
}
]
}
]
},
{
"name": "MultiHeadAttention two heads",
"operator": "MultiHeadAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [{ "name": "num_heads", "data": 2, "type": "int" }],
"cases": [
{
"name": "T[0]",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
"dims": [1, 2, 8],
"type": "float32"
},
{
"data": [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4],
"dims": [1, 2, 8],
"type": "float32"
},
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
"dims": [1, 2, 8],
"type": "float32"
}
],
"outputs": [
{
"data": [
8.99963665008545, 9.99963665008545, 10.99963665008545, 11.999635696411133, 13, 14, 15, 16, 9, 10, 11, 12,
13, 14, 15, 16
],
"dims": [1, 2, 8],
"type": "float32"
}
]
}
]
},
{
"name": "MultiHeadAttention two heads",
"operator": "MultiHeadAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [{ "name": "num_heads", "data": 2, "type": "int" }],
"cases": [
{
"name": "T[1]",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
"dims": [1, 2, 8],
"type": "float32"
},
{
"data": [1, 1, 1, 1, 2, 2, 2, 2],
"dims": [1, 1, 8],
"type": "float32"
},
{
"data": [1, 2, 3, 4, 5, 6, 7, 8],
"dims": [1, 1, 8],
"type": "float32"
}
],
"outputs": [
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8],
"dims": [1, 2, 8],
"type": "float32"
}
]
}
]
}
]