mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
[CUDA] Benchmark GQA on popular LLM models (#20646)
### Description
Update benchmark_gqa.py to test latency on popular models (like
Llama3-8b, Llama3-70b, Mixtral-8x22B-v0.1 and Phi-3 etc).
Note that this is latency of just one GroupQueryAttention node, not the
whole model. For example, packed QKV might need more time in GQA, but it
is faster in MatMul of input projection, the overall effect is not
measured here.
Example output in A100-SXM4-80GB :
```
prompt-sm80-Llama3-8B-b1-h32_8x128-fp16:
sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV
0 16.0 0.019073 0.016264
1 32.0 0.017768 0.017957
2 64.0 0.023304 0.023192
3 128.0 0.032541 0.031348
4 256.0 0.048329 0.049484
5 512.0 0.095294 0.095950
6 1024.0 0.228050 0.228980
7 2048.0 0.663820 0.663308
8 4096.0 2.243657 2.242999
9 8192.0 8.197120 8.186282
token-sm80-Llama3-8B-b1-h32_8_d128-fp16:
past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV
0 16.0 0.018516 0.015398
1 32.0 0.015687 0.016079
2 64.0 0.016115 0.016053
3 128.0 0.018727 0.019413
4 256.0 0.036373 0.035962
5 512.0 0.041701 0.042203
6 1024.0 0.053730 0.053750
7 2048.0 0.076382 0.075707
8 4096.0 0.121876 0.121802
9 8191.0 0.211292 0.211254
prompt-sm80-Llama3-8B-b4-h32_8x128-fp16:
sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV
0 16.0 0.024558 0.022070
1 32.0 0.021276 0.021406
2 64.0 0.044172 0.027789
3 128.0 0.069100 0.059071
4 256.0 0.146569 0.106717
5 512.0 0.270472 0.244461
6 1024.0 0.690024 0.692501
7 2048.0 2.308546 2.325453
8 4096.0 8.724295 8.957337
9 8192.0 39.030785 41.381378
token-sm80-Llama3-8B-b4-h32_8_d128-fp16:
past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV
0 16.0 0.018893 0.018611
1 32.0 0.018124 0.018190
2 64.0 0.018115 0.018156
3 128.0 0.023291 0.023733
4 256.0 0.038357 0.038351
5 512.0 0.047117 0.047792
6 1024.0 0.066272 0.065409
7 2048.0 0.104196 0.104527
8 4096.0 0.180557 0.180424
9 8191.0 0.332545 0.332714
prompt-sm80-Llama3-70B-b1-h64_8x128-fp16:
sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV
0 16.0 0.040974 0.015852
1 32.0 0.017839 0.018615
2 64.0 0.023956 0.022704
3 128.0 0.044622 0.035229
4 256.0 0.080241 0.075237
5 512.0 0.143457 0.144322
6 1024.0 0.380473 0.381731
7 2048.0 1.217328 1.214505
8 4096.0 4.305315 4.286324
9 8192.0 15.918250 15.933440
token-sm80-Llama3-70B-b1-h64_8_d128-fp16:
past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV
0 16.0 0.016148 0.015612
1 32.0 0.015616 0.015616
2 64.0 0.016082 0.016070
3 128.0 0.019470 0.019130
4 256.0 0.036617 0.037296
5 512.0 0.042087 0.042176
6 1024.0 0.053704 0.053587
7 2048.0 0.076918 0.076365
8 4096.0 0.122534 0.121984
9 8191.0 0.212961 0.213330
prompt-sm80-Llama3-70B-b4-h64_8x128-fp16:
sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV
0 16.0 0.031137 0.026270
1 32.0 0.030938 0.032009
2 64.0 0.040833 0.059118
3 128.0 0.084899 0.085482
4 256.0 0.163951 0.166310
5 512.0 0.420436 0.423721
6 1024.0 1.282019 1.283482
7 2048.0 4.397661 4.420121
8 4096.0 16.931839 17.456945
9 8192.0 77.896706 83.007484
token-sm80-Llama3-70B-b4-h64_8_d128-fp16:
past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV
0 16.0 0.026106 0.026061
1 32.0 0.025678 0.025589
2 64.0 0.025438 0.025965
3 128.0 0.033879 0.033320
4 256.0 0.058078 0.057656
5 512.0 0.078010 0.078153
6 1024.0 0.106353 0.098079
7 2048.0 0.160039 0.159153
8 4096.0 0.282527 0.283346
9 8191.0 0.546207 0.542135
prompt-sm80-Mistral-7B-v0.1-b1-h32_8x128-fp16:
sequence_length ORT-GQA-Dense ORT-GQA-Local ORT-GQA-Dense-PackedQKV ORT-GQA-Local-PackedQKV
0 16.0 0.015722 0.015655 0.015666 0.016150
1 32.0 0.018590 0.018562 0.018136 0.024617
2 64.0 0.022480 0.023085 0.023184 0.023160
3 128.0 0.029948 0.030581 0.030839 0.031464
4 256.0 0.048532 0.049099 0.049424 0.049408
5 512.0 0.095096 0.095665 0.096174 0.096175
6 1024.0 0.228606 0.228942 0.228434 0.229568
7 2048.0 0.660832 0.661943 0.662170 0.663979
8 4096.0 2.238001 2.243999 2.242243 2.241707
9 8192.0 8.173824 6.147072 8.187648 6.152822
10 16384.0 33.826305 14.486015 34.849792 14.938283
11 32768.0 176.702469 32.725330 184.309753 34.736130
token-sm80-Mistral-7B-v0.1-b1-h32_8_d128-fp16:
past_sequence_length ORT-GQA-Dense ORT-GQA-Local ORT-GQA-Dense-PackedQKV ORT-GQA-Local-PackedQKV
0 16.0 0.015407 0.016042 0.016030 0.015429
1 32.0 0.015525 0.016115 0.016768 0.016052
2 64.0 0.015556 0.016079 0.015383 0.016008
3 128.0 0.019302 0.018644 0.018680 0.019278
4 256.0 0.036924 0.035900 0.036753 0.036786
5 512.0 0.041482 0.041434 0.041646 0.042238
6 1024.0 0.053587 0.052972 0.052888 0.052856
7 2048.0 0.075749 0.075807 0.076528 0.075945
8 4096.0 0.122053 0.122016 0.122115 0.122216
9 8192.0 0.212069 0.121317 0.211919 0.121087
10 16384.0 0.394036 0.121202 0.393661 0.121483
11 32767.0 0.757216 0.124326 0.757659 0.124157
prompt-sm80-Mistral-7B-v0.1-b4-h32_8x128-fp16:
sequence_length ORT-GQA-Dense ORT-GQA-Local ORT-GQA-Dense-PackedQKV ORT-GQA-Local-PackedQKV
0 16.0 0.018418 0.018911 0.023387 0.019256
1 32.0 0.021085 0.021132 0.022143 0.022251
2 64.0 0.026743 0.026770 0.027942 0.027714
3 128.0 0.057922 0.058483 0.058800 0.059402
4 256.0 0.105927 0.104876 0.106695 0.105996
5 512.0 0.242958 0.242543 0.244599 0.244774
6 1024.0 0.689321 0.689347 0.691759 0.692334
7 2048.0 2.308250 2.304410 2.321587 2.317875
8 4096.0 8.705210 8.713682 8.927418 8.903866
9 8192.0 39.630848 28.227926 41.604607 29.648554
10 16384.0 175.553543 61.422592 183.384064 64.560127
11 32768.0 772.296692 132.006912 813.537292 138.996735
token-sm80-Mistral-7B-v0.1-b4-h32_8_d128-fp16:
past_sequence_length ORT-GQA-Dense ORT-GQA-Local ORT-GQA-Dense-PackedQKV ORT-GQA-Local-PackedQKV
0 16.0 0.018127 0.018691 0.018661 0.018681
1 32.0 0.018183 0.018812 0.018739 0.018759
2 64.0 0.018081 0.018116 0.018136 0.018153
3 128.0 0.023257 0.023146 0.023114 0.023103
4 256.0 0.038665 0.038102 0.038120 0.038759
5 512.0 0.047181 0.047156 0.047012 0.046382
6 1024.0 0.066047 0.066103 0.066604 0.066076
7 2048.0 0.104427 0.103770 0.103799 0.103807
8 4096.0 0.180951 0.180373 0.180173 0.180154
9 8192.0 0.334018 0.180801 0.333269 0.180690
10 16384.0 0.638682 0.180965 0.638543 0.180202
11 32767.0 1.249536 0.184779 1.249963 0.184624
prompt-sm80-Mixtral-8x22B-v0.1-b1-h48_8x128-fp16:
sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV
0 16.0 0.015699 0.015563
1 32.0 0.017931 0.017719
2 64.0 0.029975 0.022875
3 128.0 0.031038 0.055747
4 256.0 0.050191 0.050845
5 512.0 0.125187 0.122813
6 1024.0 0.304004 0.301824
7 2048.0 0.936454 0.931546
8 4096.0 3.264547 3.255931
9 8192.0 12.062719 12.030080
10 16384.0 49.018368 48.970749
11 32768.0 261.211151 254.461945
12 65536.0 1221.138428 1197.559814
token-sm80-Mixtral-8x22B-v0.1-b1-h48_8_d128-fp16:
past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV
0 16.0 0.015980 0.016024
1 32.0 0.015440 0.016165
2 64.0 0.015987 0.015979
3 128.0 0.020837 0.018715
4 256.0 0.036240 0.036747
5 512.0 0.042477 0.041813
6 1024.0 0.052950 0.052956
7 2048.0 0.076084 0.076691
8 4096.0 0.122233 0.121540
9 8192.0 0.212469 0.212433
10 16384.0 0.394937 0.394996
11 32768.0 0.757285 0.757257
12 65535.0 1.484867 1.485015
prompt-sm80-Mixtral-8x22B-v0.1-b4-h48_8x128-fp16:
sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV
0 16.0 0.024119 0.018755
1 32.0 0.022214 0.022267
2 64.0 0.028045 0.027562
3 128.0 0.062894 0.079766
4 256.0 0.135146 0.134483
5 512.0 0.331323 0.329094
6 1024.0 0.984576 0.982221
7 2048.0 3.353564 3.351021
8 4096.0 12.762113 12.778350
9 8192.0 58.599422 57.704449
10 16384.0 263.392242 258.709503
11 32768.0 1155.789795 1128.622070
12 65536.0 5014.187012 4874.590332
token-sm80-Mixtral-8x22B-v0.1-b4-h48_8_d128-fp16:
past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV
0 16.0 0.018148 0.018813
1 32.0 0.018929 0.018840
2 64.0 0.018745 0.018232
3 128.0 0.023864 0.023822
4 256.0 0.038603 0.038694
5 512.0 0.048347 0.047630
6 1024.0 0.066957 0.067392
7 2048.0 0.105094 0.105058
8 4096.0 0.181941 0.181808
9 8192.0 0.334227 0.334324
10 16384.0 0.640429 0.640961
11 32768.0 1.267897 1.269120
12 65535.0 2.534238 2.504408
prompt-sm80-Phi-3-mini-128k-b1-h32_32x96-fp16:
sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV
0 16.0 0.016112 0.026949
1 32.0 0.016486 0.017284
2 64.0 0.020910 0.020994
3 128.0 0.029306 0.029452
4 256.0 0.044604 0.044642
5 512.0 0.090079 0.086868
6 1024.0 0.208169 0.208094
7 2048.0 0.604687 0.607910
8 4096.0 2.029056 2.046771
9 8192.0 7.792128 7.906303
10 16384.0 34.271233 34.418175
11 32768.0 160.377853 159.980545
12 65536.0 733.443054 734.722046
token-sm80-Phi-3-mini-128k-b1-h32_32_d96-fp16:
past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV
0 16.0 0.016339 0.015718
1 32.0 0.016572 0.015964
2 64.0 0.016182 0.016192
3 128.0 0.019373 0.018621
4 256.0 0.021856 0.022463
5 512.0 0.028943 0.028888
6 1024.0 0.041124 0.041104
7 2048.0 0.067668 0.067542
8 4096.0 0.117528 0.117447
9 8192.0 0.216241 0.215492
10 16384.0 0.413434 0.414047
11 32768.0 0.811085 0.810612
12 65536.0 1.606189 1.606458
13 131071.0 3.193037 3.192491
prompt-sm80-Phi-3-mini-128k-b4-h32_32x96-fp16:
sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV
0 16.0 0.019385 0.019403
1 32.0 0.019801 0.020006
2 64.0 0.025958 0.025376
3 128.0 0.056445 0.055909
4 256.0 0.103180 0.102221
5 512.0 0.244224 0.244360
6 1024.0 0.703066 0.709327
7 2048.0 2.307456 2.335001
8 4096.0 8.334522 8.406760
9 8192.0 33.340416 33.758209
10 16384.0 144.141312 145.005569
11 32768.0 655.496216 655.656982
12 65536.0 2981.463135 2984.790039
token-sm80-Phi-3-mini-128k-b4-h32_32_d96-fp16:
past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV
0 16.0 0.018701 0.018185
1 32.0 0.020625 0.019213
2 64.0 0.019936 0.019943
3 128.0 0.023648 0.023689
4 256.0 0.030309 0.030305
5 512.0 0.043501 0.043801
6 1024.0 0.067314 0.068014
7 2048.0 0.108649 0.108134
8 4096.0 0.186053 0.186848
9 8192.0 0.339973 0.339742
10 16384.0 0.643288 0.644366
11 32768.0 1.261468 1.261510
12 65536.0 2.502252 2.501820
13 131071.0 4.990437 4.989521
prompt-sm80-Phi-3-small-128k-b1-h32_8x128-fp16:
sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV
0 16.0 0.025280 0.023331
1 32.0 0.023071 0.025931
2 64.0 0.022883 0.026258
3 128.0 0.030658 0.031445
4 256.0 0.057659 0.057073
5 512.0 0.095589 0.106579
6 1024.0 0.228532 0.229402
7 2048.0 0.662315 0.663349
8 4096.0 2.242885 2.248095
9 8192.0 8.194646 8.180395
10 16384.0 33.926659 35.130882
11 32768.0 175.320068 184.967163
12 65536.0 810.447876 847.632385
token-sm80-Phi-3-small-128k-b1-h32_8_d128-fp16:
past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV
0 16.0 0.015517 0.016038
1 32.0 0.016372 0.015477
2 64.0 0.015472 0.016016
3 128.0 0.019291 0.018664
4 256.0 0.036250 0.035990
5 512.0 0.041691 0.042238
6 1024.0 0.053730 0.053126
7 2048.0 0.075912 0.076439
8 4096.0 0.121336 0.121334
9 8192.0 0.213104 0.212443
10 16384.0 0.394353 0.394272
11 32768.0 0.756965 0.757017
12 65536.0 1.484548 1.485371
13 131071.0 2.939200 2.939552
prompt-sm80-Phi-3-small-128k-b4-h32_8x128-fp16:
sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV
0 16.0 0.044326 0.019298
1 32.0 0.021840 0.021408
2 64.0 0.027492 0.027802
3 128.0 0.058128 0.059431
4 256.0 0.104300 0.106019
5 512.0 0.242562 0.244948
6 1024.0 0.689614 0.692305
7 2048.0 2.297931 2.312857
8 4096.0 8.654848 8.843170
9 8192.0 38.770176 40.929279
10 16384.0 175.572998 183.692291
11 32768.0 780.126221 820.551697
12 65536.0 3357.564941 3488.527344
token-sm80-Phi-3-small-128k-b4-h32_8_d128-fp16:
past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV
0 16.0 0.018061 0.017995
1 32.0 0.018225 0.018851
2 64.0 0.018203 0.018104
3 128.0 0.023161 0.023651
4 256.0 0.038421 0.037673
5 512.0 0.047590 0.046938
6 1024.0 0.065639 0.066055
7 2048.0 0.103545 0.103581
8 4096.0 0.180461 0.179998
9 8192.0 0.332667 0.332564
10 16384.0 0.638503 0.639094
11 32768.0 1.249180 1.249479
12 65536.0 2.469457 2.471666
13 131071.0 4.915362 4.914499
prompt-sm80-Phi-3-medium-128K-b1-h40_10x128-fp16:
sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV
0 16.0 0.025759 0.016318
1 32.0 0.018282 0.018111
2 64.0 0.022642 0.022978
3 128.0 0.030860 0.037988
4 256.0 0.055703 0.050318
5 512.0 0.113465 0.113776
6 1024.0 0.267678 0.268292
7 2048.0 0.795202 0.797222
8 4096.0 2.737953 2.740435
9 8192.0 10.101760 10.149092
10 16384.0 43.326466 43.990013
11 32768.0 230.886398 229.886978
12 65536.0 1067.412476 1052.922852
token-sm80-Phi-3-medium-128K-b1-h40_10_d128-fp16:
past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV
0 16.0 0.016122 0.015582
1 32.0 0.015594 0.016262
2 64.0 0.016099 0.015512
3 128.0 0.018708 0.019510
4 256.0 0.037582 0.036341
5 512.0 0.042411 0.041894
6 1024.0 0.053278 0.053914
7 2048.0 0.076553 0.076636
8 4096.0 0.121539 0.121610
9 8192.0 0.212083 0.212377
10 16384.0 0.395086 0.395280
11 32768.0 0.757879 0.757888
12 65536.0 1.486093 1.486915
13 131071.0 2.941728 2.941408
prompt-sm80-Phi-3-medium-128K-b4-h40_10x128-fp16:
sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV
0 16.0 0.019448 0.018872
1 32.0 0.022290 0.022380
2 64.0 0.027986 0.027955
3 128.0 0.062699 0.062175
4 256.0 0.124868 0.125247
5 512.0 0.298873 0.298169
6 1024.0 0.862584 0.863467
7 2048.0 2.944640 2.957824
8 4096.0 11.318656 11.390720
9 8192.0 52.606976 52.019199
10 16384.0 232.616959 230.360062
11 32768.0 1024.171997 1019.540466
12 65536.0 4377.362305 4354.510742
token-sm80-Phi-3-medium-128K-b4-h40_10_d128-fp16:
past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV
0 16.0 0.018192 0.018175
1 32.0 0.018999 0.018319
2 64.0 0.018447 0.018897
3 128.0 0.023863 0.023195
4 256.0 0.037712 0.038192
5 512.0 0.048863 0.048548
6 1024.0 0.067244 0.066473
7 2048.0 0.105203 0.105021
8 4096.0 0.180712 0.180429
9 8192.0 0.334948 0.334734
10 16384.0 0.640662 0.639709
11 32768.0 1.252196 1.251684
12 65536.0 2.474927 2.474280
13 131071.0 4.930829 4.959340
```
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
cfe830b248
commit
85facd678b
1 changed files with 103 additions and 302 deletions
|
|
@ -1,302 +1,75 @@
|
|||
import math
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
"""
|
||||
Benchmark performance of GroupQueryAttention.
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from onnx import TensorProto, helper
|
||||
|
||||
from onnxruntime import InferenceSession, SessionOptions
|
||||
from onnxruntime.transformers.io_binding_helper import CudaSession, GpuBindingManager
|
||||
from test_sparse_attention import GroupQueryAttentionConfig, OrtGroupQueryAttention
|
||||
|
||||
|
||||
class AttentionConfig:
|
||||
def __init__(
|
||||
self,
|
||||
operator: str,
|
||||
batch_size: int,
|
||||
sequence_length: int,
|
||||
max_sequence_length: int,
|
||||
past_sequence_length: int,
|
||||
num_heads: int,
|
||||
kv_num_heads: int,
|
||||
head_size: int,
|
||||
softmax_scale: Optional[float],
|
||||
do_rotary: bool,
|
||||
rotary_interleaved: bool,
|
||||
device="cuda",
|
||||
dtype=torch.float16,
|
||||
share_buffer: bool = True,
|
||||
is_packed_qkv: bool = False,
|
||||
):
|
||||
self.operator = operator
|
||||
self.batch_size = batch_size
|
||||
self.sequence_length = sequence_length
|
||||
self.max_sequence_length = max_sequence_length
|
||||
self.past_sequence_length = past_sequence_length
|
||||
self.num_heads = num_heads
|
||||
self.kv_num_heads = kv_num_heads
|
||||
self.head_size = head_size
|
||||
self.softmax_scale = softmax_scale if softmax_scale is not None else 1.0 / (head_size**0.5)
|
||||
|
||||
# Derived values
|
||||
self.total_sequence_length = sequence_length + past_sequence_length
|
||||
self.past_buffer_length = max_sequence_length if share_buffer else past_sequence_length
|
||||
self.present_buffer_length = max_sequence_length if share_buffer else (past_sequence_length + sequence_length)
|
||||
|
||||
self.do_rotary = do_rotary
|
||||
self.rotary_interleaved = rotary_interleaved
|
||||
self.device = device
|
||||
|
||||
self.share_buffer = share_buffer
|
||||
self.is_packed_qkv = is_packed_qkv
|
||||
self.dtype = dtype
|
||||
|
||||
def shape_dict(self):
|
||||
return {
|
||||
"query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size),
|
||||
"key": (self.batch_size, self.sequence_length, self.kv_num_heads * self.head_size),
|
||||
"value": (self.batch_size, self.sequence_length, self.kv_num_heads * self.head_size),
|
||||
"past_key": (self.batch_size, self.kv_num_heads, self.past_buffer_length, self.head_size),
|
||||
"past_value": (self.batch_size, self.kv_num_heads, self.past_buffer_length, self.head_size),
|
||||
"total_sequence_length": (1,),
|
||||
"output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size),
|
||||
"present_key": (self.batch_size, self.kv_num_heads, self.present_buffer_length, self.head_size),
|
||||
"present_value": (self.batch_size, self.kv_num_heads, self.present_buffer_length, self.head_size),
|
||||
"cos_cache": (self.max_sequence_length, (math.floor(self.head_size / 16) * 16) // 2),
|
||||
"sin_cache": (self.max_sequence_length, (math.floor(self.head_size / 16) * 16) // 2),
|
||||
}
|
||||
|
||||
def get_cos_sin_cache(self, dtype):
|
||||
rotary_fraction = 1.0
|
||||
rotary_dim = math.floor(int(rotary_fraction * self.head_size) / 16) * 16
|
||||
angle = torch.rand(self.max_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi
|
||||
cos = torch.cos(angle).to(dtype=dtype)
|
||||
sin = torch.sin(angle).to(dtype=dtype)
|
||||
return cos.to(device=self.device), sin.to(device=self.device)
|
||||
|
||||
def random_inputs(self):
|
||||
device = self.device
|
||||
# bfloat16 is not supported in ORT python I/O binding API
|
||||
dtype = torch.float16
|
||||
shape_dict = self.shape_dict()
|
||||
|
||||
torch.manual_seed(123)
|
||||
feeds = {
|
||||
"query": torch.empty(shape_dict["query"], device=device, dtype=dtype).normal_(mean=0, std=0.1),
|
||||
"key": torch.empty(shape_dict["key"], device=device, dtype=dtype).normal_(mean=0, std=0.1),
|
||||
"value": torch.empty(shape_dict["value"], device=device, dtype=dtype).normal_(mean=0, std=0.1),
|
||||
"past_key": torch.empty(shape_dict["past_key"], device=device, dtype=dtype).normal_(mean=0, std=0.1),
|
||||
"past_value": torch.empty(shape_dict["past_value"], device=device, dtype=dtype).normal_(mean=0, std=0.1),
|
||||
"total_sequence_length": torch.tensor([self.total_sequence_length], dtype=torch.int32),
|
||||
}
|
||||
|
||||
if self.do_rotary:
|
||||
cos_cache, sin_cache = self.get_cos_sin_cache(dtype)
|
||||
feeds["cos_cache"] = cos_cache
|
||||
feeds["sin_cache"] = sin_cache
|
||||
|
||||
return feeds
|
||||
|
||||
|
||||
class GroupQueryAttentionConfig(AttentionConfig):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
sequence_length: int,
|
||||
max_sequence_length: int,
|
||||
past_sequence_length: int,
|
||||
num_heads: int,
|
||||
kv_num_heads: int,
|
||||
head_size: int,
|
||||
softmax_scale=None,
|
||||
do_rotary: bool = False,
|
||||
rotary_interleaved: bool = False,
|
||||
device="cuda",
|
||||
local_window_size: int = -1,
|
||||
):
|
||||
super().__init__(
|
||||
"GroupQueryAttention",
|
||||
batch_size,
|
||||
sequence_length,
|
||||
max_sequence_length,
|
||||
past_sequence_length,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
head_size,
|
||||
softmax_scale,
|
||||
do_rotary,
|
||||
rotary_interleaved,
|
||||
device,
|
||||
)
|
||||
self.local_window_size = local_window_size
|
||||
|
||||
def shape_dict(self):
|
||||
shapes = super().shape_dict()
|
||||
shapes.update(
|
||||
{
|
||||
"seqlens_k": (self.batch_size,),
|
||||
}
|
||||
)
|
||||
return shapes
|
||||
|
||||
def random_inputs(self):
|
||||
feeds = super().random_inputs()
|
||||
k_seqlens = torch.ones((self.batch_size,), device=self.device, dtype=torch.int32) * self.total_sequence_length
|
||||
feeds.update(
|
||||
{
|
||||
"seqlens_k": k_seqlens - 1,
|
||||
}
|
||||
)
|
||||
return feeds
|
||||
|
||||
|
||||
def create_group_query_attention_onnx_model(config: GroupQueryAttentionConfig):
|
||||
assert config.dtype == torch.float16
|
||||
|
||||
float_type = TensorProto.FLOAT16
|
||||
nodes = [
|
||||
helper.make_node(
|
||||
"GroupQueryAttention",
|
||||
[
|
||||
"query",
|
||||
"key" if not config.is_packed_qkv else "",
|
||||
"value" if not config.is_packed_qkv else "",
|
||||
"past_key",
|
||||
"past_value",
|
||||
"seqlens_k",
|
||||
"total_sequence_length" if config.share_buffer else "",
|
||||
"cos_cache" if config.do_rotary else "",
|
||||
"sin_cache" if config.do_rotary else "",
|
||||
],
|
||||
["output", "present_key", "present_value"],
|
||||
"GroupQueryAttention_0",
|
||||
num_heads=config.num_heads,
|
||||
kv_num_heads=config.kv_num_heads,
|
||||
scale=config.softmax_scale,
|
||||
local_window_size=config.local_window_size,
|
||||
do_rotary=1 if config.do_rotary else 0,
|
||||
rotary_interleaved=config.rotary_interleaved,
|
||||
domain="com.microsoft",
|
||||
),
|
||||
]
|
||||
|
||||
shape_dict = config.shape_dict()
|
||||
graph_input = [
|
||||
helper.make_tensor_value_info("query", float_type, list(shape_dict["query"])),
|
||||
helper.make_tensor_value_info("key", float_type, list(shape_dict["key"])),
|
||||
helper.make_tensor_value_info("value", float_type, list(shape_dict["value"])),
|
||||
helper.make_tensor_value_info("past_key", float_type, list(shape_dict["past_key"])),
|
||||
helper.make_tensor_value_info("past_value", float_type, list(shape_dict["past_value"])),
|
||||
helper.make_tensor_value_info("seqlens_k", TensorProto.INT32, list(shape_dict["seqlens_k"])),
|
||||
helper.make_tensor_value_info(
|
||||
"total_sequence_length", TensorProto.INT32, list(shape_dict["total_sequence_length"])
|
||||
),
|
||||
]
|
||||
|
||||
if config.do_rotary:
|
||||
graph_input += [
|
||||
helper.make_tensor_value_info("cos_cache", float_type, list(shape_dict["cos_cache"])),
|
||||
helper.make_tensor_value_info("sin_cache", float_type, list(shape_dict["sin_cache"])),
|
||||
]
|
||||
|
||||
graph_output = [
|
||||
helper.make_tensor_value_info("output", float_type, list(shape_dict["output"])),
|
||||
helper.make_tensor_value_info("present_key", float_type, list(shape_dict["present_key"])),
|
||||
helper.make_tensor_value_info("present_value", float_type, list(shape_dict["present_value"])),
|
||||
]
|
||||
|
||||
graph = helper.make_graph(
|
||||
nodes,
|
||||
"GroupQueryAttention_Graph",
|
||||
graph_input,
|
||||
graph_output,
|
||||
)
|
||||
|
||||
model = helper.make_model(graph)
|
||||
return model.SerializeToString()
|
||||
|
||||
|
||||
def create_session(onnx_model_str, cuda_provider_options=None) -> InferenceSession:
|
||||
session_options = SessionOptions()
|
||||
ort_session = InferenceSession(
|
||||
onnx_model_str,
|
||||
session_options,
|
||||
providers=[("CUDAExecutionProvider", cuda_provider_options), "CPUExecutionProvider"],
|
||||
)
|
||||
return ort_session
|
||||
|
||||
|
||||
class OrtGroupQueryAttention:
|
||||
"""A wrapper of ORT GroupQueryAttention to test relevance and performance."""
|
||||
|
||||
def __init__(self, config: GroupQueryAttentionConfig):
|
||||
device = config.device
|
||||
cuda_provider_options = CudaSession.get_cuda_provider_options(
|
||||
torch.cuda.current_device(), enable_cuda_graph=False, stream=torch.cuda.current_stream().cuda_stream
|
||||
)
|
||||
onnx_model_str = create_group_query_attention_onnx_model(config)
|
||||
self.ort_session = create_session(onnx_model_str, cuda_provider_options=cuda_provider_options)
|
||||
self.gpu_binding_manager = GpuBindingManager(
|
||||
ort_session=self.ort_session,
|
||||
device=device,
|
||||
stream=torch.cuda.current_stream().cuda_stream,
|
||||
max_cuda_graphs=2,
|
||||
)
|
||||
buffer_sharing = {"past_key": "present_key", "past_value": "present_value"}
|
||||
self.gpu_binding = self.gpu_binding_manager.get_binding(
|
||||
config.shape_dict(), use_cuda_graph=False, buffer_sharing=buffer_sharing
|
||||
)
|
||||
self.feed_dict = config.random_inputs()
|
||||
|
||||
def infer(self):
|
||||
return self.gpu_binding.infer(self.feed_dict)
|
||||
|
||||
|
||||
def get_plot_algos(sm: int):
|
||||
def get_plot_algos(sm: int, local_window_size: Optional[int]):
|
||||
# GQA with local windows only works in sm=8x
|
||||
if sm >= 80:
|
||||
if sm >= 80 and local_window_size:
|
||||
return {
|
||||
"line_vals": ["ort_gqa", "ort_gqa_local"],
|
||||
"line_names": ["ORT-GQA-Dense", "ORT-GQA-Local"],
|
||||
"styles": [("red", "-"), ("blue", "-")],
|
||||
"line_vals": ["ort_gqa", "ort_gqa_local", "ort_gqa_packed", "ort_gqa_local_packed"],
|
||||
"line_names": ["ORT-GQA-Dense", "ORT-GQA-Local", "ORT-GQA-Dense-PackedQKV", "ORT-GQA-Local-PackedQKV"],
|
||||
"styles": [("red", "solid"), ("yellow", "dashdot"), ("blue", "dashed"), ("green", "dotted")],
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"line_vals": ["ort_gqa"],
|
||||
"line_names": ["ORT-GQA-Dense"],
|
||||
"styles": [("green", "-")],
|
||||
"line_vals": ["ort_gqa", "ort_gqa_packed"],
|
||||
"line_names": ["ORT-GQA-Dense", "ORT-GQA-Dense-PackedQKV"],
|
||||
"styles": [("red", "solid"), ("blue", "dashed")],
|
||||
}
|
||||
|
||||
|
||||
def plot_prompt_performance(
|
||||
sm: int,
|
||||
batch_size=4,
|
||||
num_heads=32,
|
||||
kv_num_heads=8,
|
||||
max_seq_len=8192,
|
||||
head_size=128,
|
||||
model_name: str,
|
||||
batch_size: int,
|
||||
num_heads: int,
|
||||
kv_num_heads: int,
|
||||
head_size: int,
|
||||
max_seq_len: int,
|
||||
local_window_size: Optional[int] = None,
|
||||
):
|
||||
import triton
|
||||
|
||||
algos = get_plot_algos(sm)
|
||||
algos = get_plot_algos(sm, local_window_size)
|
||||
configs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=["sequence_length"],
|
||||
x_vals=[2**i for i in range(4, 14)],
|
||||
x_vals=[2**i for i in range(4, 17) if 2**i <= max_seq_len],
|
||||
line_arg="provider",
|
||||
ylabel="ms",
|
||||
**algos,
|
||||
plot_name=f"prompt-sm{sm}-batch{batch_size}-head{num_heads}_kv{kv_num_heads}-d{head_size}-fp16",
|
||||
plot_name=f"prompt-sm{sm}-{model_name}-b{batch_size}-h{num_heads}_{kv_num_heads}x{head_size}-fp16",
|
||||
args={
|
||||
"batch_size": batch_size,
|
||||
"num_heads": num_heads,
|
||||
"kv_num_heads": kv_num_heads,
|
||||
"batch_size": batch_size,
|
||||
"head_size": head_size,
|
||||
"local_window_size": local_window_size,
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def benchmark(batch_size, num_heads, kv_num_heads, sequence_length, head_size, provider, device="cuda"):
|
||||
def benchmark(
|
||||
provider: str,
|
||||
sequence_length: int,
|
||||
batch_size: int,
|
||||
num_heads: int,
|
||||
kv_num_heads: int,
|
||||
head_size: int,
|
||||
local_window_size: Optional[int] = None,
|
||||
device="cuda",
|
||||
):
|
||||
warmup = 15
|
||||
repeat = 100
|
||||
|
||||
|
|
@ -308,8 +81,9 @@ def plot_prompt_performance(
|
|||
num_heads=num_heads,
|
||||
kv_num_heads=kv_num_heads,
|
||||
head_size=head_size,
|
||||
local_window_size=1024 if provider == "ort_gqa_local" else -1,
|
||||
local_window_size=local_window_size if provider in ["ort_gqa_local", "ort_gqa_local_packed"] else -1,
|
||||
device=device,
|
||||
is_packed_qkv=provider in ["ort_gqa_packed", "ort_gqa_local_packed"],
|
||||
)
|
||||
|
||||
obj = OrtGroupQueryAttention(config)
|
||||
|
|
@ -322,40 +96,44 @@ def plot_prompt_performance(
|
|||
|
||||
def plot_token_performance(
|
||||
sm: int,
|
||||
batch_size=4,
|
||||
num_heads=32,
|
||||
kv_num_heads=8,
|
||||
max_seq_len=8192,
|
||||
head_size=128,
|
||||
model_name: str,
|
||||
batch_size: int,
|
||||
num_heads: int,
|
||||
kv_num_heads: int,
|
||||
head_size: int,
|
||||
max_seq_len: int,
|
||||
local_window_size: Optional[int] = None,
|
||||
):
|
||||
import triton
|
||||
|
||||
algos = get_plot_algos(sm)
|
||||
algos = get_plot_algos(sm, local_window_size)
|
||||
configs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=["past_sequence_length"],
|
||||
x_vals=[2**i for i in range(4, 13)] + [max_seq_len - 1],
|
||||
x_vals=[2**i for i in range(4, 17) if 2**i < max_seq_len] + [max_seq_len - 1],
|
||||
line_arg="provider",
|
||||
ylabel="ms",
|
||||
**algos,
|
||||
plot_name=f"token-sm{sm}-batch{batch_size}-head{num_heads}_kv{kv_num_heads}-d{head_size}-fp16",
|
||||
plot_name=f"token-sm{sm}-{model_name}-b{batch_size}-h{num_heads}_{kv_num_heads}_d{head_size}-fp16",
|
||||
args={
|
||||
"batch_size": batch_size,
|
||||
"num_heads": num_heads,
|
||||
"kv_num_heads": kv_num_heads,
|
||||
"batch_size": batch_size,
|
||||
"head_size": head_size,
|
||||
"local_window_size": local_window_size,
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def benchmark(
|
||||
batch_size,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
past_sequence_length,
|
||||
head_size,
|
||||
provider,
|
||||
provider: str,
|
||||
past_sequence_length: int,
|
||||
batch_size: int,
|
||||
num_heads: int,
|
||||
kv_num_heads: int,
|
||||
head_size: int,
|
||||
local_window_size: Optional[int] = None,
|
||||
device="cuda",
|
||||
):
|
||||
warmup = 15
|
||||
|
|
@ -369,7 +147,9 @@ def plot_token_performance(
|
|||
num_heads=num_heads,
|
||||
kv_num_heads=kv_num_heads,
|
||||
head_size=head_size,
|
||||
local_window_size=1024 if provider == "ort_gqa_local" else -1,
|
||||
local_window_size=local_window_size if provider in ["ort_gqa_local", "ort_gqa_local_packed"] else -1,
|
||||
do_rotary=True, # Most models use rotary positional embeddings
|
||||
is_packed_qkv=provider in ["ort_gqa_packed", "ort_gqa_local_packed"],
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
|
@ -386,25 +166,46 @@ def run_performance_test(sm: int):
|
|||
Run performance tests for prompt and token generation.
|
||||
|
||||
"""
|
||||
for batch_size in [1, 4, 8, 16]:
|
||||
for num_heads, kv_num_heads in [(8, 8), (16, 8), (32, 8), (64, 8)]:
|
||||
for head_size in [64, 128]:
|
||||
plot_prompt_performance(
|
||||
sm=sm,
|
||||
batch_size=batch_size,
|
||||
num_heads=num_heads,
|
||||
kv_num_heads=kv_num_heads,
|
||||
max_seq_len=8192,
|
||||
head_size=head_size,
|
||||
)
|
||||
plot_token_performance(
|
||||
sm=sm,
|
||||
batch_size=batch_size,
|
||||
num_heads=num_heads,
|
||||
kv_num_heads=kv_num_heads,
|
||||
max_seq_len=8192,
|
||||
head_size=head_size,
|
||||
)
|
||||
device_id = torch.cuda.current_device()
|
||||
memory_in_gb = torch.cuda.get_device_properties(device_id).total_memory / (1024 * 1024 * 1024)
|
||||
|
||||
# Note: some models use bf16.
|
||||
# We use fp16 for all models in this test since bf16 is not supported in ORT python API.
|
||||
configures = [
|
||||
(32, 128, 8, 8192, None, "Llama3-8B"),
|
||||
(64, 128, 8, 8192, None, "Llama3-70B"),
|
||||
(32, 128, 8, 32768, 4096, "Mistral-7B-v0.1"),
|
||||
(48, 128, 8, 65536, None, "Mixtral-8x22B-v0.1"),
|
||||
(32, 96, 32, 131072, None, "Phi-3-mini-128k"),
|
||||
(32, 128, 8, 131072, None, "Phi-3-small-128k"), # Sparsity is not used in this test
|
||||
(40, 128, 10, 131072, None, "Phi-3-medium-128K"),
|
||||
]
|
||||
|
||||
# Reduce max sequence length when GPU memory is not enough.
|
||||
threshold = 131072 if memory_in_gb > 24 else 65536 if memory_in_gb > 12 else 32768
|
||||
|
||||
for num_heads, head_size, kv_num_heads, max_seq_len, local_window_size, model_name in configures:
|
||||
for batch_size in [1, 4]:
|
||||
plot_prompt_performance(
|
||||
sm=sm,
|
||||
batch_size=batch_size,
|
||||
num_heads=num_heads,
|
||||
kv_num_heads=kv_num_heads,
|
||||
head_size=head_size,
|
||||
max_seq_len=min(threshold, max_seq_len),
|
||||
local_window_size=local_window_size,
|
||||
model_name=model_name,
|
||||
)
|
||||
plot_token_performance(
|
||||
sm=sm,
|
||||
batch_size=batch_size,
|
||||
num_heads=num_heads,
|
||||
kv_num_heads=kv_num_heads,
|
||||
head_size=head_size,
|
||||
max_seq_len=min(threshold, max_seq_len),
|
||||
local_window_size=local_window_size,
|
||||
model_name=model_name,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in a new issue