Add support for dotted names in CPP Extensions (#6986)

* Add support for dotted names in CPP Extensions

* Modify tests for cpp extensions

Test that dotted names work

* Py2 fixes

* Make run_test cpp_extensions Win-compatible
This commit is contained in:
Francisco Massa 2018-04-29 18:10:03 +02:00 committed by GitHub
parent e6ce1afe47
commit b240cc9b87
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 20 additions and 25 deletions

View file

@ -5,13 +5,13 @@ from torch.utils.cpp_extension import CUDA_HOME
ext_modules = [
CppExtension(
'torch_test_cpp_extension', ['extension.cpp'],
'torch_test_cpp_extension.cpp', ['extension.cpp'],
extra_compile_args=['-g']),
]
if torch.cuda.is_available() and CUDA_HOME is not None:
extension = CUDAExtension(
'torch_test_cuda_extension', [
'torch_test_cpp_extension.cuda', [
'cuda_extension.cpp',
'cuda_extension_kernel.cu',
'cuda_extension_kernel2.cu',
@ -22,5 +22,6 @@ if torch.cuda.is_available() and CUDA_HOME is not None:
setup(
name='torch_test_cpp_extension',
packages=['torch_test_cpp_extension'],
ext_modules=ext_modules,
cmdclass={'build_ext': torch.utils.cpp_extension.BuildExtension})

View file

@ -94,27 +94,15 @@ def test_cpp_extensions(python, test_module, test_directory, options):
python_path = os.environ.get('PYTHONPATH', '')
try:
cpp_extensions = os.path.join(test_directory, 'cpp_extensions')
if sys.platform == 'win32':
install_directory = os.path.join(cpp_extensions, 'install')
install_directories = get_shell_output(
'where -r "{}" *.pyd'.format(install_directory)).split('\r\n')
assert install_directories, 'install_directory must not be empty'
if len(install_directories) >= 1:
install_directory = install_directories[0]
install_directory = os.path.dirname(install_directory)
split_char = ';'
else:
install_directory = get_shell_output(
"find {}/install -name *-packages".format(cpp_extensions))
split_char = ':'
install_directory = ''
# install directory is the one that is named site-packages
for root, directories, _ in os.walk(os.path.join(cpp_extensions, 'install')):
for directory in directories:
if '-packages' in directory:
install_directory = os.path.join(root, directory)
assert install_directory, 'install_directory must not be empty'
install_directory = os.path.join(test_directory, install_directory)
os.environ['PYTHONPATH'] = '{}{}{}'.format(install_directory,
split_char, python_path)
os.environ['PYTHONPATH'] = os.pathsep.join([install_directory, python_path])
return run_test(python, test_module, test_directory, options)
finally:
os.environ['PYTHONPATH'] = python_path

View file

@ -3,8 +3,8 @@ import unittest
import torch
import torch.utils.cpp_extension
try:
import torch_test_cpp_extension as cpp_extension
except ModuleNotFoundError:
import torch_test_cpp_extension.cpp as cpp_extension
except ImportError:
print("\'test_cpp_extensions.py\' cannot be invoked directly. " +
"Run \'python run_test.py -i cpp_extensions\' for the \'test_cpp_extensions.py\' tests.")
raise
@ -70,7 +70,7 @@ class TestCppExtension(common.TestCase):
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
def test_cuda_extension(self):
import torch_test_cuda_extension as cuda_extension
import torch_test_cpp_extension.cuda as cuda_extension
x = torch.FloatTensor(100).zero_().cuda()
y = torch.FloatTensor(100).zero_().cuda()

View file

@ -245,7 +245,13 @@ class BuildExtension(build_ext):
check_compiler_abi_compatibility(compiler)
def _define_torch_extension_name(self, extension):
define = '-DTORCH_EXTENSION_NAME={}'.format(extension.name)
# pybind11 doesn't support dots in the names
# so in order to support extensions in the packages
# like torch._C, we take the last part of the string
# as the library name
names = extension.name.split('.')
name = names[-1]
define = '-DTORCH_EXTENSION_NAME={}'.format(name)
if isinstance(extension.extra_compile_args, dict):
for args in extension.extra_compile_args.values():
args.append(define)