[ROCm] set hipblas workspace (#138791)

Fixes #138532.

This brings hipblas behavior in line with cublas behavior with respect to setting the workspace to an allocation from the caching allocator as well as the env var HIPBLAS_WORKSPACE_CONFIG.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138791
Approved by: https://github.com/naromero77amd, https://github.com/eqy, https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
Jeff Daily 2024-10-29 01:37:52 +00:00 committed by PyTorch MergeBot
parent 07b0d633b8
commit 7c7b2d89ba
4 changed files with 86 additions and 17 deletions

View file

@ -48,6 +48,39 @@ void destroyCublasLtHandle(cublasLtHandle_t handle) {
}
using CuBlasLtPoolType = DeviceThreadHandlePool<cublasLtHandle_t, createCublasLtHandle, destroyCublasLtHandle>;
// ugly hack until hipblasSetWorkspace exists
#include <rocblas/rocblas.h>
static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error) {
switch(error) {
case rocblas_status_size_unchanged:
case rocblas_status_size_increased:
case rocblas_status_success:
return HIPBLAS_STATUS_SUCCESS;
case rocblas_status_invalid_handle:
return HIPBLAS_STATUS_NOT_INITIALIZED;
case rocblas_status_not_implemented:
return HIPBLAS_STATUS_NOT_SUPPORTED;
case rocblas_status_invalid_pointer:
case rocblas_status_invalid_size:
case rocblas_status_invalid_value:
return HIPBLAS_STATUS_INVALID_VALUE;
case rocblas_status_memory_error:
return HIPBLAS_STATUS_ALLOC_FAILED;
case rocblas_status_internal_error:
return HIPBLAS_STATUS_INTERNAL_ERROR;
}
TORCH_CHECK(false, "HIPBLAS_STATUS_INVALID_ENUM");
}
static hipblasStatus_t hipblasSetWorkspace_replacement(hipblasHandle_t handle, void* addr, size_t size) {
return rocBLASStatusToHIPStatus(rocblas_set_workspace((rocblas_handle)handle, addr, size));
}
// hipify mappings file correctly maps this but the function doesn't exist yet
#define hipblasSetWorkspace hipblasSetWorkspace_replacement
#endif
std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace() {
@ -77,17 +110,29 @@ using CuBlasPoolType = DeviceThreadHandlePool<cublasHandle_t, createCublasHandle
} // namespace
void clearCublasWorkspaces() {
#if !defined(USE_ROCM)
cublas_handle_stream_to_workspace().clear();
#endif
cublas_handle_stream_to_workspace().clear();
}
size_t parseChosenWorkspaceSize() {
const char * val = getenv("CUBLAS_WORKSPACE_CONFIG");
#ifdef USE_ROCM
if (!val) {
val = getenv("HIPBLAS_WORKSPACE_CONFIG");
}
if (!val) {
// for extra convenience
val = getenv("ROCBLAS_WORKSPACE_CONFIG");
}
/* 32MiB default, 128MiB for MI300 */
cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties();
const bool gfx94 = properties != nullptr && properties->major == 9 && properties->minor == 4;
const size_t default_size = gfx94 ? 1024 * 128 * 1024 : 1024 * 32 * 1024;
#else
/* :4096:2:16:8 default, 32MiB for Hopper */
cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties();
const bool sm90 = properties != nullptr && properties->major == 9 && properties->minor == 0;
const size_t default_size = sm90 ? 4096 * 8 * 1024 : 4096 * 1024 * 2 + 16 * 1024 * 8;
#endif
if (val) {
size_t total_size = 0;
@ -156,7 +201,6 @@ cublasHandle_t getCurrentCUDABlasHandle() {
auto handle = myPoolWindow->reserve(device);
auto stream = c10::cuda::getCurrentCUDAStream();
TORCH_CUDABLAS_CHECK(cublasSetStream(handle, stream));
#if !defined(USE_ROCM)
// We explicitly set the cublas workspace even though CUDA 12.2+ fixed the
// issue where memory usage increased during graph capture.
// original issue: https://github.com/pytorch/pytorch/pull/83461
@ -171,6 +215,7 @@ cublasHandle_t getCurrentCUDABlasHandle() {
workspace_it = cublas_handle_stream_to_workspace().insert(workspace_it, {key, getNewWorkspace()});
}
TORCH_CUDABLAS_CHECK(cublasSetWorkspace(handle, workspace_it->second.get(), getChosenWorkspaceSize()));
#if !defined(USE_ROCM)
// On CUDA >= 11, and architecture >= Ampere, cuBLAS can use TF32 to speedup
// FP32 data type calculations based on the value of the allow_tf32 flag.
// To enable TF32, set the math mode of the handle to CUBLAS_TF32_TENSOR_OP_MATH.

View file

@ -103,7 +103,24 @@ complete snapshot of the memory allocator state via
underlying allocation patterns produced by your code.
To debug memory errors, set
``PYTORCH_NO_CUDA_MEMORY_CACHING=1`` in your environment to disable caching.
``PYTORCH_NO_HIP_MEMORY_CACHING=1`` in your environment to disable caching.
``PYTORCH_NO_CUDA_MEMORY_CACHING=1`` is also accepted for ease of porting.
.. hipblas-workspaces:
hipBLAS workspaces
------------------
For each combination of hipBLAS handle and HIP stream, a hipBLAS workspace will be allocated if that
handle and stream combination executes a hipBLAS kernel that requires a workspace. In order to
avoid repeatedly allocating workspaces, these workspaces are not deallocated unless
``torch._C._cuda_clearCublasWorkspaces()`` is called; note that it's the same function for CUDA or
HIP. The workspace size per allocation can be specified via the environment variable
``HIPBLAS_WORKSPACE_CONFIG`` with the format ``:[SIZE]:[COUNT]``. As an example, the environment
variable ``HIPBLAS_WORKSPACE_CONFIG=:4096:2:16:8`` specifies a total size of ``2 * 4096 + 8 * 16
KiB`` or 8 MIB. The default workspace size is 32 MiB; MI300 and newer defaults to 128 MiB. To force
hipBLAS to avoid using workspaces, set ``HIPBLAS_WORKSPACE_CONFIG=:0:0``. For convenience,
``CUBLAS_WORKSPACE_CONFIG`` is also accepted.
.. _hipfft-plan-cache:

View file

@ -31,7 +31,6 @@ from torch.cuda._memory_viz import (
from torch.testing._internal.autocast_test_lists import AutocastTestLists, TestAutocast
from torch.testing._internal.common_cuda import (
_create_scaling_case,
_get_torch_cuda_version,
TEST_CUDNN,
TEST_MULTIGPU,
)
@ -63,6 +62,7 @@ from torch.testing._internal.common_utils import (
parametrize,
run_tests,
serialTest,
setBlasBackendsToDefaultFinally,
skipCUDAMemoryLeakCheckIf,
skipCUDANonDefaultStreamIf,
skipIfRocm,
@ -417,19 +417,23 @@ class TestCuda(TestCase):
q_copy[1].fill_(10)
self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10))
@unittest.skipIf(
TEST_CUDAMALLOCASYNC or TEST_WITH_ROCM, "temporarily disabled for async"
)
@unittest.skipIf(
_get_torch_cuda_version() >= (12, 2),
"skipped as explicit workspace allocation is removed",
)
@unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled for async")
@setBlasBackendsToDefaultFinally
def test_cublas_workspace_explicit_allocation(self):
torch.backends.cuda.preferred_blas_library("cublas")
a = torch.randn(7, 7, device="cuda", requires_grad=False)
default_workspace_size = 4096 * 2 * 1024 + 16 * 8 * 1024 # :4096:2:16:8
# different size (32 MiB) expected on Hopper GPU
if torch.cuda.get_device_capability() == (9, 0):
default_workspace_size = 4096 * 8 * 1024
if torch.version.hip:
default_workspace_size = 1024 * 32 * 1024 # :1024:32 32MiB
# different size (128 MiB) expected on MI300 GPU
if torch.cuda.get_device_capability() >= (9, 4):
default_workspace_size = 1024 * 128 * 1024 # :1024:128
else:
default_workspace_size = (
4096 * 2 * 1024 + 16 * 8 * 1024
) # :4096:2:16:8 8MiB
# different size (32 MiB) expected on Hopper GPU
if torch.cuda.get_device_capability() == (9, 0):
default_workspace_size = 4096 * 8 * 1024
def check_workspace_size(inp):
torch._C._cuda_clearCublasWorkspaces()
@ -1919,7 +1923,9 @@ exit(2)
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
)
@serialTest()
@setBlasBackendsToDefaultFinally
def test_repeat_graph_capture_cublas_workspace_memory(self):
torch.backends.cuda.preferred_blas_library("cublas")
(x, y, z) = 1024, 512, 64
a = torch.rand((x, y), device="cuda")
b = torch.rand((y, z), device="cuda")

View file

@ -6693,6 +6693,7 @@ CUDA_IDENTIFIER_MAP = collections.OrderedDict(
"cublasGetVersion_v2",
("hipblasGetVersion_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED),
),
("cublasSetWorkspace", ("hipblasSetWorkspace", CONV_MATH_FUNC, API_BLAS)),
("cublasSetStream", ("hipblasSetStream", CONV_MATH_FUNC, API_BLAS)),
("cublasGetStream", ("hipblasGetStream", CONV_MATH_FUNC, API_BLAS)),
("cublasSetStream_v2", ("hipblasSetStream_v2", CONV_MATH_FUNC, API_BLAS)),