[CUDA] Benchmark GQA on popular LLM models (#20646)

### Description
Update benchmark_gqa.py to test latency on popular models (like
Llama3-8b, Llama3-70b, Mixtral-8x22B-v0.1 and Phi-3 etc).

Note that this is latency of just one GroupQueryAttention node, not the
whole model. For example, packed QKV might need more time in GQA, but it
is faster in MatMul of input projection, the overall effect is not
measured here.

Example output in  A100-SXM4-80GB :
```
prompt-sm80-Llama3-8B-b1-h32_8x128-fp16:
   sequence_length  ORT-GQA-Dense  ORT-GQA-Dense-PackedQKV
0             16.0       0.019073                 0.016264
1             32.0       0.017768                 0.017957
2             64.0       0.023304                 0.023192
3            128.0       0.032541                 0.031348
4            256.0       0.048329                 0.049484
5            512.0       0.095294                 0.095950
6           1024.0       0.228050                 0.228980
7           2048.0       0.663820                 0.663308
8           4096.0       2.243657                 2.242999
9           8192.0       8.197120                 8.186282

token-sm80-Llama3-8B-b1-h32_8_d128-fp16:
   past_sequence_length  ORT-GQA-Dense  ORT-GQA-Dense-PackedQKV
0                  16.0       0.018516                 0.015398
1                  32.0       0.015687                 0.016079
2                  64.0       0.016115                 0.016053
3                 128.0       0.018727                 0.019413
4                 256.0       0.036373                 0.035962
5                 512.0       0.041701                 0.042203
6                1024.0       0.053730                 0.053750
7                2048.0       0.076382                 0.075707
8                4096.0       0.121876                 0.121802
9                8191.0       0.211292                 0.211254

prompt-sm80-Llama3-8B-b4-h32_8x128-fp16:
   sequence_length  ORT-GQA-Dense  ORT-GQA-Dense-PackedQKV
0             16.0       0.024558                 0.022070
1             32.0       0.021276                 0.021406
2             64.0       0.044172                 0.027789
3            128.0       0.069100                 0.059071
4            256.0       0.146569                 0.106717
5            512.0       0.270472                 0.244461
6           1024.0       0.690024                 0.692501
7           2048.0       2.308546                 2.325453
8           4096.0       8.724295                 8.957337
9           8192.0      39.030785                41.381378

token-sm80-Llama3-8B-b4-h32_8_d128-fp16:
   past_sequence_length  ORT-GQA-Dense  ORT-GQA-Dense-PackedQKV
0                  16.0       0.018893                 0.018611
1                  32.0       0.018124                 0.018190
2                  64.0       0.018115                 0.018156
3                 128.0       0.023291                 0.023733
4                 256.0       0.038357                 0.038351
5                 512.0       0.047117                 0.047792
6                1024.0       0.066272                 0.065409
7                2048.0       0.104196                 0.104527
8                4096.0       0.180557                 0.180424
9                8191.0       0.332545                 0.332714

prompt-sm80-Llama3-70B-b1-h64_8x128-fp16:
   sequence_length  ORT-GQA-Dense  ORT-GQA-Dense-PackedQKV
0             16.0       0.040974                 0.015852
1             32.0       0.017839                 0.018615
2             64.0       0.023956                 0.022704
3            128.0       0.044622                 0.035229
4            256.0       0.080241                 0.075237
5            512.0       0.143457                 0.144322
6           1024.0       0.380473                 0.381731
7           2048.0       1.217328                 1.214505
8           4096.0       4.305315                 4.286324
9           8192.0      15.918250                15.933440

token-sm80-Llama3-70B-b1-h64_8_d128-fp16:
   past_sequence_length  ORT-GQA-Dense  ORT-GQA-Dense-PackedQKV
0                  16.0       0.016148                 0.015612
1                  32.0       0.015616                 0.015616
2                  64.0       0.016082                 0.016070
3                 128.0       0.019470                 0.019130
4                 256.0       0.036617                 0.037296
5                 512.0       0.042087                 0.042176
6                1024.0       0.053704                 0.053587
7                2048.0       0.076918                 0.076365
8                4096.0       0.122534                 0.121984
9                8191.0       0.212961                 0.213330

prompt-sm80-Llama3-70B-b4-h64_8x128-fp16:
   sequence_length  ORT-GQA-Dense  ORT-GQA-Dense-PackedQKV
0             16.0       0.031137                 0.026270
1             32.0       0.030938                 0.032009
2             64.0       0.040833                 0.059118
3            128.0       0.084899                 0.085482
4            256.0       0.163951                 0.166310
5            512.0       0.420436                 0.423721
6           1024.0       1.282019                 1.283482
7           2048.0       4.397661                 4.420121
8           4096.0      16.931839                17.456945
9           8192.0      77.896706                83.007484

token-sm80-Llama3-70B-b4-h64_8_d128-fp16:
   past_sequence_length  ORT-GQA-Dense  ORT-GQA-Dense-PackedQKV
0                  16.0       0.026106                 0.026061
1                  32.0       0.025678                 0.025589
2                  64.0       0.025438                 0.025965
3                 128.0       0.033879                 0.033320
4                 256.0       0.058078                 0.057656
5                 512.0       0.078010                 0.078153
6                1024.0       0.106353                 0.098079
7                2048.0       0.160039                 0.159153
8                4096.0       0.282527                 0.283346
9                8191.0       0.546207                 0.542135

prompt-sm80-Mistral-7B-v0.1-b1-h32_8x128-fp16:
    sequence_length  ORT-GQA-Dense  ORT-GQA-Local  ORT-GQA-Dense-PackedQKV  ORT-GQA-Local-PackedQKV
0              16.0       0.015722       0.015655                 0.015666                 0.016150
1              32.0       0.018590       0.018562                 0.018136                 0.024617
2              64.0       0.022480       0.023085                 0.023184                 0.023160
3             128.0       0.029948       0.030581                 0.030839                 0.031464
4             256.0       0.048532       0.049099                 0.049424                 0.049408
5             512.0       0.095096       0.095665                 0.096174                 0.096175
6            1024.0       0.228606       0.228942                 0.228434                 0.229568
7            2048.0       0.660832       0.661943                 0.662170                 0.663979
8            4096.0       2.238001       2.243999                 2.242243                 2.241707
9            8192.0       8.173824       6.147072                 8.187648                 6.152822
10          16384.0      33.826305      14.486015                34.849792                14.938283
11          32768.0     176.702469      32.725330               184.309753                34.736130

token-sm80-Mistral-7B-v0.1-b1-h32_8_d128-fp16:
    past_sequence_length  ORT-GQA-Dense  ORT-GQA-Local  ORT-GQA-Dense-PackedQKV  ORT-GQA-Local-PackedQKV
0                   16.0       0.015407       0.016042                 0.016030                 0.015429
1                   32.0       0.015525       0.016115                 0.016768                 0.016052
2                   64.0       0.015556       0.016079                 0.015383                 0.016008
3                  128.0       0.019302       0.018644                 0.018680                 0.019278
4                  256.0       0.036924       0.035900                 0.036753                 0.036786
5                  512.0       0.041482       0.041434                 0.041646                 0.042238
6                 1024.0       0.053587       0.052972                 0.052888                 0.052856
7                 2048.0       0.075749       0.075807                 0.076528                 0.075945
8                 4096.0       0.122053       0.122016                 0.122115                 0.122216
9                 8192.0       0.212069       0.121317                 0.211919                 0.121087
10               16384.0       0.394036       0.121202                 0.393661                 0.121483
11               32767.0       0.757216       0.124326                 0.757659                 0.124157

prompt-sm80-Mistral-7B-v0.1-b4-h32_8x128-fp16:
    sequence_length  ORT-GQA-Dense  ORT-GQA-Local  ORT-GQA-Dense-PackedQKV  ORT-GQA-Local-PackedQKV
0              16.0       0.018418       0.018911                 0.023387                 0.019256
1              32.0       0.021085       0.021132                 0.022143                 0.022251
2              64.0       0.026743       0.026770                 0.027942                 0.027714
3             128.0       0.057922       0.058483                 0.058800                 0.059402
4             256.0       0.105927       0.104876                 0.106695                 0.105996
5             512.0       0.242958       0.242543                 0.244599                 0.244774
6            1024.0       0.689321       0.689347                 0.691759                 0.692334
7            2048.0       2.308250       2.304410                 2.321587                 2.317875
8            4096.0       8.705210       8.713682                 8.927418                 8.903866
9            8192.0      39.630848      28.227926                41.604607                29.648554
10          16384.0     175.553543      61.422592               183.384064                64.560127
11          32768.0     772.296692     132.006912               813.537292               138.996735

token-sm80-Mistral-7B-v0.1-b4-h32_8_d128-fp16:
    past_sequence_length  ORT-GQA-Dense  ORT-GQA-Local  ORT-GQA-Dense-PackedQKV  ORT-GQA-Local-PackedQKV
0                   16.0       0.018127       0.018691                 0.018661                 0.018681
1                   32.0       0.018183       0.018812                 0.018739                 0.018759
2                   64.0       0.018081       0.018116                 0.018136                 0.018153
3                  128.0       0.023257       0.023146                 0.023114                 0.023103
4                  256.0       0.038665       0.038102                 0.038120                 0.038759
5                  512.0       0.047181       0.047156                 0.047012                 0.046382
6                 1024.0       0.066047       0.066103                 0.066604                 0.066076
7                 2048.0       0.104427       0.103770                 0.103799                 0.103807
8                 4096.0       0.180951       0.180373                 0.180173                 0.180154
9                 8192.0       0.334018       0.180801                 0.333269                 0.180690
10               16384.0       0.638682       0.180965                 0.638543                 0.180202
11               32767.0       1.249536       0.184779                 1.249963                 0.184624

prompt-sm80-Mixtral-8x22B-v0.1-b1-h48_8x128-fp16:
    sequence_length  ORT-GQA-Dense  ORT-GQA-Dense-PackedQKV
0              16.0       0.015699                 0.015563
1              32.0       0.017931                 0.017719
2              64.0       0.029975                 0.022875
3             128.0       0.031038                 0.055747
4             256.0       0.050191                 0.050845
5             512.0       0.125187                 0.122813
6            1024.0       0.304004                 0.301824
7            2048.0       0.936454                 0.931546
8            4096.0       3.264547                 3.255931
9            8192.0      12.062719                12.030080
10          16384.0      49.018368                48.970749
11          32768.0     261.211151               254.461945
12          65536.0    1221.138428              1197.559814

token-sm80-Mixtral-8x22B-v0.1-b1-h48_8_d128-fp16:
    past_sequence_length  ORT-GQA-Dense  ORT-GQA-Dense-PackedQKV
0                   16.0       0.015980                 0.016024
1                   32.0       0.015440                 0.016165
2                   64.0       0.015987                 0.015979
3                  128.0       0.020837                 0.018715
4                  256.0       0.036240                 0.036747
5                  512.0       0.042477                 0.041813
6                 1024.0       0.052950                 0.052956
7                 2048.0       0.076084                 0.076691
8                 4096.0       0.122233                 0.121540
9                 8192.0       0.212469                 0.212433
10               16384.0       0.394937                 0.394996
11               32768.0       0.757285                 0.757257
12               65535.0       1.484867                 1.485015

prompt-sm80-Mixtral-8x22B-v0.1-b4-h48_8x128-fp16:
    sequence_length  ORT-GQA-Dense  ORT-GQA-Dense-PackedQKV
0              16.0       0.024119                 0.018755
1              32.0       0.022214                 0.022267
2              64.0       0.028045                 0.027562
3             128.0       0.062894                 0.079766
4             256.0       0.135146                 0.134483
5             512.0       0.331323                 0.329094
6            1024.0       0.984576                 0.982221
7            2048.0       3.353564                 3.351021
8            4096.0      12.762113                12.778350
9            8192.0      58.599422                57.704449
10          16384.0     263.392242               258.709503
11          32768.0    1155.789795              1128.622070
12          65536.0    5014.187012              4874.590332

token-sm80-Mixtral-8x22B-v0.1-b4-h48_8_d128-fp16:
    past_sequence_length  ORT-GQA-Dense  ORT-GQA-Dense-PackedQKV
0                   16.0       0.018148                 0.018813
1                   32.0       0.018929                 0.018840
2                   64.0       0.018745                 0.018232
3                  128.0       0.023864                 0.023822
4                  256.0       0.038603                 0.038694
5                  512.0       0.048347                 0.047630
6                 1024.0       0.066957                 0.067392
7                 2048.0       0.105094                 0.105058
8                 4096.0       0.181941                 0.181808
9                 8192.0       0.334227                 0.334324
10               16384.0       0.640429                 0.640961
11               32768.0       1.267897                 1.269120
12               65535.0       2.534238                 2.504408

prompt-sm80-Phi-3-mini-128k-b1-h32_32x96-fp16:
    sequence_length  ORT-GQA-Dense  ORT-GQA-Dense-PackedQKV
0              16.0       0.016112                 0.026949
1              32.0       0.016486                 0.017284
2              64.0       0.020910                 0.020994
3             128.0       0.029306                 0.029452
4             256.0       0.044604                 0.044642
5             512.0       0.090079                 0.086868
6            1024.0       0.208169                 0.208094
7            2048.0       0.604687                 0.607910
8            4096.0       2.029056                 2.046771
9            8192.0       7.792128                 7.906303
10          16384.0      34.271233                34.418175
11          32768.0     160.377853               159.980545
12          65536.0     733.443054               734.722046

token-sm80-Phi-3-mini-128k-b1-h32_32_d96-fp16:
    past_sequence_length  ORT-GQA-Dense  ORT-GQA-Dense-PackedQKV
0                   16.0       0.016339                 0.015718
1                   32.0       0.016572                 0.015964
2                   64.0       0.016182                 0.016192
3                  128.0       0.019373                 0.018621
4                  256.0       0.021856                 0.022463
5                  512.0       0.028943                 0.028888
6                 1024.0       0.041124                 0.041104
7                 2048.0       0.067668                 0.067542
8                 4096.0       0.117528                 0.117447
9                 8192.0       0.216241                 0.215492
10               16384.0       0.413434                 0.414047
11               32768.0       0.811085                 0.810612
12               65536.0       1.606189                 1.606458
13              131071.0       3.193037                 3.192491

prompt-sm80-Phi-3-mini-128k-b4-h32_32x96-fp16:
    sequence_length  ORT-GQA-Dense  ORT-GQA-Dense-PackedQKV
0              16.0       0.019385                 0.019403
1              32.0       0.019801                 0.020006
2              64.0       0.025958                 0.025376
3             128.0       0.056445                 0.055909
4             256.0       0.103180                 0.102221
5             512.0       0.244224                 0.244360
6            1024.0       0.703066                 0.709327
7            2048.0       2.307456                 2.335001
8            4096.0       8.334522                 8.406760
9            8192.0      33.340416                33.758209
10          16384.0     144.141312               145.005569
11          32768.0     655.496216               655.656982
12          65536.0    2981.463135              2984.790039

token-sm80-Phi-3-mini-128k-b4-h32_32_d96-fp16:
    past_sequence_length  ORT-GQA-Dense  ORT-GQA-Dense-PackedQKV
0                   16.0       0.018701                 0.018185
1                   32.0       0.020625                 0.019213
2                   64.0       0.019936                 0.019943
3                  128.0       0.023648                 0.023689
4                  256.0       0.030309                 0.030305
5                  512.0       0.043501                 0.043801
6                 1024.0       0.067314                 0.068014
7                 2048.0       0.108649                 0.108134
8                 4096.0       0.186053                 0.186848
9                 8192.0       0.339973                 0.339742
10               16384.0       0.643288                 0.644366
11               32768.0       1.261468                 1.261510
12               65536.0       2.502252                 2.501820
13              131071.0       4.990437                 4.989521

prompt-sm80-Phi-3-small-128k-b1-h32_8x128-fp16:
    sequence_length  ORT-GQA-Dense  ORT-GQA-Dense-PackedQKV
0              16.0       0.025280                 0.023331
1              32.0       0.023071                 0.025931
2              64.0       0.022883                 0.026258
3             128.0       0.030658                 0.031445
4             256.0       0.057659                 0.057073
5             512.0       0.095589                 0.106579
6            1024.0       0.228532                 0.229402
7            2048.0       0.662315                 0.663349
8            4096.0       2.242885                 2.248095
9            8192.0       8.194646                 8.180395
10          16384.0      33.926659                35.130882
11          32768.0     175.320068               184.967163
12          65536.0     810.447876               847.632385

token-sm80-Phi-3-small-128k-b1-h32_8_d128-fp16:
    past_sequence_length  ORT-GQA-Dense  ORT-GQA-Dense-PackedQKV
0                   16.0       0.015517                 0.016038
1                   32.0       0.016372                 0.015477
2                   64.0       0.015472                 0.016016
3                  128.0       0.019291                 0.018664
4                  256.0       0.036250                 0.035990
5                  512.0       0.041691                 0.042238
6                 1024.0       0.053730                 0.053126
7                 2048.0       0.075912                 0.076439
8                 4096.0       0.121336                 0.121334
9                 8192.0       0.213104                 0.212443
10               16384.0       0.394353                 0.394272
11               32768.0       0.756965                 0.757017
12               65536.0       1.484548                 1.485371
13              131071.0       2.939200                 2.939552

prompt-sm80-Phi-3-small-128k-b4-h32_8x128-fp16:
    sequence_length  ORT-GQA-Dense  ORT-GQA-Dense-PackedQKV
0              16.0       0.044326                 0.019298
1              32.0       0.021840                 0.021408
2              64.0       0.027492                 0.027802
3             128.0       0.058128                 0.059431
4             256.0       0.104300                 0.106019
5             512.0       0.242562                 0.244948
6            1024.0       0.689614                 0.692305
7            2048.0       2.297931                 2.312857
8            4096.0       8.654848                 8.843170
9            8192.0      38.770176                40.929279
10          16384.0     175.572998               183.692291
11          32768.0     780.126221               820.551697
12          65536.0    3357.564941              3488.527344

token-sm80-Phi-3-small-128k-b4-h32_8_d128-fp16:
    past_sequence_length  ORT-GQA-Dense  ORT-GQA-Dense-PackedQKV
0                   16.0       0.018061                 0.017995
1                   32.0       0.018225                 0.018851
2                   64.0       0.018203                 0.018104
3                  128.0       0.023161                 0.023651
4                  256.0       0.038421                 0.037673
5                  512.0       0.047590                 0.046938
6                 1024.0       0.065639                 0.066055
7                 2048.0       0.103545                 0.103581
8                 4096.0       0.180461                 0.179998
9                 8192.0       0.332667                 0.332564
10               16384.0       0.638503                 0.639094
11               32768.0       1.249180                 1.249479
12               65536.0       2.469457                 2.471666
13              131071.0       4.915362                 4.914499

prompt-sm80-Phi-3-medium-128K-b1-h40_10x128-fp16:
    sequence_length  ORT-GQA-Dense  ORT-GQA-Dense-PackedQKV
0              16.0       0.025759                 0.016318
1              32.0       0.018282                 0.018111
2              64.0       0.022642                 0.022978
3             128.0       0.030860                 0.037988
4             256.0       0.055703                 0.050318
5             512.0       0.113465                 0.113776
6            1024.0       0.267678                 0.268292
7            2048.0       0.795202                 0.797222
8            4096.0       2.737953                 2.740435
9            8192.0      10.101760                10.149092
10          16384.0      43.326466                43.990013
11          32768.0     230.886398               229.886978
12          65536.0    1067.412476              1052.922852

token-sm80-Phi-3-medium-128K-b1-h40_10_d128-fp16:
    past_sequence_length  ORT-GQA-Dense  ORT-GQA-Dense-PackedQKV
0                   16.0       0.016122                 0.015582
1                   32.0       0.015594                 0.016262
2                   64.0       0.016099                 0.015512
3                  128.0       0.018708                 0.019510
4                  256.0       0.037582                 0.036341
5                  512.0       0.042411                 0.041894
6                 1024.0       0.053278                 0.053914
7                 2048.0       0.076553                 0.076636
8                 4096.0       0.121539                 0.121610
9                 8192.0       0.212083                 0.212377
10               16384.0       0.395086                 0.395280
11               32768.0       0.757879                 0.757888
12               65536.0       1.486093                 1.486915
13              131071.0       2.941728                 2.941408

prompt-sm80-Phi-3-medium-128K-b4-h40_10x128-fp16:
    sequence_length  ORT-GQA-Dense  ORT-GQA-Dense-PackedQKV
0              16.0       0.019448                 0.018872
1              32.0       0.022290                 0.022380
2              64.0       0.027986                 0.027955
3             128.0       0.062699                 0.062175
4             256.0       0.124868                 0.125247
5             512.0       0.298873                 0.298169
6            1024.0       0.862584                 0.863467
7            2048.0       2.944640                 2.957824
8            4096.0      11.318656                11.390720
9            8192.0      52.606976                52.019199
10          16384.0     232.616959               230.360062
11          32768.0    1024.171997              1019.540466
12          65536.0    4377.362305              4354.510742

token-sm80-Phi-3-medium-128K-b4-h40_10_d128-fp16:
    past_sequence_length  ORT-GQA-Dense  ORT-GQA-Dense-PackedQKV
0                   16.0       0.018192                 0.018175
1                   32.0       0.018999                 0.018319
2                   64.0       0.018447                 0.018897
3                  128.0       0.023863                 0.023195
4                  256.0       0.037712                 0.038192
5                  512.0       0.048863                 0.048548
6                 1024.0       0.067244                 0.066473
7                 2048.0       0.105203                 0.105021
8                 4096.0       0.180712                 0.180429
9                 8192.0       0.334948                 0.334734
10               16384.0       0.640662                 0.639709
11               32768.0       1.252196                 1.251684
12               65536.0       2.474927                 2.474280
13              131071.0       4.930829                 4.959340
```
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
Tianlei Wu 2024-05-10 14:14:15 -07:00 committed by GitHub
parent cfe830b248
commit 85facd678b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1,302 +1,75 @@
import math
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""
Benchmark performance of GroupQueryAttention.
"""
from typing import Optional
import torch
from onnx import TensorProto, helper
from onnxruntime import InferenceSession, SessionOptions
from onnxruntime.transformers.io_binding_helper import CudaSession, GpuBindingManager
from test_sparse_attention import GroupQueryAttentionConfig, OrtGroupQueryAttention
class AttentionConfig:
def __init__(
self,
operator: str,
batch_size: int,
sequence_length: int,
max_sequence_length: int,
past_sequence_length: int,
num_heads: int,
kv_num_heads: int,
head_size: int,
softmax_scale: Optional[float],
do_rotary: bool,
rotary_interleaved: bool,
device="cuda",
dtype=torch.float16,
share_buffer: bool = True,
is_packed_qkv: bool = False,
):
self.operator = operator
self.batch_size = batch_size
self.sequence_length = sequence_length
self.max_sequence_length = max_sequence_length
self.past_sequence_length = past_sequence_length
self.num_heads = num_heads
self.kv_num_heads = kv_num_heads
self.head_size = head_size
self.softmax_scale = softmax_scale if softmax_scale is not None else 1.0 / (head_size**0.5)
# Derived values
self.total_sequence_length = sequence_length + past_sequence_length
self.past_buffer_length = max_sequence_length if share_buffer else past_sequence_length
self.present_buffer_length = max_sequence_length if share_buffer else (past_sequence_length + sequence_length)
self.do_rotary = do_rotary
self.rotary_interleaved = rotary_interleaved
self.device = device
self.share_buffer = share_buffer
self.is_packed_qkv = is_packed_qkv
self.dtype = dtype
def shape_dict(self):
return {
"query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size),
"key": (self.batch_size, self.sequence_length, self.kv_num_heads * self.head_size),
"value": (self.batch_size, self.sequence_length, self.kv_num_heads * self.head_size),
"past_key": (self.batch_size, self.kv_num_heads, self.past_buffer_length, self.head_size),
"past_value": (self.batch_size, self.kv_num_heads, self.past_buffer_length, self.head_size),
"total_sequence_length": (1,),
"output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size),
"present_key": (self.batch_size, self.kv_num_heads, self.present_buffer_length, self.head_size),
"present_value": (self.batch_size, self.kv_num_heads, self.present_buffer_length, self.head_size),
"cos_cache": (self.max_sequence_length, (math.floor(self.head_size / 16) * 16) // 2),
"sin_cache": (self.max_sequence_length, (math.floor(self.head_size / 16) * 16) // 2),
}
def get_cos_sin_cache(self, dtype):
rotary_fraction = 1.0
rotary_dim = math.floor(int(rotary_fraction * self.head_size) / 16) * 16
angle = torch.rand(self.max_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi
cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype)
return cos.to(device=self.device), sin.to(device=self.device)
def random_inputs(self):
device = self.device
# bfloat16 is not supported in ORT python I/O binding API
dtype = torch.float16
shape_dict = self.shape_dict()
torch.manual_seed(123)
feeds = {
"query": torch.empty(shape_dict["query"], device=device, dtype=dtype).normal_(mean=0, std=0.1),
"key": torch.empty(shape_dict["key"], device=device, dtype=dtype).normal_(mean=0, std=0.1),
"value": torch.empty(shape_dict["value"], device=device, dtype=dtype).normal_(mean=0, std=0.1),
"past_key": torch.empty(shape_dict["past_key"], device=device, dtype=dtype).normal_(mean=0, std=0.1),
"past_value": torch.empty(shape_dict["past_value"], device=device, dtype=dtype).normal_(mean=0, std=0.1),
"total_sequence_length": torch.tensor([self.total_sequence_length], dtype=torch.int32),
}
if self.do_rotary:
cos_cache, sin_cache = self.get_cos_sin_cache(dtype)
feeds["cos_cache"] = cos_cache
feeds["sin_cache"] = sin_cache
return feeds
class GroupQueryAttentionConfig(AttentionConfig):
def __init__(
self,
batch_size: int,
sequence_length: int,
max_sequence_length: int,
past_sequence_length: int,
num_heads: int,
kv_num_heads: int,
head_size: int,
softmax_scale=None,
do_rotary: bool = False,
rotary_interleaved: bool = False,
device="cuda",
local_window_size: int = -1,
):
super().__init__(
"GroupQueryAttention",
batch_size,
sequence_length,
max_sequence_length,
past_sequence_length,
num_heads,
kv_num_heads,
head_size,
softmax_scale,
do_rotary,
rotary_interleaved,
device,
)
self.local_window_size = local_window_size
def shape_dict(self):
shapes = super().shape_dict()
shapes.update(
{
"seqlens_k": (self.batch_size,),
}
)
return shapes
def random_inputs(self):
feeds = super().random_inputs()
k_seqlens = torch.ones((self.batch_size,), device=self.device, dtype=torch.int32) * self.total_sequence_length
feeds.update(
{
"seqlens_k": k_seqlens - 1,
}
)
return feeds
def create_group_query_attention_onnx_model(config: GroupQueryAttentionConfig):
assert config.dtype == torch.float16
float_type = TensorProto.FLOAT16
nodes = [
helper.make_node(
"GroupQueryAttention",
[
"query",
"key" if not config.is_packed_qkv else "",
"value" if not config.is_packed_qkv else "",
"past_key",
"past_value",
"seqlens_k",
"total_sequence_length" if config.share_buffer else "",
"cos_cache" if config.do_rotary else "",
"sin_cache" if config.do_rotary else "",
],
["output", "present_key", "present_value"],
"GroupQueryAttention_0",
num_heads=config.num_heads,
kv_num_heads=config.kv_num_heads,
scale=config.softmax_scale,
local_window_size=config.local_window_size,
do_rotary=1 if config.do_rotary else 0,
rotary_interleaved=config.rotary_interleaved,
domain="com.microsoft",
),
]
shape_dict = config.shape_dict()
graph_input = [
helper.make_tensor_value_info("query", float_type, list(shape_dict["query"])),
helper.make_tensor_value_info("key", float_type, list(shape_dict["key"])),
helper.make_tensor_value_info("value", float_type, list(shape_dict["value"])),
helper.make_tensor_value_info("past_key", float_type, list(shape_dict["past_key"])),
helper.make_tensor_value_info("past_value", float_type, list(shape_dict["past_value"])),
helper.make_tensor_value_info("seqlens_k", TensorProto.INT32, list(shape_dict["seqlens_k"])),
helper.make_tensor_value_info(
"total_sequence_length", TensorProto.INT32, list(shape_dict["total_sequence_length"])
),
]
if config.do_rotary:
graph_input += [
helper.make_tensor_value_info("cos_cache", float_type, list(shape_dict["cos_cache"])),
helper.make_tensor_value_info("sin_cache", float_type, list(shape_dict["sin_cache"])),
]
graph_output = [
helper.make_tensor_value_info("output", float_type, list(shape_dict["output"])),
helper.make_tensor_value_info("present_key", float_type, list(shape_dict["present_key"])),
helper.make_tensor_value_info("present_value", float_type, list(shape_dict["present_value"])),
]
graph = helper.make_graph(
nodes,
"GroupQueryAttention_Graph",
graph_input,
graph_output,
)
model = helper.make_model(graph)
return model.SerializeToString()
def create_session(onnx_model_str, cuda_provider_options=None) -> InferenceSession:
session_options = SessionOptions()
ort_session = InferenceSession(
onnx_model_str,
session_options,
providers=[("CUDAExecutionProvider", cuda_provider_options), "CPUExecutionProvider"],
)
return ort_session
class OrtGroupQueryAttention:
"""A wrapper of ORT GroupQueryAttention to test relevance and performance."""
def __init__(self, config: GroupQueryAttentionConfig):
device = config.device
cuda_provider_options = CudaSession.get_cuda_provider_options(
torch.cuda.current_device(), enable_cuda_graph=False, stream=torch.cuda.current_stream().cuda_stream
)
onnx_model_str = create_group_query_attention_onnx_model(config)
self.ort_session = create_session(onnx_model_str, cuda_provider_options=cuda_provider_options)
self.gpu_binding_manager = GpuBindingManager(
ort_session=self.ort_session,
device=device,
stream=torch.cuda.current_stream().cuda_stream,
max_cuda_graphs=2,
)
buffer_sharing = {"past_key": "present_key", "past_value": "present_value"}
self.gpu_binding = self.gpu_binding_manager.get_binding(
config.shape_dict(), use_cuda_graph=False, buffer_sharing=buffer_sharing
)
self.feed_dict = config.random_inputs()
def infer(self):
return self.gpu_binding.infer(self.feed_dict)
def get_plot_algos(sm: int):
def get_plot_algos(sm: int, local_window_size: Optional[int]):
# GQA with local windows only works in sm=8x
if sm >= 80:
if sm >= 80 and local_window_size:
return {
"line_vals": ["ort_gqa", "ort_gqa_local"],
"line_names": ["ORT-GQA-Dense", "ORT-GQA-Local"],
"styles": [("red", "-"), ("blue", "-")],
"line_vals": ["ort_gqa", "ort_gqa_local", "ort_gqa_packed", "ort_gqa_local_packed"],
"line_names": ["ORT-GQA-Dense", "ORT-GQA-Local", "ORT-GQA-Dense-PackedQKV", "ORT-GQA-Local-PackedQKV"],
"styles": [("red", "solid"), ("yellow", "dashdot"), ("blue", "dashed"), ("green", "dotted")],
}
else:
return {
"line_vals": ["ort_gqa"],
"line_names": ["ORT-GQA-Dense"],
"styles": [("green", "-")],
"line_vals": ["ort_gqa", "ort_gqa_packed"],
"line_names": ["ORT-GQA-Dense", "ORT-GQA-Dense-PackedQKV"],
"styles": [("red", "solid"), ("blue", "dashed")],
}
def plot_prompt_performance(
sm: int,
batch_size=4,
num_heads=32,
kv_num_heads=8,
max_seq_len=8192,
head_size=128,
model_name: str,
batch_size: int,
num_heads: int,
kv_num_heads: int,
head_size: int,
max_seq_len: int,
local_window_size: Optional[int] = None,
):
import triton
algos = get_plot_algos(sm)
algos = get_plot_algos(sm, local_window_size)
configs = [
triton.testing.Benchmark(
x_names=["sequence_length"],
x_vals=[2**i for i in range(4, 14)],
x_vals=[2**i for i in range(4, 17) if 2**i <= max_seq_len],
line_arg="provider",
ylabel="ms",
**algos,
plot_name=f"prompt-sm{sm}-batch{batch_size}-head{num_heads}_kv{kv_num_heads}-d{head_size}-fp16",
plot_name=f"prompt-sm{sm}-{model_name}-b{batch_size}-h{num_heads}_{kv_num_heads}x{head_size}-fp16",
args={
"batch_size": batch_size,
"num_heads": num_heads,
"kv_num_heads": kv_num_heads,
"batch_size": batch_size,
"head_size": head_size,
"local_window_size": local_window_size,
},
)
]
@triton.testing.perf_report(configs)
def benchmark(batch_size, num_heads, kv_num_heads, sequence_length, head_size, provider, device="cuda"):
def benchmark(
provider: str,
sequence_length: int,
batch_size: int,
num_heads: int,
kv_num_heads: int,
head_size: int,
local_window_size: Optional[int] = None,
device="cuda",
):
warmup = 15
repeat = 100
@ -308,8 +81,9 @@ def plot_prompt_performance(
num_heads=num_heads,
kv_num_heads=kv_num_heads,
head_size=head_size,
local_window_size=1024 if provider == "ort_gqa_local" else -1,
local_window_size=local_window_size if provider in ["ort_gqa_local", "ort_gqa_local_packed"] else -1,
device=device,
is_packed_qkv=provider in ["ort_gqa_packed", "ort_gqa_local_packed"],
)
obj = OrtGroupQueryAttention(config)
@ -322,40 +96,44 @@ def plot_prompt_performance(
def plot_token_performance(
sm: int,
batch_size=4,
num_heads=32,
kv_num_heads=8,
max_seq_len=8192,
head_size=128,
model_name: str,
batch_size: int,
num_heads: int,
kv_num_heads: int,
head_size: int,
max_seq_len: int,
local_window_size: Optional[int] = None,
):
import triton
algos = get_plot_algos(sm)
algos = get_plot_algos(sm, local_window_size)
configs = [
triton.testing.Benchmark(
x_names=["past_sequence_length"],
x_vals=[2**i for i in range(4, 13)] + [max_seq_len - 1],
x_vals=[2**i for i in range(4, 17) if 2**i < max_seq_len] + [max_seq_len - 1],
line_arg="provider",
ylabel="ms",
**algos,
plot_name=f"token-sm{sm}-batch{batch_size}-head{num_heads}_kv{kv_num_heads}-d{head_size}-fp16",
plot_name=f"token-sm{sm}-{model_name}-b{batch_size}-h{num_heads}_{kv_num_heads}_d{head_size}-fp16",
args={
"batch_size": batch_size,
"num_heads": num_heads,
"kv_num_heads": kv_num_heads,
"batch_size": batch_size,
"head_size": head_size,
"local_window_size": local_window_size,
},
)
]
@triton.testing.perf_report(configs)
def benchmark(
batch_size,
num_heads,
kv_num_heads,
past_sequence_length,
head_size,
provider,
provider: str,
past_sequence_length: int,
batch_size: int,
num_heads: int,
kv_num_heads: int,
head_size: int,
local_window_size: Optional[int] = None,
device="cuda",
):
warmup = 15
@ -369,7 +147,9 @@ def plot_token_performance(
num_heads=num_heads,
kv_num_heads=kv_num_heads,
head_size=head_size,
local_window_size=1024 if provider == "ort_gqa_local" else -1,
local_window_size=local_window_size if provider in ["ort_gqa_local", "ort_gqa_local_packed"] else -1,
do_rotary=True, # Most models use rotary positional embeddings
is_packed_qkv=provider in ["ort_gqa_packed", "ort_gqa_local_packed"],
device=device,
)
@ -386,25 +166,46 @@ def run_performance_test(sm: int):
Run performance tests for prompt and token generation.
"""
for batch_size in [1, 4, 8, 16]:
for num_heads, kv_num_heads in [(8, 8), (16, 8), (32, 8), (64, 8)]:
for head_size in [64, 128]:
plot_prompt_performance(
sm=sm,
batch_size=batch_size,
num_heads=num_heads,
kv_num_heads=kv_num_heads,
max_seq_len=8192,
head_size=head_size,
)
plot_token_performance(
sm=sm,
batch_size=batch_size,
num_heads=num_heads,
kv_num_heads=kv_num_heads,
max_seq_len=8192,
head_size=head_size,
)
device_id = torch.cuda.current_device()
memory_in_gb = torch.cuda.get_device_properties(device_id).total_memory / (1024 * 1024 * 1024)
# Note: some models use bf16.
# We use fp16 for all models in this test since bf16 is not supported in ORT python API.
configures = [
(32, 128, 8, 8192, None, "Llama3-8B"),
(64, 128, 8, 8192, None, "Llama3-70B"),
(32, 128, 8, 32768, 4096, "Mistral-7B-v0.1"),
(48, 128, 8, 65536, None, "Mixtral-8x22B-v0.1"),
(32, 96, 32, 131072, None, "Phi-3-mini-128k"),
(32, 128, 8, 131072, None, "Phi-3-small-128k"), # Sparsity is not used in this test
(40, 128, 10, 131072, None, "Phi-3-medium-128K"),
]
# Reduce max sequence length when GPU memory is not enough.
threshold = 131072 if memory_in_gb > 24 else 65536 if memory_in_gb > 12 else 32768
for num_heads, head_size, kv_num_heads, max_seq_len, local_window_size, model_name in configures:
for batch_size in [1, 4]:
plot_prompt_performance(
sm=sm,
batch_size=batch_size,
num_heads=num_heads,
kv_num_heads=kv_num_heads,
head_size=head_size,
max_seq_len=min(threshold, max_seq_len),
local_window_size=local_window_size,
model_name=model_name,
)
plot_token_performance(
sm=sm,
batch_size=batch_size,
num_heads=num_heads,
kv_num_heads=kv_num_heads,
head_size=head_size,
max_seq_len=min(threshold, max_seq_len),
local_window_size=local_window_size,
model_name=model_name,
)
if __name__ == "__main__":