mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add support for dotted names in CPP Extensions (#6986)
* Add support for dotted names in CPP Extensions * Modify tests for cpp extensions Test that dotted names work * Py2 fixes * Make run_test cpp_extensions Win-compatible
This commit is contained in:
parent
e6ce1afe47
commit
b240cc9b87
5 changed files with 20 additions and 25 deletions
|
|
@ -5,13 +5,13 @@ from torch.utils.cpp_extension import CUDA_HOME
|
|||
|
||||
ext_modules = [
|
||||
CppExtension(
|
||||
'torch_test_cpp_extension', ['extension.cpp'],
|
||||
'torch_test_cpp_extension.cpp', ['extension.cpp'],
|
||||
extra_compile_args=['-g']),
|
||||
]
|
||||
|
||||
if torch.cuda.is_available() and CUDA_HOME is not None:
|
||||
extension = CUDAExtension(
|
||||
'torch_test_cuda_extension', [
|
||||
'torch_test_cpp_extension.cuda', [
|
||||
'cuda_extension.cpp',
|
||||
'cuda_extension_kernel.cu',
|
||||
'cuda_extension_kernel2.cu',
|
||||
|
|
@ -22,5 +22,6 @@ if torch.cuda.is_available() and CUDA_HOME is not None:
|
|||
|
||||
setup(
|
||||
name='torch_test_cpp_extension',
|
||||
packages=['torch_test_cpp_extension'],
|
||||
ext_modules=ext_modules,
|
||||
cmdclass={'build_ext': torch.utils.cpp_extension.BuildExtension})
|
||||
|
|
|
|||
0
test/cpp_extensions/torch_test_cpp_extension/__init__.py
Normal file
0
test/cpp_extensions/torch_test_cpp_extension/__init__.py
Normal file
|
|
@ -94,27 +94,15 @@ def test_cpp_extensions(python, test_module, test_directory, options):
|
|||
python_path = os.environ.get('PYTHONPATH', '')
|
||||
try:
|
||||
cpp_extensions = os.path.join(test_directory, 'cpp_extensions')
|
||||
if sys.platform == 'win32':
|
||||
install_directory = os.path.join(cpp_extensions, 'install')
|
||||
install_directories = get_shell_output(
|
||||
'where -r "{}" *.pyd'.format(install_directory)).split('\r\n')
|
||||
|
||||
assert install_directories, 'install_directory must not be empty'
|
||||
|
||||
if len(install_directories) >= 1:
|
||||
install_directory = install_directories[0]
|
||||
|
||||
install_directory = os.path.dirname(install_directory)
|
||||
split_char = ';'
|
||||
else:
|
||||
install_directory = get_shell_output(
|
||||
"find {}/install -name *-packages".format(cpp_extensions))
|
||||
split_char = ':'
|
||||
install_directory = ''
|
||||
# install directory is the one that is named site-packages
|
||||
for root, directories, _ in os.walk(os.path.join(cpp_extensions, 'install')):
|
||||
for directory in directories:
|
||||
if '-packages' in directory:
|
||||
install_directory = os.path.join(root, directory)
|
||||
|
||||
assert install_directory, 'install_directory must not be empty'
|
||||
install_directory = os.path.join(test_directory, install_directory)
|
||||
os.environ['PYTHONPATH'] = '{}{}{}'.format(install_directory,
|
||||
split_char, python_path)
|
||||
os.environ['PYTHONPATH'] = os.pathsep.join([install_directory, python_path])
|
||||
return run_test(python, test_module, test_directory, options)
|
||||
finally:
|
||||
os.environ['PYTHONPATH'] = python_path
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ import unittest
|
|||
import torch
|
||||
import torch.utils.cpp_extension
|
||||
try:
|
||||
import torch_test_cpp_extension as cpp_extension
|
||||
except ModuleNotFoundError:
|
||||
import torch_test_cpp_extension.cpp as cpp_extension
|
||||
except ImportError:
|
||||
print("\'test_cpp_extensions.py\' cannot be invoked directly. " +
|
||||
"Run \'python run_test.py -i cpp_extensions\' for the \'test_cpp_extensions.py\' tests.")
|
||||
raise
|
||||
|
|
@ -70,7 +70,7 @@ class TestCppExtension(common.TestCase):
|
|||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
|
||||
def test_cuda_extension(self):
|
||||
import torch_test_cuda_extension as cuda_extension
|
||||
import torch_test_cpp_extension.cuda as cuda_extension
|
||||
|
||||
x = torch.FloatTensor(100).zero_().cuda()
|
||||
y = torch.FloatTensor(100).zero_().cuda()
|
||||
|
|
|
|||
|
|
@ -245,7 +245,13 @@ class BuildExtension(build_ext):
|
|||
check_compiler_abi_compatibility(compiler)
|
||||
|
||||
def _define_torch_extension_name(self, extension):
|
||||
define = '-DTORCH_EXTENSION_NAME={}'.format(extension.name)
|
||||
# pybind11 doesn't support dots in the names
|
||||
# so in order to support extensions in the packages
|
||||
# like torch._C, we take the last part of the string
|
||||
# as the library name
|
||||
names = extension.name.split('.')
|
||||
name = names[-1]
|
||||
define = '-DTORCH_EXTENSION_NAME={}'.format(name)
|
||||
if isinstance(extension.extra_compile_args, dict):
|
||||
for args in extension.extra_compile_args.values():
|
||||
args.append(define)
|
||||
|
|
|
|||
Loading…
Reference in a new issue