2023-01-20 20:33:01 +00:00
|
|
|
if (onnxruntime_USE_FLASH_ATTENTION)
|
|
|
|
|
include(FetchContent)
|
2023-01-25 17:43:48 +00:00
|
|
|
FetchContent_Declare(
|
|
|
|
|
cutlass
|
|
|
|
|
URL ${DEP_URL_cutlass}
|
|
|
|
|
URL_HASH SHA1=${DEP_SHA1_cutlass}
|
Extend memory efficient attention coverage in Attention/MHA cuda op (#15064)
### Description
<!-- Describe your changes. -->
1. upgrade cutlass to 3.0 that containing attn_bias support.
2. extend Attention/MHA to use memory efficient attention when
rel_pos_bias with [1, num_head, s, s*] and 1d mask with [2 * batch_size
+ 1] are present.
new mask format introduction:
MASK_1D_KEY_SEQ_LEN_START,
[3 * batch_size + 2] with [key_len[0], ..., key_len[batch_size - 1],
query_start[0], ..., query_start[batch_size - 1], query_end[batch_size -
1], key_start[0], ..., key_start[batch_size - 1], key_end[batch_size -
1]]
e.g
2D mask with [[1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 0]] converts to this
1D mask is [3, 5, 0, 6, 12, 0, 6, 12]
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
It potentially benefits tnlrv6 and t5(encoder)
---------
Co-authored-by: Ubuntu <wy@v100-2.0cdb2e52twzevn1i4fi45bylyg.jx.internal.cloudapp.net>
Co-authored-by: Kunal Vaishnavi <kvaishnavi@microsoft.com>
Co-authored-by: Kunal Vaishnavi <kvaishnavi@microsoft.com@orttrainingdev7.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
2023-03-23 18:05:17 +00:00
|
|
|
PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/cutlass/cutlass.patch
|
2023-01-20 20:33:01 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
FetchContent_GetProperties(cutlass)
|
|
|
|
|
if(NOT cutlass_POPULATED)
|
|
|
|
|
FetchContent_Populate(cutlass)
|
|
|
|
|
endif()
|
|
|
|
|
endif()
|