diff --git a/.gitignore b/.gitignore index 0a16dd79f22..5b2a18a2f46 100644 --- a/.gitignore +++ b/.gitignore @@ -52,6 +52,8 @@ test/cpp_extensions/install/ test/test-reports/ third_party/build/ tools/shared/_utils_internal.py +tools/fast_nvcc/wrap_nvcc.sh +tools/fast_nvcc/tmp/ torch.egg-info/ torch/_C/__init__.pyi torch/_C/_nn.pyi diff --git a/CMakeLists.txt b/CMakeLists.txt index b84a30469cb..e0f794274df 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -161,6 +161,7 @@ option(COLORIZE_OUTPUT "Colorize output during compilation" ON) option(USE_ASAN "Use Address Sanitizer" OFF) option(USE_TSAN "Use Thread Sanitizer" OFF) option(USE_CUDA "Use CUDA" ON) +option(USE_FAST_NVCC "Use parallel NVCC build" OFF) option(USE_ROCM "Use ROCm" ON) option(CAFFE2_STATIC_LINK_CUDA "Statically link CUDA libraries" OFF) cmake_dependent_option( diff --git a/cmake/Modules_CUDA_fix/upstream/FindCUDA.cmake b/cmake/Modules_CUDA_fix/upstream/FindCUDA.cmake index ec8a732e307..fd457bea956 100644 --- a/cmake/Modules_CUDA_fix/upstream/FindCUDA.cmake +++ b/cmake/Modules_CUDA_fix/upstream/FindCUDA.cmake @@ -767,6 +767,18 @@ else() # Search default search paths, after we search our own set of paths. cuda_find_host_program(CUDA_NVCC_EXECUTABLE nvcc) endif() + +# FAST_NVCC +if(USE_FAST_NVCC AND CUDA_NVCC_EXECUTABLE AND NOT CUDA_NVCC_EXECUTABLE_ORIGIN) + set(CUDA_NVCC_EXECUTABLE_ORIGIN "${CUDA_NVCC_EXECUTABLE}") + set(FAST_NVCC_EXECUTABLE "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/fast_nvcc.py") + configure_file(${PROJECT_SOURCE_DIR}/tools/fast_nvcc/wrap_nvcc.sh.in "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/tmp/wrap_nvcc.sh") + file(COPY "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/tmp/wrap_nvcc.sh" + DESTINATION "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/" + FILE_PERMISSIONS OWNER_READ OWNER_WRITE OWNER_EXECUTE GROUP_READ GROUP_EXECUTE WORLD_READ WORLD_EXECUTE + ) + set(CUDA_NVCC_EXECUTABLE "${PROJECT_SOURCE_DIR}/tools/fast_nvcc/wrap_nvcc.sh") +endif() mark_as_advanced(CUDA_NVCC_EXECUTABLE) if(CUDA_NVCC_EXECUTABLE AND NOT CUDA_VERSION) @@ -789,7 +801,6 @@ else() string(REGEX REPLACE "([0-9]+)\\.([0-9]+).*" "\\2" CUDA_VERSION_MINOR "${CUDA_VERSION}") endif() - # Always set this convenience variable set(CUDA_VERSION_STRING "${CUDA_VERSION}") diff --git a/tools/fast_nvcc/fast_nvcc.py b/tools/fast_nvcc/fast_nvcc.py index 2a8d1d73145..fc14191ffd1 100755 --- a/tools/fast_nvcc/fast_nvcc.py +++ b/tools/fast_nvcc/fast_nvcc.py @@ -157,6 +157,19 @@ def warn_if_tmpdir_set(env): fast_nvcc_warn(url_vars) +def contains_non_executable(commands): + for command in commands: + # This is to deal with special command dry-run result from NVCC such as: + # ``` + # #$ "/lib64/ccache"/c++ -std=c++11 -E -x c++ -D__CUDACC__ -D__NVCC__ -fPIC -fvisibility=hidden -O3 \ + # -I ... -m64 "reduce_scatter.cu" > "/tmp/tmpxft_0037fae3_00000000-5_reduce_scatter.cpp4.ii + # #$ -- Filter Dependencies -- > ... pytorch/build/nccl/obj/collectives/device/reduce_scatter.dep.tmp + # ``` + if command.startswith("--"): + return True + return False + + def module_id_contents(command): """ Guess the contents of the .module_id file contained within command. @@ -275,6 +288,8 @@ def is_weakly_connected(graph): """ Return true iff graph is weakly connected. """ + if not graph: + return True neighbors = [set() for _ in graph] for node, predecessors in enumerate(graph): for pred in predecessors: @@ -408,6 +423,10 @@ def exit_code(results): return 0 +def wrap_nvcc(args, config=default_config): + return subprocess.call([config.nvcc] + args) + + def fast_nvcc(args, *, config=default_config): """ Emulate the result of calling the given nvcc binary with args. @@ -422,6 +441,10 @@ def fast_nvcc(args, *, config=default_config): commands = dryrun_data['commands'] if not config.faithful: commands = make_rm_force(unique_module_id_files(commands)) + + if contains_non_executable(commands): + return wrap_nvcc(args, config) + command_parts = list(map(shlex.split, commands)) if config.verbose: print_verbose_output( diff --git a/tools/fast_nvcc/wrap_nvcc.sh.in b/tools/fast_nvcc/wrap_nvcc.sh.in new file mode 100644 index 00000000000..d15ed016068 --- /dev/null +++ b/tools/fast_nvcc/wrap_nvcc.sh.in @@ -0,0 +1,5 @@ +#!/bin/bash + +# This script was created because cmake is not happy about dangling -- when +# defining CUDA_NVCC_EXECUTABLE, thus we wrapped it in a shell script. +@FAST_NVCC_EXECUTABLE@ --nvcc @CUDA_NVCC_EXECUTABLE_ORIGIN@ -- "$@"