Allow relocatable device code linking in pytorch CUDA extensions (#78225)

Close https://github.com/pytorch/pytorch/issues/57543

Doc: check `Relocatable device code linking:` in https://docs-preview.pytorch.org/78225/cpp_extension.html#torch.utils.cpp_extension.CUDAExtension
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78225
Approved by: https://github.com/ezyang, https://github.com/malfet
This commit is contained in:
Xiao Wang 2022-06-02 21:35:56 +00:00 committed by PyTorch MergeBot
parent 9446f9678a
commit ef0332e36d
8 changed files with 153 additions and 2 deletions

View file

@ -0,0 +1,21 @@
#include <torch/extension.h>
// Declare the function from cuda_dlink_extension.cu.
void add_cuda(const float* a, const float* b, float* output, int size);
at::Tensor add(at::Tensor a, at::Tensor b) {
TORCH_CHECK(a.device().is_cuda(), "a is a cuda tensor");
TORCH_CHECK(b.device().is_cuda(), "b is a cuda tensor");
TORCH_CHECK(a.dtype() == at::kFloat, "a is a float tensor");
TORCH_CHECK(b.dtype() == at::kFloat, "b is a float tensor");
TORCH_CHECK(a.sizes() == b.sizes(), "a and b should have same size");
at::Tensor output = at::empty_like(a);
add_cuda(a.data_ptr<float>(), b.data_ptr<float>(), output.data_ptr<float>(), a.numel());
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("add", &add, "a + b");
}

View file

@ -0,0 +1,6 @@
#include <cuda.h>
#include <cuda_runtime.h>
__device__ void add(const float* a, const float* b, float* output) {
*output = *a + *b;
}

View file

@ -0,0 +1,6 @@
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
__device__ void add(const float* a, const float* b, float* output);

View file

@ -0,0 +1,22 @@
#include <cuda.h>
#include <cuda_runtime.h>
#include <c10/cuda/CUDAException.h>
#include <ATen/ATen.h>
#include "cuda_dlink_extension_add.cuh"
__global__ void add_kernel(const float* a, const float* b, float* output, int size) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < size) {
add(a + i, b + i, output + i);
}
}
// output = a * b + c
void add_cuda(const float* a, const float* b, float* output, int size) {
const int threads = 1024;
const int blocks = (size + threads - 1) / threads;
add_kernel<<<blocks, threads>>>(a, b, output, size);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}

View file

