diff --git a/CMakeLists.txt b/CMakeLists.txt index f3531e9ad9f..98593c2de97 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -883,6 +883,16 @@ cmake_dependent_option( Will be disabled if not supported by the platform" ON "USE_CUDA OR USE_ROCM" OFF) +# +# Cannot be put into Dependencies.cmake due circular dependency: +# USE_FLASH_ATTENTION -> USE_ROCM -> Dependencies.cmake -> aotriton.cmake +# +if(USE_ROCM) + if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION) + include(cmake/External/aotriton.cmake) + endif() +endif() + if(DEBUG_CUDA) string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo") string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " -lineinfo") diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index f093878ee3c..66178a8d879 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -25,7 +25,10 @@ #include #if USE_ROCM +#if defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION) #include +#define USE_AOTRITON 1 +#endif #endif /** @@ -208,6 +211,7 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug using sm80 = SMVersion<8, 0>; using sm90 = SMVersion<9, 0>; #if USE_ROCM +#if USE_AOTRITON auto stream = at::cuda::getCurrentCUDAStream().stream(); if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) { auto dprops = at::cuda::getCurrentDeviceProperties(); @@ -217,6 +221,9 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug } return false; } +#else + return false; +#endif #else auto dprops = at::cuda::getCurrentDeviceProperties(); if (!check_sm_version(dprops)) { @@ -239,6 +246,7 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) using sm50 = SMVersion<5, 0>; using sm90 = SMVersion<9, 0>; #if USE_ROCM +#if USE_AOTRITON auto stream = at::cuda::getCurrentCUDAStream().stream(); if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) { auto dprops = at::cuda::getCurrentDeviceProperties(); @@ -248,6 +256,9 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) } return false; } +#else + return false; +#endif #else auto dprops = at::cuda::getCurrentDeviceProperties(); if (!check_sm_version(dprops)) { diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index ef33a316534..86cb3b28b6e 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1103,10 +1103,6 @@ if(USE_ROCM) message(STATUS "Disabling Kernel Assert for ROCm") endif() - include(${CMAKE_CURRENT_LIST_DIR}/External/aotriton.cmake) - if(USE_CUDA) - caffe2_update_option(USE_MEM_EFF_ATTENTION OFF) - endif() else() caffe2_update_option(USE_ROCM OFF) endif()