diff --git a/BUILD.bazel b/BUILD.bazel index 6aacd560c40..a8ea7988a24 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -723,6 +723,7 @@ torch_cuda_half_options = [ "-DCUDA_HAS_FP16=1", "-D__CUDA_NO_HALF_OPERATORS__", "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", "-D__CUDA_NO_HALF2_OPERATORS__", ] diff --git a/c10/util/BFloat16-inl.h b/c10/util/BFloat16-inl.h index da6ce385955..57e2a69b86f 100644 --- a/c10/util/BFloat16-inl.h +++ b/c10/util/BFloat16-inl.h @@ -7,15 +7,44 @@ namespace c10 { /// Constructors inline C10_HOST_DEVICE BFloat16::BFloat16(float value) { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + x = __bfloat16_as_ushort(__float2bfloat16(value)); +#else // RNE by default x = detail::round_to_nearest_even(value); +#endif } /// Implicit conversions inline C10_HOST_DEVICE BFloat16::operator float() const { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 + return __bfloat162float(*reinterpret_cast(&x)); +#else return detail::f32_from_bits(x); +#endif } +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 +inline C10_HOST_DEVICE BFloat16::BFloat16(const __nv_bfloat16& value) { + x = *reinterpret_cast(&value); +} +inline C10_HOST_DEVICE BFloat16::operator __nv_bfloat16() const { + return *reinterpret_cast(&x); +} +#endif + +// CUDA intrinsics + +#if defined(__CUDACC__) || defined(__HIPCC__) +inline C10_DEVICE BFloat16 __ldg(const BFloat16* ptr) { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return __ldg(reinterpret_cast(ptr)); +#else + return *ptr; +#endif +} +#endif + /// Arithmetic inline C10_HOST_DEVICE BFloat16 operator+(const BFloat16& a, const BFloat16& b) { diff --git a/c10/util/BFloat16.h b/c10/util/BFloat16.h index 375b1086e07..0bd115d568f 100644 --- a/c10/util/BFloat16.h +++ b/c10/util/BFloat16.h @@ -7,6 +7,10 @@ #include #include +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 +#include +#endif + namespace c10 { namespace detail { @@ -84,6 +88,11 @@ struct alignas(2) BFloat16 { constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t) : x(bits){}; inline C10_HOST_DEVICE BFloat16(float value); inline C10_HOST_DEVICE operator float() const; + +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 + inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value); + explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const; +#endif }; } // namespace c10 diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 023bbe9e8d0..1bbb98fb361 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1504,7 +1504,8 @@ if(NOT INTERN_BUILD_MOBILE) if(CUDA_HAS_FP16 OR NOT ${CUDA_VERSION} LESS 7.5) message(STATUS "Found CUDA with FP16 support, compiling with torch.cuda.HalfTensor") - list(APPEND CUDA_NVCC_FLAGS "-DCUDA_HAS_FP16=1" "-D__CUDA_NO_HALF_OPERATORS__" "-D__CUDA_NO_HALF_CONVERSIONS__" "-D__CUDA_NO_HALF2_OPERATORS__") + list(APPEND CUDA_NVCC_FLAGS "-DCUDA_HAS_FP16=1" "-D__CUDA_NO_HALF_OPERATORS__" "-D__CUDA_NO_HALF_CONVERSIONS__" + "-D__CUDA_NO_BFLOAT16_CONVERSIONS__" "-D__CUDA_NO_HALF2_OPERATORS__") add_compile_options(-DCUDA_HAS_FP16=1) else() message(STATUS "Could not find CUDA with FP16 support, compiling without torch.CudaHalfTensor") diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index 49fb7988c0d..feecc39acd8 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -152,6 +152,7 @@ MSVC_IGNORE_CUDAFE_WARNINGS = [ COMMON_NVCC_FLAGS = [ '-D__CUDA_NO_HALF_OPERATORS__', '-D__CUDA_NO_HALF_CONVERSIONS__', + '-D__CUDA_NO_BFLOAT16_CONVERSIONS__', '-D__CUDA_NO_HALF2_OPERATORS__', '--expt-relaxed-constexpr' ]