From 9ee506bd938b260e3cbd1f6a1c04bd136768f954 Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Thu, 6 Feb 2025 19:04:48 +0000 Subject: [PATCH] [CUDA][cuBLAS] Add fp16 accumulate option to cuBLAS/cuBLASLt (#144441) Test for `cublasGemmEx` added, still need to figure out the best way to exercise the other APIs... Pull Request resolved: https://github.com/pytorch/pytorch/pull/144441 Approved by: https://github.com/Chillee, https://github.com/malfet --- aten/src/ATen/Context.cpp | 9 +++- aten/src/ATen/Context.h | 3 ++ aten/src/ATen/cuda/CUDABlas.cpp | 89 +++++++++++++++++++++++++++------ docs/source/notes/cuda.rst | 25 +++++++++ test/test_cuda.py | 7 +++ test/test_matmul_cuda.py | 41 ++++++++++++++- torch/_C/__init__.pyi.in | 4 ++ torch/backends/cuda/__init__.py | 4 ++ torch/csrc/Module.cpp | 31 ++++++++++++ 9 files changed, 196 insertions(+), 17 deletions(-) diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index db5380ac961..4f269f8ae71 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -394,7 +394,6 @@ void Context::setROCmFAPreferredBackend(at::ROCmFABackend b) { rocm_fa_preferred_backend = b; } - bool Context::allowFP16ReductionCuBLAS() const { return allow_fp16_reduction_cublas; } @@ -411,6 +410,14 @@ void Context::setAllowBF16ReductionCuBLAS(bool b) { allow_bf16_reduction_cublas = b; } +bool Context::allowFP16AccumulationCuBLAS() const { + return allow_fp16_accumulation_cublas; +} + +void Context::setAllowFP16AccumulationCuBLAS(bool b) { + allow_fp16_accumulation_cublas = b; +} + bool Context::hasMKL() { #if AT_MKL_ENABLED() diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index d7e5079e264..b03615279ae 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -341,6 +341,8 @@ class TORCH_API Context { void setAllowFP16ReductionCuBLAS(bool); bool allowBF16ReductionCuBLAS() const; void setAllowBF16ReductionCuBLAS(bool); + bool allowFP16AccumulationCuBLAS() const; + void setAllowFP16AccumulationCuBLAS(bool); at::QEngine qEngine() const; void setQEngine(at::QEngine e); static const std::vector& supportedQEngines(); @@ -418,6 +420,7 @@ class TORCH_API Context { bool allow_tf32_cudnn = true; bool allow_fp16_reduction_cublas = true; bool allow_bf16_reduction_cublas = true; + bool allow_fp16_accumulation_cublas = false; bool enabled_mkldnn = true; bool enabled_nnpack = true; at::LinalgBackend linalg_preferred_backend = diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index 6f2e7a8315a..4e4f4aa2126 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -343,6 +343,12 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { cudaDataType_t abcType = CUDA_R_32F; cublasComputeType_t computeType = CUBLAS_COMPUTE_32F; cudaDataType_t scaleType = CUDA_R_32F; +#ifndef USE_ROCM + at::Half halpha; + at::Half hbeta; +#endif + void * alpha_ptr = α + void * beta_ptr = β if constexpr (std::is_same_v) { abcType = CUDA_R_64F; computeType = CUBLAS_COMPUTE_64F; @@ -359,6 +365,16 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { abcType = CUDA_C_32F; scaleType = CUDA_C_32F; } else if constexpr (std::is_same_v) { +#ifndef USE_ROCM + cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); + if (prop->major >= 7 && at::globalContext().allowFP16AccumulationCuBLAS()) { + computeType = CUBLAS_COMPUTE_16F; + halpha = alpha; + hbeta = beta; + alpha_ptr = &halpha; + beta_ptr = &hbeta; + } +#endif abcType = CUDA_R_16F; } else if constexpr (std::is_same_v) { abcType = CUDA_R_16BF; @@ -437,12 +453,12 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { cublasStatus_t cublasStatus = cublasLtMatmul( ltHandle, computeDesc.descriptor(), - &alpha, + alpha_ptr, a, Adesc.descriptor(), b, Bdesc.descriptor(), - &beta, + beta_ptr, c, Cdesc.descriptor(), c, @@ -552,6 +568,13 @@ void bgemm_internal_cublas(CUDABLAS_BGEMM_ARGTYPES(at::Half)) { BGEMM_CHECK_ARGVALUES(at::Half); float falpha = alpha; float fbeta = beta; +#ifndef USE_ROCM + at::Half halpha; + at::Half hbeta; + auto compute_type = CUDA_R_32F; +#endif + void * alpha_ptr = &falpha; + void * beta_ptr = &fbeta; #ifdef USE_ROCM int flag = 0; #if USE_GEMM_FLAGS_FP16_ALT_IMPL @@ -560,21 +583,28 @@ void bgemm_internal_cublas(CUDABLAS_BGEMM_ARGTYPES(at::Half)) { TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_strided_batched_ex((rocblas_handle)handle, hipOperationToRocOperation(opa), hipOperationToRocOperation(opb), (int)m, (int)n, (int)k, - (void*)&falpha, a, rocblas_datatype_f16_r, (int)lda, stridea, + (void*)alpha_ptr, a, rocblas_datatype_f16_r, (int)lda, stridea, b, rocblas_datatype_f16_r, (int)ldb, strideb, - (void*)&fbeta, c, rocblas_datatype_f16_r, (int)ldc, stridec, + (void*)beta_ptr, c, rocblas_datatype_f16_r, (int)ldc, stridec, c, rocblas_datatype_f16_r, (int)ldc, stridec, (int) num_batches, rocblas_datatype_f32_r, rocblas_gemm_algo_standard, 0, flag))); #else cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); + if (prop->major >= 7 && at::globalContext().allowFP16AccumulationCuBLAS()) { + halpha = alpha; + hbeta = beta; + compute_type = CUDA_R_16F; + alpha_ptr = &halpha; + beta_ptr = &hbeta; + } if (prop->major >= 5){ TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedEx( handle, opa, opb, m, n, k, - (void*)(&falpha), a, CUDA_R_16F, lda, stridea, - b, CUDA_R_16F, ldb, strideb, (void*)(&fbeta), + alpha_ptr, a, CUDA_R_16F, lda, stridea, + b, CUDA_R_16F, ldb, strideb, beta_ptr, c, CUDA_R_16F, ldc, stridec, - num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + num_batches, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } else { for (const auto i : c10::irange(num_batches)) { at::cuda::blas::gemm( @@ -889,6 +919,13 @@ void gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES(at::Half)) { cublasOperation_t opb = _cublasOpFromChar(transb); float falpha = alpha; float fbeta = beta; +#ifndef USE_ROCM + at::Half halpha; + at::Half hbeta; + auto compute_type = CUDA_R_32F; +#endif + void * alpha_ptr = &falpha; + void * beta_ptr = &fbeta; _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); GEMM_CHECK_ARGVALUES(at::Half); #ifdef USE_ROCM @@ -903,14 +940,14 @@ void gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES(at::Half)) { m, n, k, - &falpha, + alpha_ptr, a, rocblas_datatype_f16_r, lda, b, rocblas_datatype_f16_r, ldb, - &fbeta, + beta_ptr, c, rocblas_datatype_f16_r, ldc, @@ -923,6 +960,13 @@ void gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES(at::Half)) { flag))); #else cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); + if (prop->major >= 7 && at::globalContext().allowFP16AccumulationCuBLAS()) { + compute_type = CUDA_R_16F; + halpha = alpha; + hbeta = beta; + alpha_ptr = &halpha; + beta_ptr = &hbeta; + } if (prop->major >= 5) { cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH; if (!at::globalContext().allowFP16ReductionCuBLAS()) { @@ -937,18 +981,18 @@ void gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES(at::Half)) { m, n, k, - &falpha, + alpha_ptr, a, CUDA_R_16F, lda, b, CUDA_R_16F, ldb, - &fbeta, + beta_ptr, c, CUDA_R_16F, ldc, - CUDA_R_32F, + compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); } else { @@ -1250,6 +1294,12 @@ void gemm_and_bias( cudaDataType_t abcType = CUDA_R_32F; cublasComputeType_t computeType = CUBLAS_COMPUTE_32F; cudaDataType_t scaleType = CUDA_R_32F; + void * alpha_ptr = &alpha_val; + void * beta_ptr = &beta_val; +#ifndef USE_ROCM + at::Half halpha_val; + at::Half hbeta_val; +#endif if constexpr (std::is_same_v) { abcType = CUDA_R_64F; computeType = CUBLAS_COMPUTE_64F; @@ -1260,6 +1310,17 @@ void gemm_and_bias( } abcType = CUDA_R_32F; } else if constexpr (std::is_same_v) { +#ifndef USE_ROCM + cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); + if (prop->major >= 7 && at::globalContext().allowFP16AccumulationCuBLAS()) { + computeType = CUBLAS_COMPUTE_16F; + scaleType = CUDA_R_16F; + halpha_val = alpha_val; + hbeta_val = beta_val; + alpha_ptr = &halpha_val; + beta_ptr = &hbeta_val; + } +#endif abcType = CUDA_R_16F; } else if constexpr (std::is_same_v) { abcType = CUDA_R_16BF; @@ -1342,12 +1403,12 @@ void gemm_and_bias( cublasStatus_t cublasStatus = cublasLtMatmul( ltHandle, computeDesc.descriptor(), - &alpha_val, + alpha_ptr, mat1_ptr, Adesc.descriptor(), mat2_ptr, Bdesc.descriptor(), - &beta_val, + beta_ptr, result_ptr, Cdesc.descriptor(), result_ptr, diff --git a/docs/source/notes/cuda.rst b/docs/source/notes/cuda.rst index 74d0c89387f..e5c13dd924c 100644 --- a/docs/source/notes/cuda.rst +++ b/docs/source/notes/cuda.rst @@ -148,6 +148,9 @@ For more information about TF32, see: Reduced Precision Reduction in FP16 GEMMs ----------------------------------------- +(Distinct from full FP16 accumulation that is intended for hardware that has higher throughput +with FP16 accumulation than FP32 accumulation, see :ref:`Full FP16 accumulation`) + fp16 GEMMs are potentially done with some intermediate reduced precision reductions (e.g., in fp16 rather than fp32). These selective reductions in precision can allow for higher performance on certain workloads (particularly those with a large `k` dimension) and GPU architectures at the cost of numerical precision and potential for overflow. Some example benchmark data on V100: @@ -206,6 +209,28 @@ To toggle the reduced precision reduction flags in C++, one can do at::globalContext().setAllowBF16ReductionCuBLAS(true); +.. _fp16accumulation: + +Full FP16 Accmumulation in FP16 GEMMs +------------------------------------- + +Certain GPUs have increased performance when doing _all_ FP16 GEMM accumulation +in FP16, at the cost of numerical precision and greater likelihood of overflow. +Note that this setting only has an effect on GPUs of compute capability 7.0 (Volta) +or newer. + +This behavior can be enabled via: + +.. code:: python + + torch.backends.cuda.matmul.allow_fp16_accumulation = True + +To toggle the reduced precision reduction flags in C++, one can do + +.. code:: C++ + + at::globalContext().setAllowFP16AccumulationCuBLAS(true); + Asynchronous execution ---------------------- diff --git a/test/test_cuda.py b/test/test_cuda.py index 3d28ed24263..04de04ec2e4 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -672,6 +672,13 @@ class TestCuda(TestCase): ) torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig + def test_cublas_allow_fp16_accumulation_get_set(self): + orig = torch.backends.cuda.matmul.allow_fp16_accumulation + self.assertEqual(torch._C._get_cublas_allow_fp16_accumulation(), orig) + torch.backends.cuda.matmul.allow_fp16_accumulation = not orig + self.assertEqual(torch._C._get_cublas_allow_fp16_accumulation(), not orig) + torch.backends.cuda.matmul.allow_fp16_accumulation = orig + def test_cudnn_allow_tf32_get_set(self): with torch.backends.cudnn.flags( enabled=None, benchmark=None, deterministic=None, allow_tf32=False diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index e0eaa52f093..940b9a98357 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -35,10 +35,10 @@ from torch.testing._internal.common_utils import ( IS_WINDOWS, parametrize, run_tests, + skipIfRocm, skipIfRocmVersionLessThan, TEST_CUDA, TEST_WITH_ROCM, - skipIfRocm, TestCase, ) @@ -60,7 +60,7 @@ class TestMatmulCuda(TestCase): torch.backends.cuda.matmul.allow_tf32 = True super(self.__class__, self).tearDown() - def cublas_addmm(self, size: int, dtype: torch.dtype, reduced_precision: bool = False): + def cublas_addmm(self, size: int, dtype: torch.dtype, reduced_precision: bool = False, fp16_accumulate: bool = False): # # Check for catastrophic cuBLAS inaccuracy by measuring the deviation between # results from the CUDA invocation of torch.addmm and the CPU invocation @@ -72,8 +72,10 @@ class TestMatmulCuda(TestCase): # which fail the threshold check orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction orig_fp16 = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction + orig_fp16_accumulate = torch.backends.cuda.matmul.allow_fp16_accumulation torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = reduced_precision torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = reduced_precision + torch.backends.cuda.matmul.allow_fp16_accumulation = fp16_accumulate # Make random tensors on CPU (seed set on common_utils.py import) # (Not using numpy because it does not support bfloat16) make_arg = partial(make_tensor, dtype=dtype, device="cpu") @@ -81,6 +83,10 @@ class TestMatmulCuda(TestCase): m_input = make_arg((n, p)) m_1 = make_arg((n, m)) m_2 = make_arg((m, p)) + # scale to abate overflows in fp16 accum + if fp16_accumulate: + m_1 = m_1 / 100 + m_2 = m_2 / 100 # *(B)FLOAT16 Special Handling* # Backend does not tensorize float16 on CPU, # and bloat16 may present accuracy issues, @@ -114,6 +120,7 @@ class TestMatmulCuda(TestCase): self.assertEqual(res_cpu, res_cuda) torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig_bf16 torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_fp16 + torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accumulate @onlyCUDA @skipIfRocmVersionLessThan((5, 2)) @@ -136,6 +143,36 @@ class TestMatmulCuda(TestCase): def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype): self.cublas_addmm(size, dtype, True) + @onlyCUDA + @skipIfRocmVersionLessThan((5, 2)) + # imported 'tol' as 'xtol' to avoid aliasing in code above + @toleranceOverride({torch.float16: xtol(atol=7e-1, rtol=2e-1), + torch.bfloat16: xtol(atol=1e1, rtol=2e-1)}) + @dtypes(torch.float16, torch.bfloat16) + @parametrize("size", [100, 1000, 10000]) + def test_cublas_addmm_reduced_precision_fp16_accumulate(self, size: int, dtype: torch.dtype): + self.cublas_addmm(size, dtype, False, True) + + @onlyCUDA + @skipIfRocm + def test_cublas_and_lt_reduced_precision_fp16_accumulate(self): + orig_fp16_accumulate = torch.backends.cuda.matmul.allow_fp16_accumulation + torch.backends.cuda.matmul.allow_fp16_accumulation = True + x = torch.rand(32, 512, 512, device='cuda', dtype=torch.half) + w = torch.rand(512, 512, device='cuda', dtype=torch.half) + b = torch.rand(512, device='cuda', dtype=torch.half) + out = torch.nn.functional.linear(x, w, b) + out_cpu = torch.nn.functional.linear(x.cpu(), w.cpu(), b.cpu()) + self.assertEqual(out, out_cpu, atol=5e-3, rtol=8e-3) + + a = torch.rand(16, 128, 128, device='cuda', dtype=torch.half) + b = torch.rand(16, 128, 128, device='cuda', dtype=torch.half) + c = torch.rand(16, 128, 128, device='cuda', dtype=torch.half) + out = torch.baddbmm(a, b, c) + out_cpu = torch.baddbmm(a.cpu(), b.cpu(), c.cpu()) + self.assertEqual(out, out_cpu, atol=1e-3, rtol=5e-3) + torch.backends.cuda.matmul.allow_fp16_accumulation = orig_fp16_accumulate + @onlyCUDA @toleranceOverride({torch.float16: xtol(atol=1e-3, rtol=2e-3)}) @dtypes(torch.float16) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index eed4641d64f..1f0b0ac33fb 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1210,6 +1210,10 @@ def _get_cublas_allow_bf16_reduced_precision_reduction() -> _bool: ... # THPMod def _set_cublas_allow_bf16_reduced_precision_reduction( arg: _bool, ) -> None: ... # THPModule_setAllowBF16ReductionCuBLAS +def _get_cublas_allow_fp16_accumulation() -> _bool: ... # THPModule_allowFP16AccumulationCuBLAS +def _set_cublas_allow_fp16_accumulation( + arg: _bool, +) -> None: ... # THPModule_setAllowFP16AccumulationCuBLAS def _set_conj(x: Tensor, conj: _bool) -> None: ... def _set_neg(x: Tensor, neg: _bool) -> None: ... def _set_meta_in_tls_dispatch_include(meta_in_tls: _bool) -> None: ... diff --git a/torch/backends/cuda/__init__.py b/torch/backends/cuda/__init__.py index b305819c1b0..0e4ed644010 100644 --- a/torch/backends/cuda/__init__.py +++ b/torch/backends/cuda/__init__.py @@ -133,6 +133,8 @@ class cuBLASModule: return torch._C._get_cublas_allow_fp16_reduced_precision_reduction() elif name == "allow_bf16_reduced_precision_reduction": return torch._C._get_cublas_allow_bf16_reduced_precision_reduction() + elif name == "allow_fp16_accumulation": + return torch._C._get_cublas_allow_fp16_accumulation() raise AttributeError("Unknown attribute " + name) def __setattr__(self, name, value): @@ -142,6 +144,8 @@ class cuBLASModule: return torch._C._set_cublas_allow_fp16_reduced_precision_reduction(value) elif name == "allow_bf16_reduced_precision_reduction": return torch._C._set_cublas_allow_bf16_reduced_precision_reduction(value) + elif name == "allow_fp16_accumulation": + return torch._C._set_cublas_allow_fp16_accumulation(value) raise AttributeError("Unknown attribute " + name) diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 2230b15aeb3..59c2e96be54 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -1133,6 +1133,29 @@ static PyObject* THPModule_allowBF16ReductionCuBLAS( Py_RETURN_FALSE; } +static PyObject* THPModule_setAllowFP16AccumulationCuBLAS( + PyObject* _unused, + PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + PyBool_Check(arg), + "set_allow_fp16_accumulation_cublas expects a bool, " + "but got ", + THPUtils_typename(arg)); + at::globalContext().setAllowFP16AccumulationCuBLAS(arg == Py_True); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +static PyObject* THPModule_allowFP16AccumulationCuBLAS( + PyObject* _unused, + PyObject* noargs) { + if (at::globalContext().allowFP16AccumulationCuBLAS()) { + Py_RETURN_TRUE; + } + Py_RETURN_FALSE; +} + static PyObject* THPModule_setAllowFP16ReductionCPU( PyObject* _unused, PyObject* arg) { @@ -1574,6 +1597,14 @@ static std::initializer_list TorchMethods = { THPModule_setAllowBF16ReductionCuBLAS, METH_O, nullptr}, + {"_get_cublas_allow_fp16_accumulation", + THPModule_allowFP16AccumulationCuBLAS, + METH_NOARGS, + nullptr}, + {"_set_cublas_allow_fp16_accumulation", + THPModule_setAllowFP16AccumulationCuBLAS, + METH_O, + nullptr}, {"_get_cpu_allow_fp16_reduced_precision_reduction", THPModule_allowFP16ReductionCPU, METH_NOARGS,