diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index 874ddca27b9..3b156b6a736 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #define CUDABLAS_POSINT_CHECK(FD, X) \ TORCH_CHECK( \ @@ -96,7 +97,7 @@ namespace at { namespace cuda { namespace blas { -const char* _cublasGetErrorEnum(cublasStatus_t error) { +C10_EXPORT const char* _cublasGetErrorEnum(cublasStatus_t error) { if (error == CUBLAS_STATUS_SUCCESS) { return "CUBLAS_STATUS_SUCCESS"; } diff --git a/aten/src/ATen/cuda/CUDASolver.cpp b/aten/src/ATen/cuda/CUDASolver.cpp index 48031c72c2f..683f50ea229 100644 --- a/aten/src/ATen/cuda/CUDASolver.cpp +++ b/aten/src/ATen/cuda/CUDASolver.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #ifdef CUDART_VERSION @@ -9,7 +10,7 @@ namespace at { namespace cuda { namespace solver { -const char* cusolverGetErrorMessage(cusolverStatus_t status) { +C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status) { switch (status) { case CUSOLVER_STATUS_SUCCESS: return "CUSOLVER_STATUS_SUCCES"; case CUSOLVER_STATUS_NOT_INITIALIZED: return "CUSOLVER_STATUS_NOT_INITIALIZED"; diff --git a/aten/src/ATen/cuda/Exceptions.h b/aten/src/ATen/cuda/Exceptions.h index a446bd6dbc7..2d1fd05fa2e 100644 --- a/aten/src/ATen/cuda/Exceptions.h +++ b/aten/src/ATen/cuda/Exceptions.h @@ -2,6 +2,7 @@ #include #include +#include #ifdef CUDART_VERSION #include @@ -39,7 +40,7 @@ class CuDNNError : public c10::Error { } while (0) namespace at { namespace cuda { namespace blas { -const char* _cublasGetErrorEnum(cublasStatus_t error); +C10_EXPORT const char* _cublasGetErrorEnum(cublasStatus_t error); }}} // namespace at::cuda::blas #define TORCH_CUDABLAS_CHECK(EXPR) \ @@ -66,7 +67,7 @@ const char *cusparseGetErrorString(cusparseStatus_t status); #ifdef CUDART_VERSION namespace at { namespace cuda { namespace solver { -const char* cusolverGetErrorMessage(cusolverStatus_t status); +C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status); }}} // namespace at::cuda::solver #define TORCH_CUSOLVER_CHECK(EXPR) \ diff --git a/test/cpp_extensions/cublas_extension.cpp b/test/cpp_extensions/cublas_extension.cpp new file mode 100644 index 00000000000..61945b1aa22 --- /dev/null +++ b/test/cpp_extensions/cublas_extension.cpp @@ -0,0 +1,17 @@ +#include + +#include +#include + +#include + +torch::Tensor noop_cublas_function(torch::Tensor x) { + cublasHandle_t handle; + TORCH_CUDABLAS_CHECK(cublasCreate(&handle)); + TORCH_CUDABLAS_CHECK(cublasDestroy(handle)); + return x; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("noop_cublas_function", &noop_cublas_function, "a cublas function"); +} diff --git a/test/cpp_extensions/cusolver_extension.cpp b/test/cpp_extensions/cusolver_extension.cpp new file mode 100644 index 00000000000..515d09958a8 --- /dev/null +++ b/test/cpp_extensions/cusolver_extension.cpp @@ -0,0 +1,17 @@ +#include +#include + +#include + + +torch::Tensor noop_cusolver_function(torch::Tensor x) { + cusolverDnHandle_t handle; + TORCH_CUSOLVER_CHECK(cusolverDnCreate(&handle)); + TORCH_CUSOLVER_CHECK(cusolverDnDestroy(handle)); + return x; +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("noop_cusolver_function", &noop_cusolver_function, "a cusolver function"); +} diff --git a/test/cpp_extensions/setup.py b/test/cpp_extensions/setup.py index 7888d0e3a88..3b25f1e60bb 100644 --- a/test/cpp_extensions/setup.py +++ b/test/cpp_extensions/setup.py @@ -4,6 +4,7 @@ import os from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME +from torch.testing._internal.common_utils import IS_WINDOWS if sys.platform == 'win32': vc_version = os.getenv('VCToolsVersion', '') @@ -48,6 +49,20 @@ if torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None 'nvcc': ['-O2']}) ext_modules.append(extension) +# todo(mkozuki): Figure out the root cause +if (not IS_WINDOWS) and torch.cuda.is_available() and CUDA_HOME is not None: + cublas_extension = CUDAExtension( + name='torch_test_cpp_extension.cublas_extension', + sources=['cublas_extension.cpp'] + ) + ext_modules.append(cublas_extension) + + cusolver_extension = CUDAExtension( + name='torch_test_cpp_extension.cusolver_extension', + sources=['cusolver_extension.cpp'] + ) + ext_modules.append(cusolver_extension) + setup( name='torch_test_cpp_extension', packages=['torch_test_cpp_extension'], diff --git a/test/test_cpp_extensions_aot.py b/test/test_cpp_extensions_aot.py index eaa4899a926..5ca6d34fa6b 100644 --- a/test/test_cpp_extensions_aot.py +++ b/test/test_cpp_extensions_aot.py @@ -82,6 +82,26 @@ class TestCppExtensionAOT(common.TestCase): # 2 * sigmoid(0) = 2 * 0.5 = 1 self.assertEqual(z, torch.ones_like(z)) + @common.skipIfRocm + @unittest.skipIf(common.IS_WINDOWS, "Windows not supported") + @unittest.skipIf(not TEST_CUDA, "CUDA not found") + def test_cublas_extension(self): + from torch_test_cpp_extension import cublas_extension + + x = torch.zeros(100, device="cuda", dtype=torch.float32) + z = cublas_extension.noop_cublas_function(x) + self.assertEqual(z, x) + + @common.skipIfRocm + @unittest.skipIf(common.IS_WINDOWS, "Windows not supported") + @unittest.skipIf(not TEST_CUDA, "CUDA not found") + def test_cusolver_extension(self): + from torch_test_cpp_extension import cusolver_extension + + x = torch.zeros(100, device="cuda", dtype=torch.float32) + z = cusolver_extension.noop_cusolver_function(x) + self.assertEqual(z, x) + @unittest.skipIf(IS_WINDOWS, "Not available on Windows") def test_no_python_abi_suffix_sets_the_correct_library_name(self): # For this test, run_test.py will call `python setup.py install` in the