From ef0332e36d3b5faa7dd7621b2ed766e0e78e40f9 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Thu, 2 Jun 2022 21:35:56 +0000 Subject: [PATCH] Allow relocatable device code linking in pytorch CUDA extensions (#78225) Close https://github.com/pytorch/pytorch/issues/57543 Doc: check `Relocatable device code linking:` in https://docs-preview.pytorch.org/78225/cpp_extension.html#torch.utils.cpp_extension.CUDAExtension Pull Request resolved: https://github.com/pytorch/pytorch/pull/78225 Approved by: https://github.com/ezyang, https://github.com/malfet --- test/cpp_extensions/cuda_dlink_extension.cpp | 21 ++++++ .../cuda_dlink_extension_add.cu | 6 ++ .../cuda_dlink_extension_add.cuh | 6 ++ .../cuda_dlink_extension_kernel.cu | 22 ++++++ test/cpp_extensions/setup.py | 13 ++++ test/run_test.py | 2 + test/test_cpp_extensions_aot.py | 11 +++ torch/utils/cpp_extension.py | 74 ++++++++++++++++++- 8 files changed, 153 insertions(+), 2 deletions(-) create mode 100644 test/cpp_extensions/cuda_dlink_extension.cpp create mode 100644 test/cpp_extensions/cuda_dlink_extension_add.cu create mode 100644 test/cpp_extensions/cuda_dlink_extension_add.cuh create mode 100644 test/cpp_extensions/cuda_dlink_extension_kernel.cu diff --git a/test/cpp_extensions/cuda_dlink_extension.cpp b/test/cpp_extensions/cuda_dlink_extension.cpp new file mode 100644 index 00000000000..46aedeff9a9 --- /dev/null +++ b/test/cpp_extensions/cuda_dlink_extension.cpp @@ -0,0 +1,21 @@ +#include + +// Declare the function from cuda_dlink_extension.cu. +void add_cuda(const float* a, const float* b, float* output, int size); + +at::Tensor add(at::Tensor a, at::Tensor b) { + TORCH_CHECK(a.device().is_cuda(), "a is a cuda tensor"); + TORCH_CHECK(b.device().is_cuda(), "b is a cuda tensor"); + TORCH_CHECK(a.dtype() == at::kFloat, "a is a float tensor"); + TORCH_CHECK(b.dtype() == at::kFloat, "b is a float tensor"); + TORCH_CHECK(a.sizes() == b.sizes(), "a and b should have same size"); + + at::Tensor output = at::empty_like(a); + add_cuda(a.data_ptr(), b.data_ptr(), output.data_ptr(), a.numel()); + + return output; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("add", &add, "a + b"); +} diff --git a/test/cpp_extensions/cuda_dlink_extension_add.cu b/test/cpp_extensions/cuda_dlink_extension_add.cu new file mode 100644 index 00000000000..5e19d8e069d --- /dev/null +++ b/test/cpp_extensions/cuda_dlink_extension_add.cu @@ -0,0 +1,6 @@ +#include +#include + +__device__ void add(const float* a, const float* b, float* output) { + *output = *a + *b; +} diff --git a/test/cpp_extensions/cuda_dlink_extension_add.cuh b/test/cpp_extensions/cuda_dlink_extension_add.cuh new file mode 100644 index 00000000000..9427caaac28 --- /dev/null +++ b/test/cpp_extensions/cuda_dlink_extension_add.cuh @@ -0,0 +1,6 @@ +#pragma once + +#include +#include + +__device__ void add(const float* a, const float* b, float* output); diff --git a/test/cpp_extensions/cuda_dlink_extension_kernel.cu b/test/cpp_extensions/cuda_dlink_extension_kernel.cu new file mode 100644 index 00000000000..66e6077f59f --- /dev/null +++ b/test/cpp_extensions/cuda_dlink_extension_kernel.cu @@ -0,0 +1,22 @@ +#include +#include +#include + +#include + +#include "cuda_dlink_extension_add.cuh" + +__global__ void add_kernel(const float* a, const float* b, float* output, int size) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < size) { + add(a + i, b + i, output + i); + } +} + +// output = a * b + c +void add_cuda(const float* a, const float* b, float* output, int size) { + const int threads = 1024; + const int blocks = (size + threads - 1) / threads; + add_kernel<<>>(a, b, output, size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} diff --git a/test/cpp_extensions/setup.py b/test/cpp_extensions/setup.py index df533941730..db5c49dc2a1 100644 --- a/test/cpp_extensions/setup.py +++ b/test/cpp_extensions/setup.py @@ -66,6 +66,19 @@ if (not IS_WINDOWS) and torch.cuda.is_available() and CUDA_HOME is not None: ) ext_modules.append(cusolver_extension) +if USE_NINJA and (not IS_WINDOWS) and torch.cuda.is_available() and CUDA_HOME is not None: + extension = CUDAExtension( + name='torch_test_cpp_extension.cuda_dlink', + sources=[ + 'cuda_dlink_extension.cpp', + 'cuda_dlink_extension_kernel.cu', + 'cuda_dlink_extension_add.cu', + ], + dlink=True, + extra_compile_args={'cxx': CXX_FLAGS, + 'nvcc': ['-O2', '-dc']}) + ext_modules.append(extension) + setup( name='torch_test_cpp_extension', packages=['torch_test_cpp_extension'], diff --git a/test/run_test.py b/test/run_test.py index f78c1650699..5c4d665662b 100644 --- a/test/run_test.py +++ b/test/run_test.py @@ -488,6 +488,7 @@ def _test_cpp_extensions_aot(test_directory, options, use_ninja): python_path = os.environ.get("PYTHONPATH", "") from shutil import copyfile + os.environ['USE_NINJA'] = shell_env['USE_NINJA'] test_module = "test_cpp_extensions_aot" + ("_ninja" if use_ninja else "_no_ninja") copyfile( test_directory + "/test_cpp_extensions_aot.py", @@ -509,6 +510,7 @@ def _test_cpp_extensions_aot(test_directory, options, use_ninja): os.environ["PYTHONPATH"] = python_path if os.path.exists(test_directory + "/" + test_module + ".py"): os.remove(test_directory + "/" + test_module + ".py") + os.environ.pop('USE_NINJA') def test_cpp_extensions_aot_ninja(test_module, test_directory, options): diff --git a/test/test_cpp_extensions_aot.py b/test/test_cpp_extensions_aot.py index fc3c7be6106..2f505553859 100644 --- a/test/test_cpp_extensions_aot.py +++ b/test/test_cpp_extensions_aot.py @@ -121,6 +121,17 @@ class TestCppExtensionAOT(common.TestCase): has_value = cpp_extension.function_taking_optional(None) self.assertFalse(has_value) + @common.skipIfRocm + @unittest.skipIf(common.IS_WINDOWS, "Windows not supported") + @unittest.skipIf(not TEST_CUDA, "CUDA not found") + @unittest.skipIf(os.getenv('USE_NINJA', '0') == '0', "cuda extension with dlink requires ninja to build") + def test_cuda_dlink_libs(self): + from torch_test_cpp_extension import cuda_dlink + a = torch.randn(8, dtype=torch.float, device='cuda') + b = torch.randn(8, dtype=torch.float, device='cuda') + ref = a + b + test = cuda_dlink.add(a, b) + self.assertEqual(test, ref) class TestORTTensor(common.TestCase): def test_unregistered(self): diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index aed4fd5004b..edcdcd0fce0 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -455,6 +455,9 @@ class BuildExtension(build_ext, object): self._define_torch_extension_name(extension) self._add_gnu_cpp_abi_flag(extension) + if 'nvcc_dlink' in extension.extra_compile_args: + assert self.use_ninja, f"With dlink=True, ninja is required to build cuda extension {extension.name}." + # Register .cu, .cuh and .hip as valid source extensions. self.compiler.src_extensions += ['.cu', '.cuh', '.hip'] # Save the original _compile method for later. @@ -583,6 +586,10 @@ class BuildExtension(build_ext, object): cuda_cflags = [shlex.quote(f) for f in cuda_cflags] cuda_post_cflags = [shlex.quote(f) for f in cuda_post_cflags] + if isinstance(extra_postargs, dict) and 'nvcc_dlink' in extra_postargs: + cuda_dlink_post_cflags = unix_cuda_flags(extra_postargs['nvcc_dlink']) + else: + cuda_dlink_post_cflags = None _write_ninja_file_and_compile_objects( sources=sources, objects=objects, @@ -590,6 +597,7 @@ class BuildExtension(build_ext, object): post_cflags=[shlex.quote(f) for f in post_cflags], cuda_cflags=cuda_cflags, cuda_post_cflags=cuda_post_cflags, + cuda_dlink_post_cflags=cuda_dlink_post_cflags, build_directory=output_dir, verbose=True, with_cuda=with_cuda) @@ -734,6 +742,10 @@ class BuildExtension(build_ext, object): if with_cuda: cuda_cflags = _nt_quote_args(cuda_cflags) cuda_post_cflags = _nt_quote_args(cuda_post_cflags) + if isinstance(extra_postargs, dict) and 'nvcc_dlink' in extra_postargs: + cuda_dlink_post_cflags = win_cuda_flags(extra_postargs['nvcc_dlink']) + else: + cuda_dlink_post_cflags = None _write_ninja_file_and_compile_objects( sources=sources, @@ -742,6 +754,7 @@ class BuildExtension(build_ext, object): post_cflags=post_cflags, cuda_cflags=cuda_cflags, cuda_post_cflags=cuda_post_cflags, + cuda_dlink_post_cflags=cuda_dlink_post_cflags, build_directory=output_dir, verbose=True, with_cuda=with_cuda) @@ -978,6 +991,31 @@ def CUDAExtension(name, sources, *args, **kwargs): Currently open issue for nvcc bug: https://github.com/pytorch/pytorch/issues/69460 Complete workaround code example: https://github.com/facebookresearch/pytorch3d/commit/cb170ac024a949f1f9614ffe6af1c38d972f7d48 + Relocatable device code linking: + + If you want to reference device symbols across compilation units (across object files), + the object files need to be built with `relocatable device code` (-rdc=true or -dc). + An exception to this rule is "dynamic parallelism" (nested kernel launches) which is not used a lot anymore. + `Relocatable device code` is less optimized so it needs to be used only on object files that need it. + Using `-dlto` (Device Link Time Optimization) at the device code compilation step and `dlink` step + help reduce the protentional perf degradation of `-rdc`. + Note that it needs to be used at both steps to be useful. + + If you have `rdc` objects you need to have an extra `-dlink` (device linking) step before the CPU symbol linking step. + There is also a case where `-dlink` is used without `-rdc`: + when an extension is linked against a static lib containing rdc-compiled objects + like the [NVSHMEM library](https://developer.nvidia.com/nvshmem). + + Note: Ninja is required to build a CUDA Extension with RDC linking. + + Example: + >>> CUDAExtension( + name='cuda_extension', + sources=['extension.cpp', 'extension_kernel.cu'], + dlink=True, + dlink_libraries=["dlink_lib"], + extra_compile_args={'cxx': ['-g'], + 'nvcc': ['-O2', '-rdc=true']}) ''' library_dirs = kwargs.get('library_dirs', []) library_dirs += library_paths(cuda=True) @@ -1031,6 +1069,23 @@ def CUDAExtension(name, sources, *args, **kwargs): kwargs['language'] = 'c++' + dlink_libraries = kwargs.get('dlink_libraries', []) + dlink = kwargs.get('dlink', False) or dlink_libraries + if dlink: + extra_compile_args = kwargs.get('extra_compile_args', {}) + + extra_compile_args_dlink = extra_compile_args.get('nvcc_dlink', []) + extra_compile_args_dlink += ['-dlink'] + extra_compile_args_dlink += [f'-L{x}' for x in library_dirs] + extra_compile_args_dlink += [f'-l{x}' for x in dlink_libraries] + + if (torch.version.cuda is not None) and packaging.version.parse(torch.version.cuda) >= packaging.version.parse('11.2'): + extra_compile_args_dlink += ['-dlto'] # Device Link Time Optimization started from cuda 11.2 + + extra_compile_args['nvcc_dlink'] = extra_compile_args_dlink + + kwargs['extra_compile_args'] = extra_compile_args + return setuptools.Extension(name, sources, *args, **kwargs) @@ -1457,6 +1512,7 @@ def _write_ninja_file_and_compile_objects( post_cflags, cuda_cflags, cuda_post_cflags, + cuda_dlink_post_cflags, build_directory: str, verbose: bool, with_cuda: Optional[bool]) -> None: @@ -1477,6 +1533,7 @@ def _write_ninja_file_and_compile_objects( post_cflags=post_cflags, cuda_cflags=cuda_cflags, cuda_post_cflags=cuda_post_cflags, + cuda_dlink_post_cflags=cuda_dlink_post_cflags, sources=sources, objects=objects, ldflags=None, @@ -1966,6 +2023,7 @@ def _write_ninja_file_to_build_library(path, post_cflags=None, cuda_cflags=cuda_flags, cuda_post_cflags=None, + cuda_dlink_post_cflags=None, sources=sources, objects=objects, ldflags=ldflags, @@ -1978,6 +2036,7 @@ def _write_ninja_file(path, post_cflags, cuda_cflags, cuda_post_cflags, + cuda_dlink_post_cflags, sources, objects, ldflags, @@ -2007,6 +2066,7 @@ def _write_ninja_file(path, post_cflags = sanitize_flags(post_cflags) cuda_cflags = sanitize_flags(cuda_cflags) cuda_post_cflags = sanitize_flags(cuda_post_cflags) + cuda_dlink_post_cflags = sanitize_flags(cuda_dlink_post_cflags) ldflags = sanitize_flags(ldflags) # Sanity checks... @@ -2021,7 +2081,7 @@ def _write_ninja_file(path, # Version 1.3 is required for the `deps` directive. config = ['ninja_required_version = 1.3'] config.append(f'cxx = {compiler}') - if with_cuda: + if with_cuda or cuda_dlink_post_cflags: if IS_HIP_EXTENSION: nvcc = _join_rocm_home('bin', 'hipcc') else: @@ -2035,6 +2095,7 @@ def _write_ninja_file(path, if with_cuda: flags.append(f'cuda_cflags = {" ".join(cuda_cflags)}') flags.append(f'cuda_post_cflags = {" ".join(cuda_post_cflags)}') + flags.append(f'cuda_dlink_post_cflags = {" ".join(cuda_dlink_post_cflags)}') flags.append(f'ldflags = {" ".join(ldflags)}') # Turn into absolute paths so we can emit them into the ninja build @@ -2084,6 +2145,15 @@ def _write_ninja_file(path, object_file = object_file.replace(" ", "$ ") build.append(f'build {object_file}: {rule} {source_file}') + if cuda_dlink_post_cflags: + devlink_out = os.path.join(os.path.dirname(objects[0]), 'dlink.o') + devlink_rule = ['rule cuda_devlink'] + devlink_rule.append(' command = $nvcc $in -o $out $cuda_dlink_post_cflags') + devlink = [f'build {devlink_out}: cuda_devlink {" ".join(objects)}'] + objects += [devlink_out] + else: + devlink_rule, devlink = [], [] + if library_target is not None: link_rule = ['rule link'] if IS_WINDOWS: @@ -2107,7 +2177,7 @@ def _write_ninja_file(path, blocks = [config, flags, compile_rule] if with_cuda: blocks.append(cuda_compile_rule) - blocks += [link_rule, build, link, default] + blocks += [devlink_rule, link_rule, build, devlink, link, default] with open(path, 'w') as build_file: for block in blocks: lines = '\n'.join(block)