diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index 67f58004ffd..0e0243b664a 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -280,6 +280,7 @@ COMMON_HIPCC_FLAGS = [ '-DCUDA_HAS_FP16=1', '-D__HIP_NO_HALF_OPERATORS__=1', '-D__HIP_NO_HALF_CONVERSIONS__=1', + '-DHIP_ENABLE_WARP_SYNC_BUILTINS=1' ] JIT_EXTENSION_VERSIONER = ExtensionVersioner()