mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Allow relocatable device code linking in pytorch CUDA extensions (#78225)
Close https://github.com/pytorch/pytorch/issues/57543 Doc: check `Relocatable device code linking:` in https://docs-preview.pytorch.org/78225/cpp_extension.html#torch.utils.cpp_extension.CUDAExtension Pull Request resolved: https://github.com/pytorch/pytorch/pull/78225 Approved by: https://github.com/ezyang, https://github.com/malfet
This commit is contained in:
parent
9446f9678a
commit
ef0332e36d
8 changed files with 153 additions and 2 deletions
21
test/cpp_extensions/cuda_dlink_extension.cpp
Normal file
21
test/cpp_extensions/cuda_dlink_extension.cpp
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
#include <torch/extension.h>
|
||||
|
||||
// Declare the function from cuda_dlink_extension.cu.
|
||||
void add_cuda(const float* a, const float* b, float* output, int size);
|
||||
|
||||
at::Tensor add(at::Tensor a, at::Tensor b) {
|
||||
TORCH_CHECK(a.device().is_cuda(), "a is a cuda tensor");
|
||||
TORCH_CHECK(b.device().is_cuda(), "b is a cuda tensor");
|
||||
TORCH_CHECK(a.dtype() == at::kFloat, "a is a float tensor");
|
||||
TORCH_CHECK(b.dtype() == at::kFloat, "b is a float tensor");
|
||||
TORCH_CHECK(a.sizes() == b.sizes(), "a and b should have same size");
|
||||
|
||||
at::Tensor output = at::empty_like(a);
|
||||
add_cuda(a.data_ptr<float>(), b.data_ptr<float>(), output.data_ptr<float>(), a.numel());
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("add", &add, "a + b");
|
||||
}
|
||||
6
test/cpp_extensions/cuda_dlink_extension_add.cu
Normal file
6
test/cpp_extensions/cuda_dlink_extension_add.cu
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
__device__ void add(const float* a, const float* b, float* output) {
|
||||
*output = *a + *b;
|
||||
}
|
||||
6
test/cpp_extensions/cuda_dlink_extension_add.cuh
Normal file
6
test/cpp_extensions/cuda_dlink_extension_add.cuh
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
__device__ void add(const float* a, const float* b, float* output);
|
||||
22
test/cpp_extensions/cuda_dlink_extension_kernel.cu
Normal file
22
test/cpp_extensions/cuda_dlink_extension_kernel.cu
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include "cuda_dlink_extension_add.cuh"
|
||||
|
||||
__global__ void add_kernel(const float* a, const float* b, float* output, int size) {
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (i < size) {
|
||||
add(a + i, b + i, output + i);
|
||||
}
|
||||
}
|
||||
|
||||
// output = a * b + c
|
||||
void add_cuda(const float* a, const float* b, float* output, int size) {
|
||||
const int threads = 1024;
|
||||
const int blocks = (size + threads - 1) / threads;
|
||||
add_kernel<<<blocks, threads>>>(a, b, output, size);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
|
@ -66,6 +66,19 @@ if (not IS_WINDOWS) and torch.cuda.is_available() and CUDA_HOME is not None:
|
|||
)
|
||||
ext_modules.append(cusolver_extension)
|
||||
|
||||
if USE_NINJA and (not IS_WINDOWS) and torch.cuda.is_available() and CUDA_HOME is not None:
|
||||
extension = CUDAExtension(
|
||||
name='torch_test_cpp_extension.cuda_dlink',
|
||||
sources=[
|
||||
'cuda_dlink_extension.cpp',
|
||||
'cuda_dlink_extension_kernel.cu',
|
||||
'cuda_dlink_extension_add.cu',
|
||||
],
|
||||
dlink=True,
|
||||
extra_compile_args={'cxx': CXX_FLAGS,
|
||||
'nvcc': ['-O2', '-dc']})
|
||||
ext_modules.append(extension)
|
||||
|
||||
setup(
|
||||
name='torch_test_cpp_extension',
|
||||
packages=['torch_test_cpp_extension'],
|
||||
|
|
|
|||
|
|
@ -488,6 +488,7 @@ def _test_cpp_extensions_aot(test_directory, options, use_ninja):
|
|||
python_path = os.environ.get("PYTHONPATH", "")
|
||||
from shutil import copyfile
|
||||
|
||||
os.environ['USE_NINJA'] = shell_env['USE_NINJA']
|
||||
test_module = "test_cpp_extensions_aot" + ("_ninja" if use_ninja else "_no_ninja")
|
||||
copyfile(
|
||||
test_directory + "/test_cpp_extensions_aot.py",
|
||||
|
|
@ -509,6 +510,7 @@ def _test_cpp_extensions_aot(test_directory, options, use_ninja):
|
|||
os.environ["PYTHONPATH"] = python_path
|
||||
if os.path.exists(test_directory + "/" + test_module + ".py"):
|
||||
os.remove(test_directory + "/" + test_module + ".py")
|
||||
os.environ.pop('USE_NINJA')
|
||||
|
||||
|
||||
def test_cpp_extensions_aot_ninja(test_module, test_directory, options):
|
||||
|
|
|
|||
|
|
@ -121,6 +121,17 @@ class TestCppExtensionAOT(common.TestCase):
|
|||
has_value = cpp_extension.function_taking_optional(None)
|
||||
self.assertFalse(has_value)
|
||||
|
||||
@common.skipIfRocm
|
||||
@unittest.skipIf(common.IS_WINDOWS, "Windows not supported")
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
|
||||
@unittest.skipIf(os.getenv('USE_NINJA', '0') == '0', "cuda extension with dlink requires ninja to build")
|
||||
def test_cuda_dlink_libs(self):
|
||||
from torch_test_cpp_extension import cuda_dlink
|
||||
a = torch.randn(8, dtype=torch.float, device='cuda')
|
||||
b = torch.randn(8, dtype=torch.float, device='cuda')
|
||||
ref = a + b
|
||||
test = cuda_dlink.add(a, b)
|
||||
self.assertEqual(test, ref)
|
||||
|
||||
class TestORTTensor(common.TestCase):
|
||||
def test_unregistered(self):
|
||||
|
|
|
|||
|
|
@ -455,6 +455,9 @@ class BuildExtension(build_ext, object):
|
|||
self._define_torch_extension_name(extension)
|
||||
self._add_gnu_cpp_abi_flag(extension)
|
||||
|
||||
if 'nvcc_dlink' in extension.extra_compile_args:
|
||||
assert self.use_ninja, f"With dlink=True, ninja is required to build cuda extension {extension.name}."
|
||||
|
||||
# Register .cu, .cuh and .hip as valid source extensions.
|
||||
self.compiler.src_extensions += ['.cu', '.cuh', '.hip']
|
||||
# Save the original _compile method for later.
|
||||
|
|
@ -583,6 +586,10 @@ class BuildExtension(build_ext, object):
|
|||
cuda_cflags = [shlex.quote(f) for f in cuda_cflags]
|
||||
cuda_post_cflags = [shlex.quote(f) for f in cuda_post_cflags]
|
||||
|
||||
if isinstance(extra_postargs, dict) and 'nvcc_dlink' in extra_postargs:
|
||||
cuda_dlink_post_cflags = unix_cuda_flags(extra_postargs['nvcc_dlink'])
|
||||
else:
|
||||
cuda_dlink_post_cflags = None
|
||||
_write_ninja_file_and_compile_objects(
|
||||
sources=sources,
|
||||
objects=objects,
|
||||
|
|
@ -590,6 +597,7 @@ class BuildExtension(build_ext, object):
|
|||
post_cflags=[shlex.quote(f) for f in post_cflags],
|
||||
cuda_cflags=cuda_cflags,
|
||||
cuda_post_cflags=cuda_post_cflags,
|
||||
cuda_dlink_post_cflags=cuda_dlink_post_cflags,
|
||||
build_directory=output_dir,
|
||||
verbose=True,
|
||||
with_cuda=with_cuda)
|
||||
|
|
@ -734,6 +742,10 @@ class BuildExtension(build_ext, object):
|
|||
if with_cuda:
|
||||
cuda_cflags = _nt_quote_args(cuda_cflags)
|
||||
cuda_post_cflags = _nt_quote_args(cuda_post_cflags)
|
||||
if isinstance(extra_postargs, dict) and 'nvcc_dlink' in extra_postargs:
|
||||
cuda_dlink_post_cflags = win_cuda_flags(extra_postargs['nvcc_dlink'])
|
||||
else:
|
||||
cuda_dlink_post_cflags = None
|
||||
|
||||
_write_ninja_file_and_compile_objects(
|
||||
sources=sources,
|
||||
|
|
@ -742,6 +754,7 @@ class BuildExtension(build_ext, object):
|
|||
post_cflags=post_cflags,
|
||||
cuda_cflags=cuda_cflags,
|
||||
cuda_post_cflags=cuda_post_cflags,
|
||||
cuda_dlink_post_cflags=cuda_dlink_post_cflags,
|
||||
build_directory=output_dir,
|
||||
verbose=True,
|
||||
with_cuda=with_cuda)
|
||||
|
|
@ -978,6 +991,31 @@ def CUDAExtension(name, sources, *args, **kwargs):
|
|||
Currently open issue for nvcc bug: https://github.com/pytorch/pytorch/issues/69460
|
||||
Complete workaround code example: https://github.com/facebookresearch/pytorch3d/commit/cb170ac024a949f1f9614ffe6af1c38d972f7d48
|
||||
|
||||
Relocatable device code linking:
|
||||
|
||||
If you want to reference device symbols across compilation units (across object files),
|
||||
the object files need to be built with `relocatable device code` (-rdc=true or -dc).
|
||||
An exception to this rule is "dynamic parallelism" (nested kernel launches) which is not used a lot anymore.
|
||||
`Relocatable device code` is less optimized so it needs to be used only on object files that need it.
|
||||
Using `-dlto` (Device Link Time Optimization) at the device code compilation step and `dlink` step
|
||||
help reduce the protentional perf degradation of `-rdc`.
|
||||
Note that it needs to be used at both steps to be useful.
|
||||
|
||||
If you have `rdc` objects you need to have an extra `-dlink` (device linking) step before the CPU symbol linking step.
|
||||
There is also a case where `-dlink` is used without `-rdc`:
|
||||
when an extension is linked against a static lib containing rdc-compiled objects
|
||||
like the [NVSHMEM library](https://developer.nvidia.com/nvshmem).
|
||||
|
||||
Note: Ninja is required to build a CUDA Extension with RDC linking.
|
||||
|
||||
Example:
|
||||
>>> CUDAExtension(
|
||||
name='cuda_extension',
|
||||
sources=['extension.cpp', 'extension_kernel.cu'],
|
||||
dlink=True,
|
||||
dlink_libraries=["dlink_lib"],
|
||||
extra_compile_args={'cxx': ['-g'],
|
||||
'nvcc': ['-O2', '-rdc=true']})
|
||||
'''
|
||||
library_dirs = kwargs.get('library_dirs', [])
|
||||
library_dirs += library_paths(cuda=True)
|
||||
|
|
@ -1031,6 +1069,23 @@ def CUDAExtension(name, sources, *args, **kwargs):
|
|||
|
||||
kwargs['language'] = 'c++'
|
||||
|
||||
dlink_libraries = kwargs.get('dlink_libraries', [])
|
||||
dlink = kwargs.get('dlink', False) or dlink_libraries
|
||||
if dlink:
|
||||
extra_compile_args = kwargs.get('extra_compile_args', {})
|
||||
|
||||
extra_compile_args_dlink = extra_compile_args.get('nvcc_dlink', [])
|
||||
extra_compile_args_dlink += ['-dlink']
|
||||
extra_compile_args_dlink += [f'-L{x}' for x in library_dirs]
|
||||
extra_compile_args_dlink += [f'-l{x}' for x in dlink_libraries]
|
||||
|
||||
if (torch.version.cuda is not None) and packaging.version.parse(torch.version.cuda) >= packaging.version.parse('11.2'):
|
||||
extra_compile_args_dlink += ['-dlto'] # Device Link Time Optimization started from cuda 11.2
|
||||
|
||||
extra_compile_args['nvcc_dlink'] = extra_compile_args_dlink
|
||||
|
||||
kwargs['extra_compile_args'] = extra_compile_args
|
||||
|
||||
return setuptools.Extension(name, sources, *args, **kwargs)
|
||||
|
||||
|
||||
|
|
@ -1457,6 +1512,7 @@ def _write_ninja_file_and_compile_objects(
|
|||
post_cflags,
|
||||
cuda_cflags,
|
||||
cuda_post_cflags,
|
||||
cuda_dlink_post_cflags,
|
||||
build_directory: str,
|
||||
verbose: bool,
|
||||
with_cuda: Optional[bool]) -> None:
|
||||
|
|
@ -1477,6 +1533,7 @@ def _write_ninja_file_and_compile_objects(
|
|||
post_cflags=post_cflags,
|
||||
cuda_cflags=cuda_cflags,
|
||||
cuda_post_cflags=cuda_post_cflags,
|
||||
cuda_dlink_post_cflags=cuda_dlink_post_cflags,
|
||||
sources=sources,
|
||||
objects=objects,
|
||||
ldflags=None,
|
||||
|
|
@ -1966,6 +2023,7 @@ def _write_ninja_file_to_build_library(path,
|
|||
post_cflags=None,
|
||||
cuda_cflags=cuda_flags,
|
||||
cuda_post_cflags=None,
|
||||
cuda_dlink_post_cflags=None,
|
||||
sources=sources,
|
||||
objects=objects,
|
||||
ldflags=ldflags,
|
||||
|
|
@ -1978,6 +2036,7 @@ def _write_ninja_file(path,
|
|||
post_cflags,
|
||||
cuda_cflags,
|
||||
cuda_post_cflags,
|
||||
cuda_dlink_post_cflags,
|
||||
sources,
|
||||
objects,
|
||||
ldflags,
|
||||
|
|
@ -2007,6 +2066,7 @@ def _write_ninja_file(path,
|
|||
post_cflags = sanitize_flags(post_cflags)
|
||||
cuda_cflags = sanitize_flags(cuda_cflags)
|
||||
cuda_post_cflags = sanitize_flags(cuda_post_cflags)
|
||||
cuda_dlink_post_cflags = sanitize_flags(cuda_dlink_post_cflags)
|
||||
ldflags = sanitize_flags(ldflags)
|
||||
|
||||
# Sanity checks...
|
||||
|
|
@ -2021,7 +2081,7 @@ def _write_ninja_file(path,
|
|||
# Version 1.3 is required for the `deps` directive.
|
||||
config = ['ninja_required_version = 1.3']
|
||||
config.append(f'cxx = {compiler}')
|
||||
if with_cuda:
|
||||
if with_cuda or cuda_dlink_post_cflags:
|
||||
if IS_HIP_EXTENSION:
|
||||
nvcc = _join_rocm_home('bin', 'hipcc')
|
||||
else:
|
||||
|
|
@ -2035,6 +2095,7 @@ def _write_ninja_file(path,
|
|||
if with_cuda:
|
||||
flags.append(f'cuda_cflags = {" ".join(cuda_cflags)}')
|
||||
flags.append(f'cuda_post_cflags = {" ".join(cuda_post_cflags)}')
|
||||
flags.append(f'cuda_dlink_post_cflags = {" ".join(cuda_dlink_post_cflags)}')
|
||||
flags.append(f'ldflags = {" ".join(ldflags)}')
|
||||
|
||||
# Turn into absolute paths so we can emit them into the ninja build
|
||||
|
|
@ -2084,6 +2145,15 @@ def _write_ninja_file(path,
|
|||
object_file = object_file.replace(" ", "$ ")
|
||||
build.append(f'build {object_file}: {rule} {source_file}')
|
||||
|
||||
if cuda_dlink_post_cflags:
|
||||
devlink_out = os.path.join(os.path.dirname(objects[0]), 'dlink.o')
|
||||
devlink_rule = ['rule cuda_devlink']
|
||||
devlink_rule.append(' command = $nvcc $in -o $out $cuda_dlink_post_cflags')
|
||||
devlink = [f'build {devlink_out}: cuda_devlink {" ".join(objects)}']
|
||||
objects += [devlink_out]
|
||||
else:
|
||||
devlink_rule, devlink = [], []
|
||||
|
||||
if library_target is not None:
|
||||
link_rule = ['rule link']
|
||||
if IS_WINDOWS:
|
||||
|
|
@ -2107,7 +2177,7 @@ def _write_ninja_file(path,
|
|||
blocks = [config, flags, compile_rule]
|
||||
if with_cuda:
|
||||
blocks.append(cuda_compile_rule)
|
||||
blocks += [link_rule, build, link, default]
|
||||
blocks += [devlink_rule, link_rule, build, devlink, link, default]
|
||||
with open(path, 'w') as build_file:
|
||||
for block in blocks:
|
||||
lines = '\n'.join(block)
|
||||
|
|
|
|||
Loading…
Reference in a new issue