From 85facd678bbd3835cc8acbaffbaa78a0c9224200 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 10 May 2024 14:14:15 -0700 Subject: [PATCH] [CUDA] Benchmark GQA on popular LLM models (#20646) ### Description Update benchmark_gqa.py to test latency on popular models (like Llama3-8b, Llama3-70b, Mixtral-8x22B-v0.1 and Phi-3 etc). Note that this is latency of just one GroupQueryAttention node, not the whole model. For example, packed QKV might need more time in GQA, but it is faster in MatMul of input projection, the overall effect is not measured here. Example output in A100-SXM4-80GB : ``` prompt-sm80-Llama3-8B-b1-h32_8x128-fp16: sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.019073 0.016264 1 32.0 0.017768 0.017957 2 64.0 0.023304 0.023192 3 128.0 0.032541 0.031348 4 256.0 0.048329 0.049484 5 512.0 0.095294 0.095950 6 1024.0 0.228050 0.228980 7 2048.0 0.663820 0.663308 8 4096.0 2.243657 2.242999 9 8192.0 8.197120 8.186282 token-sm80-Llama3-8B-b1-h32_8_d128-fp16: past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.018516 0.015398 1 32.0 0.015687 0.016079 2 64.0 0.016115 0.016053 3 128.0 0.018727 0.019413 4 256.0 0.036373 0.035962 5 512.0 0.041701 0.042203 6 1024.0 0.053730 0.053750 7 2048.0 0.076382 0.075707 8 4096.0 0.121876 0.121802 9 8191.0 0.211292 0.211254 prompt-sm80-Llama3-8B-b4-h32_8x128-fp16: sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.024558 0.022070 1 32.0 0.021276 0.021406 2 64.0 0.044172 0.027789 3 128.0 0.069100 0.059071 4 256.0 0.146569 0.106717 5 512.0 0.270472 0.244461 6 1024.0 0.690024 0.692501 7 2048.0 2.308546 2.325453 8 4096.0 8.724295 8.957337 9 8192.0 39.030785 41.381378 token-sm80-Llama3-8B-b4-h32_8_d128-fp16: past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.018893 0.018611 1 32.0 0.018124 0.018190 2 64.0 0.018115 0.018156 3 128.0 0.023291 0.023733 4 256.0 0.038357 0.038351 5 512.0 0.047117 0.047792 6 1024.0 0.066272 0.065409 7 2048.0 0.104196 0.104527 8 4096.0 0.180557 0.180424 9 8191.0 0.332545 0.332714 prompt-sm80-Llama3-70B-b1-h64_8x128-fp16: sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.040974 0.015852 1 32.0 0.017839 0.018615 2 64.0 0.023956 0.022704 3 128.0 0.044622 0.035229 4 256.0 0.080241 0.075237 5 512.0 0.143457 0.144322 6 1024.0 0.380473 0.381731 7 2048.0 1.217328 1.214505 8 4096.0 4.305315 4.286324 9 8192.0 15.918250 15.933440 token-sm80-Llama3-70B-b1-h64_8_d128-fp16: past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.016148 0.015612 1 32.0 0.015616 0.015616 2 64.0 0.016082 0.016070 3 128.0 0.019470 0.019130 4 256.0 0.036617 0.037296 5 512.0 0.042087 0.042176 6 1024.0 0.053704 0.053587 7 2048.0 0.076918 0.076365 8 4096.0 0.122534 0.121984 9 8191.0 0.212961 0.213330 prompt-sm80-Llama3-70B-b4-h64_8x128-fp16: sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.031137 0.026270 1 32.0 0.030938 0.032009 2 64.0 0.040833 0.059118 3 128.0 0.084899 0.085482 4 256.0 0.163951 0.166310 5 512.0 0.420436 0.423721 6 1024.0 1.282019 1.283482 7 2048.0 4.397661 4.420121 8 4096.0 16.931839 17.456945 9 8192.0 77.896706 83.007484 token-sm80-Llama3-70B-b4-h64_8_d128-fp16: past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.026106 0.026061 1 32.0 0.025678 0.025589 2 64.0 0.025438 0.025965 3 128.0 0.033879 0.033320 4 256.0 0.058078 0.057656 5 512.0 0.078010 0.078153 6 1024.0 0.106353 0.098079 7 2048.0 0.160039 0.159153 8 4096.0 0.282527 0.283346 9 8191.0 0.546207 0.542135 prompt-sm80-Mistral-7B-v0.1-b1-h32_8x128-fp16: sequence_length ORT-GQA-Dense ORT-GQA-Local ORT-GQA-Dense-PackedQKV ORT-GQA-Local-PackedQKV 0 16.0 0.015722 0.015655 0.015666 0.016150 1 32.0 0.018590 0.018562 0.018136 0.024617 2 64.0 0.022480 0.023085 0.023184 0.023160 3 128.0 0.029948 0.030581 0.030839 0.031464 4 256.0 0.048532 0.049099 0.049424 0.049408 5 512.0 0.095096 0.095665 0.096174 0.096175 6 1024.0 0.228606 0.228942 0.228434 0.229568 7 2048.0 0.660832 0.661943 0.662170 0.663979 8 4096.0 2.238001 2.243999 2.242243 2.241707 9 8192.0 8.173824 6.147072 8.187648 6.152822 10 16384.0 33.826305 14.486015 34.849792 14.938283 11 32768.0 176.702469 32.725330 184.309753 34.736130 token-sm80-Mistral-7B-v0.1-b1-h32_8_d128-fp16: past_sequence_length ORT-GQA-Dense ORT-GQA-Local ORT-GQA-Dense-PackedQKV ORT-GQA-Local-PackedQKV 0 16.0 0.015407 0.016042 0.016030 0.015429 1 32.0 0.015525 0.016115 0.016768 0.016052 2 64.0 0.015556 0.016079 0.015383 0.016008 3 128.0 0.019302 0.018644 0.018680 0.019278 4 256.0 0.036924 0.035900 0.036753 0.036786 5 512.0 0.041482 0.041434 0.041646 0.042238 6 1024.0 0.053587 0.052972 0.052888 0.052856 7 2048.0 0.075749 0.075807 0.076528 0.075945 8 4096.0 0.122053 0.122016 0.122115 0.122216 9 8192.0 0.212069 0.121317 0.211919 0.121087 10 16384.0 0.394036 0.121202 0.393661 0.121483 11 32767.0 0.757216 0.124326 0.757659 0.124157 prompt-sm80-Mistral-7B-v0.1-b4-h32_8x128-fp16: sequence_length ORT-GQA-Dense ORT-GQA-Local ORT-GQA-Dense-PackedQKV ORT-GQA-Local-PackedQKV 0 16.0 0.018418 0.018911 0.023387 0.019256 1 32.0 0.021085 0.021132 0.022143 0.022251 2 64.0 0.026743 0.026770 0.027942 0.027714 3 128.0 0.057922 0.058483 0.058800 0.059402 4 256.0 0.105927 0.104876 0.106695 0.105996 5 512.0 0.242958 0.242543 0.244599 0.244774 6 1024.0 0.689321 0.689347 0.691759 0.692334 7 2048.0 2.308250 2.304410 2.321587 2.317875 8 4096.0 8.705210 8.713682 8.927418 8.903866 9 8192.0 39.630848 28.227926 41.604607 29.648554 10 16384.0 175.553543 61.422592 183.384064 64.560127 11 32768.0 772.296692 132.006912 813.537292 138.996735 token-sm80-Mistral-7B-v0.1-b4-h32_8_d128-fp16: past_sequence_length ORT-GQA-Dense ORT-GQA-Local ORT-GQA-Dense-PackedQKV ORT-GQA-Local-PackedQKV 0 16.0 0.018127 0.018691 0.018661 0.018681 1 32.0 0.018183 0.018812 0.018739 0.018759 2 64.0 0.018081 0.018116 0.018136 0.018153 3 128.0 0.023257 0.023146 0.023114 0.023103 4 256.0 0.038665 0.038102 0.038120 0.038759 5 512.0 0.047181 0.047156 0.047012 0.046382 6 1024.0 0.066047 0.066103 0.066604 0.066076 7 2048.0 0.104427 0.103770 0.103799 0.103807 8 4096.0 0.180951 0.180373 0.180173 0.180154 9 8192.0 0.334018 0.180801 0.333269 0.180690 10 16384.0 0.638682 0.180965 0.638543 0.180202 11 32767.0 1.249536 0.184779 1.249963 0.184624 prompt-sm80-Mixtral-8x22B-v0.1-b1-h48_8x128-fp16: sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.015699 0.015563 1 32.0 0.017931 0.017719 2 64.0 0.029975 0.022875 3 128.0 0.031038 0.055747 4 256.0 0.050191 0.050845 5 512.0 0.125187 0.122813 6 1024.0 0.304004 0.301824 7 2048.0 0.936454 0.931546 8 4096.0 3.264547 3.255931 9 8192.0 12.062719 12.030080 10 16384.0 49.018368 48.970749 11 32768.0 261.211151 254.461945 12 65536.0 1221.138428 1197.559814 token-sm80-Mixtral-8x22B-v0.1-b1-h48_8_d128-fp16: past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.015980 0.016024 1 32.0 0.015440 0.016165 2 64.0 0.015987 0.015979 3 128.0 0.020837 0.018715 4 256.0 0.036240 0.036747 5 512.0 0.042477 0.041813 6 1024.0 0.052950 0.052956 7 2048.0 0.076084 0.076691 8 4096.0 0.122233 0.121540 9 8192.0 0.212469 0.212433 10 16384.0 0.394937 0.394996 11 32768.0 0.757285 0.757257 12 65535.0 1.484867 1.485015 prompt-sm80-Mixtral-8x22B-v0.1-b4-h48_8x128-fp16: sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.024119 0.018755 1 32.0 0.022214 0.022267 2 64.0 0.028045 0.027562 3 128.0 0.062894 0.079766 4 256.0 0.135146 0.134483 5 512.0 0.331323 0.329094 6 1024.0 0.984576 0.982221 7 2048.0 3.353564 3.351021 8 4096.0 12.762113 12.778350 9 8192.0 58.599422 57.704449 10 16384.0 263.392242 258.709503 11 32768.0 1155.789795 1128.622070 12 65536.0 5014.187012 4874.590332 token-sm80-Mixtral-8x22B-v0.1-b4-h48_8_d128-fp16: past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.018148 0.018813 1 32.0 0.018929 0.018840 2 64.0 0.018745 0.018232 3 128.0 0.023864 0.023822 4 256.0 0.038603 0.038694 5 512.0 0.048347 0.047630 6 1024.0 0.066957 0.067392 7 2048.0 0.105094 0.105058 8 4096.0 0.181941 0.181808 9 8192.0 0.334227 0.334324 10 16384.0 0.640429 0.640961 11 32768.0 1.267897 1.269120 12 65535.0 2.534238 2.504408 prompt-sm80-Phi-3-mini-128k-b1-h32_32x96-fp16: sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.016112 0.026949 1 32.0 0.016486 0.017284 2 64.0 0.020910 0.020994 3 128.0 0.029306 0.029452 4 256.0 0.044604 0.044642 5 512.0 0.090079 0.086868 6 1024.0 0.208169 0.208094 7 2048.0 0.604687 0.607910 8 4096.0 2.029056 2.046771 9 8192.0 7.792128 7.906303 10 16384.0 34.271233 34.418175 11 32768.0 160.377853 159.980545 12 65536.0 733.443054 734.722046 token-sm80-Phi-3-mini-128k-b1-h32_32_d96-fp16: past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.016339 0.015718 1 32.0 0.016572 0.015964 2 64.0 0.016182 0.016192 3 128.0 0.019373 0.018621 4 256.0 0.021856 0.022463 5 512.0 0.028943 0.028888 6 1024.0 0.041124 0.041104 7 2048.0 0.067668 0.067542 8 4096.0 0.117528 0.117447 9 8192.0 0.216241 0.215492 10 16384.0 0.413434 0.414047 11 32768.0 0.811085 0.810612 12 65536.0 1.606189 1.606458 13 131071.0 3.193037 3.192491 prompt-sm80-Phi-3-mini-128k-b4-h32_32x96-fp16: sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.019385 0.019403 1 32.0 0.019801 0.020006 2 64.0 0.025958 0.025376 3 128.0 0.056445 0.055909 4 256.0 0.103180 0.102221 5 512.0 0.244224 0.244360 6 1024.0 0.703066 0.709327 7 2048.0 2.307456 2.335001 8 4096.0 8.334522 8.406760 9 8192.0 33.340416 33.758209 10 16384.0 144.141312 145.005569 11 32768.0 655.496216 655.656982 12 65536.0 2981.463135 2984.790039 token-sm80-Phi-3-mini-128k-b4-h32_32_d96-fp16: past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.018701 0.018185 1 32.0 0.020625 0.019213 2 64.0 0.019936 0.019943 3 128.0 0.023648 0.023689 4 256.0 0.030309 0.030305 5 512.0 0.043501 0.043801 6 1024.0 0.067314 0.068014 7 2048.0 0.108649 0.108134 8 4096.0 0.186053 0.186848 9 8192.0 0.339973 0.339742 10 16384.0 0.643288 0.644366 11 32768.0 1.261468 1.261510 12 65536.0 2.502252 2.501820 13 131071.0 4.990437 4.989521 prompt-sm80-Phi-3-small-128k-b1-h32_8x128-fp16: sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.025280 0.023331 1 32.0 0.023071 0.025931 2 64.0 0.022883 0.026258 3 128.0 0.030658 0.031445 4 256.0 0.057659 0.057073 5 512.0 0.095589 0.106579 6 1024.0 0.228532 0.229402 7 2048.0 0.662315 0.663349 8 4096.0 2.242885 2.248095 9 8192.0 8.194646 8.180395 10 16384.0 33.926659 35.130882 11 32768.0 175.320068 184.967163 12 65536.0 810.447876 847.632385 token-sm80-Phi-3-small-128k-b1-h32_8_d128-fp16: past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.015517 0.016038 1 32.0 0.016372 0.015477 2 64.0 0.015472 0.016016 3 128.0 0.019291 0.018664 4 256.0 0.036250 0.035990 5 512.0 0.041691 0.042238 6 1024.0 0.053730 0.053126 7 2048.0 0.075912 0.076439 8 4096.0 0.121336 0.121334 9 8192.0 0.213104 0.212443 10 16384.0 0.394353 0.394272 11 32768.0 0.756965 0.757017 12 65536.0 1.484548 1.485371 13 131071.0 2.939200 2.939552 prompt-sm80-Phi-3-small-128k-b4-h32_8x128-fp16: sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.044326 0.019298 1 32.0 0.021840 0.021408 2 64.0 0.027492 0.027802 3 128.0 0.058128 0.059431 4 256.0 0.104300 0.106019 5 512.0 0.242562 0.244948 6 1024.0 0.689614 0.692305 7 2048.0 2.297931 2.312857 8 4096.0 8.654848 8.843170 9 8192.0 38.770176 40.929279 10 16384.0 175.572998 183.692291 11 32768.0 780.126221 820.551697 12 65536.0 3357.564941 3488.527344 token-sm80-Phi-3-small-128k-b4-h32_8_d128-fp16: past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.018061 0.017995 1 32.0 0.018225 0.018851 2 64.0 0.018203 0.018104 3 128.0 0.023161 0.023651 4 256.0 0.038421 0.037673 5 512.0 0.047590 0.046938 6 1024.0 0.065639 0.066055 7 2048.0 0.103545 0.103581 8 4096.0 0.180461 0.179998 9 8192.0 0.332667 0.332564 10 16384.0 0.638503 0.639094 11 32768.0 1.249180 1.249479 12 65536.0 2.469457 2.471666 13 131071.0 4.915362 4.914499 prompt-sm80-Phi-3-medium-128K-b1-h40_10x128-fp16: sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.025759 0.016318 1 32.0 0.018282 0.018111 2 64.0 0.022642 0.022978 3 128.0 0.030860 0.037988 4 256.0 0.055703 0.050318 5 512.0 0.113465 0.113776 6 1024.0 0.267678 0.268292 7 2048.0 0.795202 0.797222 8 4096.0 2.737953 2.740435 9 8192.0 10.101760 10.149092 10 16384.0 43.326466 43.990013 11 32768.0 230.886398 229.886978 12 65536.0 1067.412476 1052.922852 token-sm80-Phi-3-medium-128K-b1-h40_10_d128-fp16: past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.016122 0.015582 1 32.0 0.015594 0.016262 2 64.0 0.016099 0.015512 3 128.0 0.018708 0.019510 4 256.0 0.037582 0.036341 5 512.0 0.042411 0.041894 6 1024.0 0.053278 0.053914 7 2048.0 0.076553 0.076636 8 4096.0 0.121539 0.121610 9 8192.0 0.212083 0.212377 10 16384.0 0.395086 0.395280 11 32768.0 0.757879 0.757888 12 65536.0 1.486093 1.486915 13 131071.0 2.941728 2.941408 prompt-sm80-Phi-3-medium-128K-b4-h40_10x128-fp16: sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.019448 0.018872 1 32.0 0.022290 0.022380 2 64.0 0.027986 0.027955 3 128.0 0.062699 0.062175 4 256.0 0.124868 0.125247 5 512.0 0.298873 0.298169 6 1024.0 0.862584 0.863467 7 2048.0 2.944640 2.957824 8 4096.0 11.318656 11.390720 9 8192.0 52.606976 52.019199 10 16384.0 232.616959 230.360062 11 32768.0 1024.171997 1019.540466 12 65536.0 4377.362305 4354.510742 token-sm80-Phi-3-medium-128K-b4-h40_10_d128-fp16: past_sequence_length ORT-GQA-Dense ORT-GQA-Dense-PackedQKV 0 16.0 0.018192 0.018175 1 32.0 0.018999 0.018319 2 64.0 0.018447 0.018897 3 128.0 0.023863 0.023195 4 256.0 0.037712 0.038192 5 512.0 0.048863 0.048548 6 1024.0 0.067244 0.066473 7 2048.0 0.105203 0.105021 8 4096.0 0.180712 0.180429 9 8192.0 0.334948 0.334734 10 16384.0 0.640662 0.639709 11 32768.0 1.252196 1.251684 12 65536.0 2.474927 2.474280 13 131071.0 4.930829 4.959340 ``` ### Motivation and Context --- .../test/python/transformers/benchmark_gqa.py | 405 +++++------------- 1 file changed, 103 insertions(+), 302 deletions(-) diff --git a/onnxruntime/test/python/transformers/benchmark_gqa.py b/onnxruntime/test/python/transformers/benchmark_gqa.py index 7fcd56bb8f..5e028519b9 100644 --- a/onnxruntime/test/python/transformers/benchmark_gqa.py +++ b/onnxruntime/test/python/transformers/benchmark_gqa.py @@ -1,302 +1,75 @@ -import math +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +""" +Benchmark performance of GroupQueryAttention. +""" from typing import Optional import torch -from onnx import TensorProto, helper - -from onnxruntime import InferenceSession, SessionOptions -from onnxruntime.transformers.io_binding_helper import CudaSession, GpuBindingManager +from test_sparse_attention import GroupQueryAttentionConfig, OrtGroupQueryAttention -class AttentionConfig: - def __init__( - self, - operator: str, - batch_size: int, - sequence_length: int, - max_sequence_length: int, - past_sequence_length: int, - num_heads: int, - kv_num_heads: int, - head_size: int, - softmax_scale: Optional[float], - do_rotary: bool, - rotary_interleaved: bool, - device="cuda", - dtype=torch.float16, - share_buffer: bool = True, - is_packed_qkv: bool = False, - ): - self.operator = operator - self.batch_size = batch_size - self.sequence_length = sequence_length - self.max_sequence_length = max_sequence_length - self.past_sequence_length = past_sequence_length - self.num_heads = num_heads - self.kv_num_heads = kv_num_heads - self.head_size = head_size - self.softmax_scale = softmax_scale if softmax_scale is not None else 1.0 / (head_size**0.5) - - # Derived values - self.total_sequence_length = sequence_length + past_sequence_length - self.past_buffer_length = max_sequence_length if share_buffer else past_sequence_length - self.present_buffer_length = max_sequence_length if share_buffer else (past_sequence_length + sequence_length) - - self.do_rotary = do_rotary - self.rotary_interleaved = rotary_interleaved - self.device = device - - self.share_buffer = share_buffer - self.is_packed_qkv = is_packed_qkv - self.dtype = dtype - - def shape_dict(self): - return { - "query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), - "key": (self.batch_size, self.sequence_length, self.kv_num_heads * self.head_size), - "value": (self.batch_size, self.sequence_length, self.kv_num_heads * self.head_size), - "past_key": (self.batch_size, self.kv_num_heads, self.past_buffer_length, self.head_size), - "past_value": (self.batch_size, self.kv_num_heads, self.past_buffer_length, self.head_size), - "total_sequence_length": (1,), - "output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), - "present_key": (self.batch_size, self.kv_num_heads, self.present_buffer_length, self.head_size), - "present_value": (self.batch_size, self.kv_num_heads, self.present_buffer_length, self.head_size), - "cos_cache": (self.max_sequence_length, (math.floor(self.head_size / 16) * 16) // 2), - "sin_cache": (self.max_sequence_length, (math.floor(self.head_size / 16) * 16) // 2), - } - - def get_cos_sin_cache(self, dtype): - rotary_fraction = 1.0 - rotary_dim = math.floor(int(rotary_fraction * self.head_size) / 16) * 16 - angle = torch.rand(self.max_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi - cos = torch.cos(angle).to(dtype=dtype) - sin = torch.sin(angle).to(dtype=dtype) - return cos.to(device=self.device), sin.to(device=self.device) - - def random_inputs(self): - device = self.device - # bfloat16 is not supported in ORT python I/O binding API - dtype = torch.float16 - shape_dict = self.shape_dict() - - torch.manual_seed(123) - feeds = { - "query": torch.empty(shape_dict["query"], device=device, dtype=dtype).normal_(mean=0, std=0.1), - "key": torch.empty(shape_dict["key"], device=device, dtype=dtype).normal_(mean=0, std=0.1), - "value": torch.empty(shape_dict["value"], device=device, dtype=dtype).normal_(mean=0, std=0.1), - "past_key": torch.empty(shape_dict["past_key"], device=device, dtype=dtype).normal_(mean=0, std=0.1), - "past_value": torch.empty(shape_dict["past_value"], device=device, dtype=dtype).normal_(mean=0, std=0.1), - "total_sequence_length": torch.tensor([self.total_sequence_length], dtype=torch.int32), - } - - if self.do_rotary: - cos_cache, sin_cache = self.get_cos_sin_cache(dtype) - feeds["cos_cache"] = cos_cache - feeds["sin_cache"] = sin_cache - - return feeds - - -class GroupQueryAttentionConfig(AttentionConfig): - def __init__( - self, - batch_size: int, - sequence_length: int, - max_sequence_length: int, - past_sequence_length: int, - num_heads: int, - kv_num_heads: int, - head_size: int, - softmax_scale=None, - do_rotary: bool = False, - rotary_interleaved: bool = False, - device="cuda", - local_window_size: int = -1, - ): - super().__init__( - "GroupQueryAttention", - batch_size, - sequence_length, - max_sequence_length, - past_sequence_length, - num_heads, - kv_num_heads, - head_size, - softmax_scale, - do_rotary, - rotary_interleaved, - device, - ) - self.local_window_size = local_window_size - - def shape_dict(self): - shapes = super().shape_dict() - shapes.update( - { - "seqlens_k": (self.batch_size,), - } - ) - return shapes - - def random_inputs(self): - feeds = super().random_inputs() - k_seqlens = torch.ones((self.batch_size,), device=self.device, dtype=torch.int32) * self.total_sequence_length - feeds.update( - { - "seqlens_k": k_seqlens - 1, - } - ) - return feeds - - -def create_group_query_attention_onnx_model(config: GroupQueryAttentionConfig): - assert config.dtype == torch.float16 - - float_type = TensorProto.FLOAT16 - nodes = [ - helper.make_node( - "GroupQueryAttention", - [ - "query", - "key" if not config.is_packed_qkv else "", - "value" if not config.is_packed_qkv else "", - "past_key", - "past_value", - "seqlens_k", - "total_sequence_length" if config.share_buffer else "", - "cos_cache" if config.do_rotary else "", - "sin_cache" if config.do_rotary else "", - ], - ["output", "present_key", "present_value"], - "GroupQueryAttention_0", - num_heads=config.num_heads, - kv_num_heads=config.kv_num_heads, - scale=config.softmax_scale, - local_window_size=config.local_window_size, - do_rotary=1 if config.do_rotary else 0, - rotary_interleaved=config.rotary_interleaved, - domain="com.microsoft", - ), - ] - - shape_dict = config.shape_dict() - graph_input = [ - helper.make_tensor_value_info("query", float_type, list(shape_dict["query"])), - helper.make_tensor_value_info("key", float_type, list(shape_dict["key"])), - helper.make_tensor_value_info("value", float_type, list(shape_dict["value"])), - helper.make_tensor_value_info("past_key", float_type, list(shape_dict["past_key"])), - helper.make_tensor_value_info("past_value", float_type, list(shape_dict["past_value"])), - helper.make_tensor_value_info("seqlens_k", TensorProto.INT32, list(shape_dict["seqlens_k"])), - helper.make_tensor_value_info( - "total_sequence_length", TensorProto.INT32, list(shape_dict["total_sequence_length"]) - ), - ] - - if config.do_rotary: - graph_input += [ - helper.make_tensor_value_info("cos_cache", float_type, list(shape_dict["cos_cache"])), - helper.make_tensor_value_info("sin_cache", float_type, list(shape_dict["sin_cache"])), - ] - - graph_output = [ - helper.make_tensor_value_info("output", float_type, list(shape_dict["output"])), - helper.make_tensor_value_info("present_key", float_type, list(shape_dict["present_key"])), - helper.make_tensor_value_info("present_value", float_type, list(shape_dict["present_value"])), - ] - - graph = helper.make_graph( - nodes, - "GroupQueryAttention_Graph", - graph_input, - graph_output, - ) - - model = helper.make_model(graph) - return model.SerializeToString() - - -def create_session(onnx_model_str, cuda_provider_options=None) -> InferenceSession: - session_options = SessionOptions() - ort_session = InferenceSession( - onnx_model_str, - session_options, - providers=[("CUDAExecutionProvider", cuda_provider_options), "CPUExecutionProvider"], - ) - return ort_session - - -class OrtGroupQueryAttention: - """A wrapper of ORT GroupQueryAttention to test relevance and performance.""" - - def __init__(self, config: GroupQueryAttentionConfig): - device = config.device - cuda_provider_options = CudaSession.get_cuda_provider_options( - torch.cuda.current_device(), enable_cuda_graph=False, stream=torch.cuda.current_stream().cuda_stream - ) - onnx_model_str = create_group_query_attention_onnx_model(config) - self.ort_session = create_session(onnx_model_str, cuda_provider_options=cuda_provider_options) - self.gpu_binding_manager = GpuBindingManager( - ort_session=self.ort_session, - device=device, - stream=torch.cuda.current_stream().cuda_stream, - max_cuda_graphs=2, - ) - buffer_sharing = {"past_key": "present_key", "past_value": "present_value"} - self.gpu_binding = self.gpu_binding_manager.get_binding( - config.shape_dict(), use_cuda_graph=False, buffer_sharing=buffer_sharing - ) - self.feed_dict = config.random_inputs() - - def infer(self): - return self.gpu_binding.infer(self.feed_dict) - - -def get_plot_algos(sm: int): +def get_plot_algos(sm: int, local_window_size: Optional[int]): # GQA with local windows only works in sm=8x - if sm >= 80: + if sm >= 80 and local_window_size: return { - "line_vals": ["ort_gqa", "ort_gqa_local"], - "line_names": ["ORT-GQA-Dense", "ORT-GQA-Local"], - "styles": [("red", "-"), ("blue", "-")], + "line_vals": ["ort_gqa", "ort_gqa_local", "ort_gqa_packed", "ort_gqa_local_packed"], + "line_names": ["ORT-GQA-Dense", "ORT-GQA-Local", "ORT-GQA-Dense-PackedQKV", "ORT-GQA-Local-PackedQKV"], + "styles": [("red", "solid"), ("yellow", "dashdot"), ("blue", "dashed"), ("green", "dotted")], } else: return { - "line_vals": ["ort_gqa"], - "line_names": ["ORT-GQA-Dense"], - "styles": [("green", "-")], + "line_vals": ["ort_gqa", "ort_gqa_packed"], + "line_names": ["ORT-GQA-Dense", "ORT-GQA-Dense-PackedQKV"], + "styles": [("red", "solid"), ("blue", "dashed")], } def plot_prompt_performance( sm: int, - batch_size=4, - num_heads=32, - kv_num_heads=8, - max_seq_len=8192, - head_size=128, + model_name: str, + batch_size: int, + num_heads: int, + kv_num_heads: int, + head_size: int, + max_seq_len: int, + local_window_size: Optional[int] = None, ): import triton - algos = get_plot_algos(sm) + algos = get_plot_algos(sm, local_window_size) configs = [ triton.testing.Benchmark( x_names=["sequence_length"], - x_vals=[2**i for i in range(4, 14)], + x_vals=[2**i for i in range(4, 17) if 2**i <= max_seq_len], line_arg="provider", ylabel="ms", **algos, - plot_name=f"prompt-sm{sm}-batch{batch_size}-head{num_heads}_kv{kv_num_heads}-d{head_size}-fp16", + plot_name=f"prompt-sm{sm}-{model_name}-b{batch_size}-h{num_heads}_{kv_num_heads}x{head_size}-fp16", args={ + "batch_size": batch_size, "num_heads": num_heads, "kv_num_heads": kv_num_heads, - "batch_size": batch_size, "head_size": head_size, + "local_window_size": local_window_size, }, ) ] @triton.testing.perf_report(configs) - def benchmark(batch_size, num_heads, kv_num_heads, sequence_length, head_size, provider, device="cuda"): + def benchmark( + provider: str, + sequence_length: int, + batch_size: int, + num_heads: int, + kv_num_heads: int, + head_size: int, + local_window_size: Optional[int] = None, + device="cuda", + ): warmup = 15 repeat = 100 @@ -308,8 +81,9 @@ def plot_prompt_performance( num_heads=num_heads, kv_num_heads=kv_num_heads, head_size=head_size, - local_window_size=1024 if provider == "ort_gqa_local" else -1, + local_window_size=local_window_size if provider in ["ort_gqa_local", "ort_gqa_local_packed"] else -1, device=device, + is_packed_qkv=provider in ["ort_gqa_packed", "ort_gqa_local_packed"], ) obj = OrtGroupQueryAttention(config) @@ -322,40 +96,44 @@ def plot_prompt_performance( def plot_token_performance( sm: int, - batch_size=4, - num_heads=32, - kv_num_heads=8, - max_seq_len=8192, - head_size=128, + model_name: str, + batch_size: int, + num_heads: int, + kv_num_heads: int, + head_size: int, + max_seq_len: int, + local_window_size: Optional[int] = None, ): import triton - algos = get_plot_algos(sm) + algos = get_plot_algos(sm, local_window_size) configs = [ triton.testing.Benchmark( x_names=["past_sequence_length"], - x_vals=[2**i for i in range(4, 13)] + [max_seq_len - 1], + x_vals=[2**i for i in range(4, 17) if 2**i < max_seq_len] + [max_seq_len - 1], line_arg="provider", ylabel="ms", **algos, - plot_name=f"token-sm{sm}-batch{batch_size}-head{num_heads}_kv{kv_num_heads}-d{head_size}-fp16", + plot_name=f"token-sm{sm}-{model_name}-b{batch_size}-h{num_heads}_{kv_num_heads}_d{head_size}-fp16", args={ + "batch_size": batch_size, "num_heads": num_heads, "kv_num_heads": kv_num_heads, - "batch_size": batch_size, "head_size": head_size, + "local_window_size": local_window_size, }, ) ] @triton.testing.perf_report(configs) def benchmark( - batch_size, - num_heads, - kv_num_heads, - past_sequence_length, - head_size, - provider, + provider: str, + past_sequence_length: int, + batch_size: int, + num_heads: int, + kv_num_heads: int, + head_size: int, + local_window_size: Optional[int] = None, device="cuda", ): warmup = 15 @@ -369,7 +147,9 @@ def plot_token_performance( num_heads=num_heads, kv_num_heads=kv_num_heads, head_size=head_size, - local_window_size=1024 if provider == "ort_gqa_local" else -1, + local_window_size=local_window_size if provider in ["ort_gqa_local", "ort_gqa_local_packed"] else -1, + do_rotary=True, # Most models use rotary positional embeddings + is_packed_qkv=provider in ["ort_gqa_packed", "ort_gqa_local_packed"], device=device, ) @@ -386,25 +166,46 @@ def run_performance_test(sm: int): Run performance tests for prompt and token generation. """ - for batch_size in [1, 4, 8, 16]: - for num_heads, kv_num_heads in [(8, 8), (16, 8), (32, 8), (64, 8)]: - for head_size in [64, 128]: - plot_prompt_performance( - sm=sm, - batch_size=batch_size, - num_heads=num_heads, - kv_num_heads=kv_num_heads, - max_seq_len=8192, - head_size=head_size, - ) - plot_token_performance( - sm=sm, - batch_size=batch_size, - num_heads=num_heads, - kv_num_heads=kv_num_heads, - max_seq_len=8192, - head_size=head_size, - ) + device_id = torch.cuda.current_device() + memory_in_gb = torch.cuda.get_device_properties(device_id).total_memory / (1024 * 1024 * 1024) + + # Note: some models use bf16. + # We use fp16 for all models in this test since bf16 is not supported in ORT python API. + configures = [ + (32, 128, 8, 8192, None, "Llama3-8B"), + (64, 128, 8, 8192, None, "Llama3-70B"), + (32, 128, 8, 32768, 4096, "Mistral-7B-v0.1"), + (48, 128, 8, 65536, None, "Mixtral-8x22B-v0.1"), + (32, 96, 32, 131072, None, "Phi-3-mini-128k"), + (32, 128, 8, 131072, None, "Phi-3-small-128k"), # Sparsity is not used in this test + (40, 128, 10, 131072, None, "Phi-3-medium-128K"), + ] + + # Reduce max sequence length when GPU memory is not enough. + threshold = 131072 if memory_in_gb > 24 else 65536 if memory_in_gb > 12 else 32768 + + for num_heads, head_size, kv_num_heads, max_seq_len, local_window_size, model_name in configures: + for batch_size in [1, 4]: + plot_prompt_performance( + sm=sm, + batch_size=batch_size, + num_heads=num_heads, + kv_num_heads=kv_num_heads, + head_size=head_size, + max_seq_len=min(threshold, max_seq_len), + local_window_size=local_window_size, + model_name=model_name, + ) + plot_token_performance( + sm=sm, + batch_size=batch_size, + num_heads=num_heads, + kv_num_heads=kv_num_heads, + head_size=head_size, + max_seq_len=min(threshold, max_seq_len), + local_window_size=local_window_size, + model_name=model_name, + ) if __name__ == "__main__":