mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Test TORCH_LIBRARY in CUDA extension (#47524)
Summary: In the [official documentation](https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html), it is recommended to use `TORCH_LIBRARY` to register ops for TorchScript. However, that code is never tested with CUDA extension and is actually broken (https://github.com/pytorch/pytorch/issues/47493). This PR adds a test for it. It will not pass CI now, but it will pass when the issue https://github.com/pytorch/pytorch/issues/47493 is fixed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/47524 Reviewed By: zou3519 Differential Revision: D24991839 Pulled By: ezyang fbshipit-source-id: 037196621c7ff9a6e7905efc1097ff97906a0b1c
This commit is contained in:
parent
cf92b0f3a0
commit
b12d645c2f
3 changed files with 58 additions and 0 deletions
|
|
@ -4,6 +4,7 @@ import os
|
|||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
|
||||
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS
|
||||
|
||||
if sys.platform == 'win32':
|
||||
vc_version = os.getenv('VCToolsVersion', '')
|
||||
|
|
@ -55,6 +56,31 @@ elif torch.cuda.is_available() and ROCM_HOME is not None:
|
|||
])
|
||||
ext_modules.append(extension)
|
||||
|
||||
if not IS_WINDOWS: # MSVC has bug compiling this example
|
||||
if torch.cuda.is_available() and CUDA_HOME is not None:
|
||||
extension = CUDAExtension(
|
||||
'torch_test_cpp_extension.torch_library', [
|
||||
'torch_library.cu'
|
||||
],
|
||||
extra_compile_args={'cxx': CXX_FLAGS,
|
||||
'nvcc': ['-O2']})
|
||||
ext_modules.append(extension)
|
||||
elif torch.cuda.is_available() and ROCM_HOME is not None:
|
||||
from torch.utils.hipify import hipify_python
|
||||
hipify_python.hipify(
|
||||
project_directory=this_dir,
|
||||
output_directory=this_dir,
|
||||
includes="./*",
|
||||
show_detailed=True,
|
||||
is_pytorch_extension=True,)
|
||||
extension = CUDAExtension(
|
||||
'torch_test_cpp_extension.torch_library', [
|
||||
'hip/torch_library.hip'
|
||||
],
|
||||
extra_compile_args={'cxx': CXX_FLAGS,
|
||||
'nvcc': ['-O2']})
|
||||
ext_modules.append(extension)
|
||||
|
||||
setup(
|
||||
name='torch_test_cpp_extension',
|
||||
packages=['torch_test_cpp_extension'],
|
||||
|
|
|
|||
9
test/cpp_extensions/torch_library.cu
Normal file
9
test/cpp_extensions/torch_library.cu
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
#include <torch/extension.h>
|
||||
|
||||
bool logical_and(bool a, bool b) { return a && b; }
|
||||
|
||||
TORCH_LIBRARY(torch_library, m) {
|
||||
m.def("logical_and", &logical_and);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
|
||||
|
|
@ -171,5 +171,28 @@ class TestRNGExtension(common.TestCase):
|
|||
del copy2
|
||||
self.assertEqual(rng_extension.getInstanceCount(), 0)
|
||||
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
|
||||
@unittest.skipIf(IS_WINDOWS, "MSVC have bug compiling this")
|
||||
class TestTorchLibrary(common.TestCase):
|
||||
|
||||
def test_torch_library(self):
|
||||
import torch_test_cpp_extension.torch_library # noqa: F401
|
||||
|
||||
def f(a: bool, b: bool):
|
||||
return torch.ops.torch_library.logical_and(a, b)
|
||||
|
||||
self.assertTrue(f(True, True))
|
||||
self.assertFalse(f(True, False))
|
||||
self.assertFalse(f(False, True))
|
||||
self.assertFalse(f(False, False))
|
||||
s = torch.jit.script(f)
|
||||
self.assertTrue(s(True, True))
|
||||
self.assertFalse(s(True, False))
|
||||
self.assertFalse(s(False, True))
|
||||
self.assertFalse(s(False, False))
|
||||
self.assertIn('torch_library::logical_and', str(s.graph))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common.run_tests()
|
||||
|
|
|
|||
Loading…
Reference in a new issue