@ -66,6 +66,19 @@ if (not IS_WINDOWS) and torch.cuda.is_available() and CUDA_HOME is not None:
)
ext_modules.append(cusolver_extension)
if USE_NINJA and (not IS_WINDOWS) and torch.cuda.is_available() and CUDA_HOME is not None:
extension = CUDAExtension(
name='torch_test_cpp_extension.cuda_dlink',
sources=[
'cuda_dlink_extension.cpp',
'cuda_dlink_extension_kernel.cu',
'cuda_dlink_extension_add.cu',
],
dlink=True,
extra_compile_args={'cxx': CXX_FLAGS,
'nvcc': ['-O2', '-dc']})
ext_modules.append(extension)
setup(
name='torch_test_cpp_extension',
packages=['torch_test_cpp_extension'],

View file

@ -488,6 +488,7 @@ def _test_cpp_extensions_aot(test_directory, options, use_ninja):
python_path = os.environ.get("PYTHONPATH", "")
from shutil import copyfile
os.environ['USE_NINJA'] = shell_env['USE_NINJA']
test_module = "test_cpp_extensions_aot" + ("_ninja" if use_ninja else "_no_ninja")
copyfile(
test_directory + "/test_cpp_extensions_aot.py",
@ -509,6 +510,7 @@ def _test_cpp_extensions_aot(test_directory, options, use_ninja):
os.environ["PYTHONPATH"] = python_path
if os.path.exists(test_directory + "/" + test_module + ".py"):
os.remove(test_directory + "/" + test_module + ".py")
os.environ.pop('USE_NINJA')
def test_cpp_extensions_aot_ninja(test_module, test_directory, options):

View file

@ -121,6 +121,17 @@ class TestCppExtensionAOT(common.TestCase):
has_value = cpp_extension.function_taking_optional(None)
self.assertFalse(has_value)
@common.skipIfRocm
@unittest.skipIf(common.IS_WINDOWS, "Windows not supported")
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
@unittest.skipIf(os.getenv('USE_NINJA', '0') == '0', "cuda extension with dlink requires ninja to build")
def test_cuda_dlink_libs(self):
from torch_test_cpp_extension import cuda_dlink
a = torch.randn(8, dtype=torch.float, device='cuda')
b = torch.randn(8, dtype=torch.float, device='cuda')
ref = a + b
test = cuda_dlink.add(a, b)
self.assertEqual(test, ref)
class TestORTTensor(common.TestCase):
def test_unregistered(self):

View file

@ -455,6 +455,9 @@ class BuildExtension(build_ext, object):
self._define_torch_extension_name(extension)
self._add_gnu_cpp_abi_flag(extension)
if 'nvcc_dlink' in extension.extra_compile_args:
assert self.use_ninja, f"With dlink=True, ninja is required to build cuda extension {extension.name}."
# Register .cu, .cuh and .hip as valid source extensions.
self.compiler.src_extensions += ['.cu', '.cuh', '.hip']
# Save the original _compile method for later.
@ -583,6 +586,10 @@ class BuildExtension(build_ext, object):
cuda_cflags = [shlex.quote(f) for f in cuda_cflags]
cuda_post_cflags = [shlex.quote(f) for f in cuda_post_cflags]
if isinstance(extra_postargs, dict) and 'nvcc_dlink' in extra_postargs:
cuda_dlink_post_cflags = unix_cuda_flags(extra_postargs['nvcc_dlink'])
else:
cuda_dlink_post_cflags = None
_write_ninja_file_and_compile_objects(
sources=sources,
objects=objects,
@ -590,6 +597,7 @@ class BuildExtension(build_ext, object):
post_cflags=[shlex.quote(f) for f in post_cflags],
cuda_cflags=cuda_cflags,
cuda_post_cflags=cuda_post_cflags,
cuda_dlink_post_cflags=cuda_dlink_post_cflags,
build_directory=output_dir,
verbose=True,
with_cuda=with_cuda)
@ -734,6 +742,10 @@ class BuildExtension(build_ext, object):
if with_cuda:
cuda_cflags = _nt_quote_args(cuda_cflags)
cuda_post_cflags = _nt_quote_args(cuda_post_cflags)
if isinstance(extra_postargs, dict) and 'nvcc_dlink' in extra_postargs:
cuda_dlink_post_cflags = win_cuda_flags(extra_postargs['nvcc_dlink'])
else:
cuda_dlink_post_cflags = None
_write_ninja_file_and_compile_objects(
sources=sources,
@ -742,6 +754,7 @@ class BuildExtension(build_ext, object):
post_cflags=post_cflags,
cuda_cflags=cuda_cflags,
cuda_post_cflags=cuda_post_cflags,
cuda_dlink_post_cflags=cuda_dlink_post_cflags,
build_directory=output_dir,
verbose=True,
with_cuda=with_cuda)
@ -978,6 +991,31 @@ def CUDAExtension(name, sources, *args, **kwargs):
Currently open issue for nvcc bug: https://github.com/pytorch/pytorch/issues/69460
Complete workaround code example: https://github.com/facebookresearch/pytorch3d/commit/cb170ac024a949f1f9614ffe6af1c38d972f7d48
Relocatable device code linking:
If you want to reference device symbols across compilation units (across object files),
the object files need to be built with `relocatable device code` (-rdc=true or -dc).
An exception to this rule is "dynamic parallelism" (nested kernel launches) which is not used a lot anymore.
`Relocatable device code` is less optimized so it needs to be used only on object files that need it.
Using `-dlto` (Device Link Time Optimization) at the device code compilation step and `dlink` step
help reduce the protentional perf degradation of `-rdc`.
Note that it needs to be used at both steps to be useful.
If you have `rdc` objects you need to have an extra `-dlink` (device linking) step before the CPU symbol linking step.
There is also a case where `-dlink` is used without `-rdc`:
when an extension is linked against a static lib containing rdc-compiled objects
like the [NVSHMEM library](https://developer.nvidia.com/nvshmem).
Note: Ninja is required to build a CUDA Extension with RDC linking.
Example:
>>> CUDAExtension(
name='cuda_extension',
sources=['extension.cpp', 'extension_kernel.cu'],
dlink=True,
dlink_libraries=["dlink_lib"],
extra_compile_args={'cxx': ['-g'],
'nvcc': ['-O2', '-rdc=true']})
'''
library_dirs = kwargs.get('library_dirs', [])
library_dirs += library_paths(cuda=True)
@ -1031,6 +1069,23 @@ def CUDAExtension(name, sources, *args, **kwargs):
kwargs['language'] = 'c++'
dlink_libraries = kwargs.get('dlink_libraries', [])
dlink = kwargs.get('dlink', False) or dlink_libraries
if dlink:
extra_compile_args = kwargs.get('extra_compile_args', {})
extra_compile_args_dlink = extra_compile_args.get('nvcc_dlink', [])
extra_compile_args_dlink += ['-dlink']
extra_compile_args_dlink += [f'-L{x}' for x in library_dirs]
extra_compile_args_dlink += [f'-l{x}' for x in dlink_libraries]
if (torch.version.cuda is not None) and packaging.version.parse(torch.version.cuda) >= packaging.version.parse('11.2'):
extra_compile_args_dlink += ['-dlto'] # Device Link Time Optimization started from cuda 11.2
extra_compile_args['nvcc_dlink'] = extra_compile_args_dlink
kwargs['extra_compile_args'] = extra_compile_args
return setuptools.Extension(name, sources, *args, **kwargs)
@ -1457,6 +1512,7 @@ def _write_ninja_file_and_compile_objects(
post_cflags,
cuda_cflags,
cuda_post_cflags,
cuda_dlink_post_cflags,
build_directory: str,
verbose: bool,
with_cuda: Optional[bool]) -> None:
@ -1477,6 +1533,7 @@ def _write_ninja_file_and_compile_objects(
post_cflags=post_cflags,
cuda_cflags=cuda_cflags,
cuda_post_cflags=cuda_post_cflags,
cuda_dlink_post_cflags=cuda_dlink_post_cflags,
sources=sources,
objects=objects,
ldflags=None,
@ -1966,6 +2023,7 @@ def _write_ninja_file_to_build_library(path,
post_cflags=None,
cuda_cflags=cuda_flags,
cuda_post_cflags=None,
cuda_dlink_post_cflags=None,
sources=sources,
objects=objects,
ldflags=ldflags,
@ -1978,6 +2036,7 @@ def _write_ninja_file(path,
post_cflags,
cuda_cflags,
cuda_post_cflags,
cuda_dlink_post_cflags,
sources,
objects,
ldflags,
@ -2007,6 +2066,7 @@ def _write_ninja_file(path,
post_cflags = sanitize_flags(post_cflags)
cuda_cflags = sanitize_flags(cuda_cflags)
cuda_post_cflags = sanitize_flags(cuda_post_cflags)
cuda_dlink_post_cflags = sanitize_flags(cuda_dlink_post_cflags)
ldflags = sanitize_flags(ldflags)
# Sanity checks...
@ -2021,7 +2081,7 @@ def _write_ninja_file(path,
# Version 1.3 is required for the `deps` directive.
config = ['ninja_required_version = 1.3']
config.append(f'cxx = {compiler}')
if with_cuda:
if with_cuda or cuda_dlink_post_cflags:
if IS_HIP_EXTENSION:
nvcc = _join_rocm_home('bin', 'hipcc')
else:
@ -2035,6 +2095,7 @@ def _write_ninja_file(path,
if with_cuda:
flags.append(f'cuda_cflags = {" ".join(cuda_cflags)}')
flags.append(f'cuda_post_cflags = {" ".join(cuda_post_cflags)}')
flags.append(f'cuda_dlink_post_cflags = {" ".join(cuda_dlink_post_cflags)}')
flags.append(f'ldflags = {" ".join(ldflags)}')
# Turn into absolute paths so we can emit them into the ninja build
@ -2084,6 +2145,15 @@ def _write_ninja_file(path,
object_file = object_file.replace(" ", "$ ")
build.append(f'build {object_file}: {rule} {source_file}')
if cuda_dlink_post_cflags:
devlink_out = os.path.join(os.path.dirname(objects[0]), 'dlink.o')
devlink_rule = ['rule cuda_devlink']
devlink_rule.append(' command = $nvcc $in -o $out $cuda_dlink_post_cflags')
devlink = [f'build {devlink_out}: cuda_devlink {" ".join(objects)}']
objects += [devlink_out]
else:
devlink_rule, devlink = [], []
if library_target is not None:
link_rule = ['rule link']
if IS_WINDOWS:
@ -2107,7 +2177,7 @@ def _write_ninja_file(path,
blocks = [config, flags, compile_rule]
if with_cuda:
blocks.append(cuda_compile_rule)
blocks += [link_rule, build, link, default]
blocks += [devlink_rule, link_rule, build, devlink, link, default]
with open(path, 'w') as build_file:
for block in blocks:
lines = '\n'.join(block)