mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[ROCm] set hipblas workspace (#138791)
Fixes #138532. This brings hipblas behavior in line with cublas behavior with respect to setting the workspace to an allocation from the caching allocator as well as the env var HIPBLAS_WORKSPACE_CONFIG. Pull Request resolved: https://github.com/pytorch/pytorch/pull/138791 Approved by: https://github.com/naromero77amd, https://github.com/eqy, https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
parent
07b0d633b8
commit
7c7b2d89ba
4 changed files with 86 additions and 17 deletions
|
|
@ -48,6 +48,39 @@ void destroyCublasLtHandle(cublasLtHandle_t handle) {
|
|||
}
|
||||
|
||||
using CuBlasLtPoolType = DeviceThreadHandlePool<cublasLtHandle_t, createCublasLtHandle, destroyCublasLtHandle>;
|
||||
|
||||
// ugly hack until hipblasSetWorkspace exists
|
||||
#include <rocblas/rocblas.h>
|
||||
|
||||
static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error) {
|
||||
switch(error) {
|
||||
case rocblas_status_size_unchanged:
|
||||
case rocblas_status_size_increased:
|
||||
case rocblas_status_success:
|
||||
return HIPBLAS_STATUS_SUCCESS;
|
||||
case rocblas_status_invalid_handle:
|
||||
return HIPBLAS_STATUS_NOT_INITIALIZED;
|
||||
case rocblas_status_not_implemented:
|
||||
return HIPBLAS_STATUS_NOT_SUPPORTED;
|
||||
case rocblas_status_invalid_pointer:
|
||||
case rocblas_status_invalid_size:
|
||||
case rocblas_status_invalid_value:
|
||||
return HIPBLAS_STATUS_INVALID_VALUE;
|
||||
case rocblas_status_memory_error:
|
||||
return HIPBLAS_STATUS_ALLOC_FAILED;
|
||||
case rocblas_status_internal_error:
|
||||
return HIPBLAS_STATUS_INTERNAL_ERROR;
|
||||
}
|
||||
TORCH_CHECK(false, "HIPBLAS_STATUS_INVALID_ENUM");
|
||||
}
|
||||
|
||||
static hipblasStatus_t hipblasSetWorkspace_replacement(hipblasHandle_t handle, void* addr, size_t size) {
|
||||
return rocBLASStatusToHIPStatus(rocblas_set_workspace((rocblas_handle)handle, addr, size));
|
||||
}
|
||||
|
||||
// hipify mappings file correctly maps this but the function doesn't exist yet
|
||||
#define hipblasSetWorkspace hipblasSetWorkspace_replacement
|
||||
|
||||
#endif
|
||||
|
||||
std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace() {
|
||||
|
|
@ -77,17 +110,29 @@ using CuBlasPoolType = DeviceThreadHandlePool<cublasHandle_t, createCublasHandle
|
|||
} // namespace
|
||||
|
||||
void clearCublasWorkspaces() {
|
||||
#if !defined(USE_ROCM)
|
||||
cublas_handle_stream_to_workspace().clear();
|
||||
#endif
|
||||
cublas_handle_stream_to_workspace().clear();
|
||||
}
|
||||
|
||||
size_t parseChosenWorkspaceSize() {
|
||||
const char * val = getenv("CUBLAS_WORKSPACE_CONFIG");
|
||||
#ifdef USE_ROCM
|
||||
if (!val) {
|
||||
val = getenv("HIPBLAS_WORKSPACE_CONFIG");
|
||||
}
|
||||
if (!val) {
|
||||
// for extra convenience
|
||||
val = getenv("ROCBLAS_WORKSPACE_CONFIG");
|
||||
}
|
||||
/* 32MiB default, 128MiB for MI300 */
|
||||
cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties();
|
||||
const bool gfx94 = properties != nullptr && properties->major == 9 && properties->minor == 4;
|
||||
const size_t default_size = gfx94 ? 1024 * 128 * 1024 : 1024 * 32 * 1024;
|
||||
#else
|
||||
/* :4096:2:16:8 default, 32MiB for Hopper */
|
||||
cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties();
|
||||
const bool sm90 = properties != nullptr && properties->major == 9 && properties->minor == 0;
|
||||
const size_t default_size = sm90 ? 4096 * 8 * 1024 : 4096 * 1024 * 2 + 16 * 1024 * 8;
|
||||
#endif
|
||||
|
||||
if (val) {
|
||||
size_t total_size = 0;
|
||||
|
|
@ -156,7 +201,6 @@ cublasHandle_t getCurrentCUDABlasHandle() {
|
|||
auto handle = myPoolWindow->reserve(device);
|
||||
auto stream = c10::cuda::getCurrentCUDAStream();
|
||||
TORCH_CUDABLAS_CHECK(cublasSetStream(handle, stream));
|
||||
#if !defined(USE_ROCM)
|
||||
// We explicitly set the cublas workspace even though CUDA 12.2+ fixed the
|
||||
// issue where memory usage increased during graph capture.
|
||||
// original issue: https://github.com/pytorch/pytorch/pull/83461
|
||||
|
|
@ -171,6 +215,7 @@ cublasHandle_t getCurrentCUDABlasHandle() {
|
|||
workspace_it = cublas_handle_stream_to_workspace().insert(workspace_it, {key, getNewWorkspace()});
|
||||
}
|
||||
TORCH_CUDABLAS_CHECK(cublasSetWorkspace(handle, workspace_it->second.get(), getChosenWorkspaceSize()));
|
||||
#if !defined(USE_ROCM)
|
||||
// On CUDA >= 11, and architecture >= Ampere, cuBLAS can use TF32 to speedup
|
||||
// FP32 data type calculations based on the value of the allow_tf32 flag.
|
||||
// To enable TF32, set the math mode of the handle to CUBLAS_TF32_TENSOR_OP_MATH.
|
||||
|
|
|
|||
|
|
@ -103,7 +103,24 @@ complete snapshot of the memory allocator state via
|
|||
underlying allocation patterns produced by your code.
|
||||
|
||||
To debug memory errors, set
|
||||
``PYTORCH_NO_CUDA_MEMORY_CACHING=1`` in your environment to disable caching.
|
||||
``PYTORCH_NO_HIP_MEMORY_CACHING=1`` in your environment to disable caching.
|
||||
``PYTORCH_NO_CUDA_MEMORY_CACHING=1`` is also accepted for ease of porting.
|
||||
|
||||
.. hipblas-workspaces:
|
||||
|
||||
hipBLAS workspaces
|
||||
------------------
|
||||
|
||||
For each combination of hipBLAS handle and HIP stream, a hipBLAS workspace will be allocated if that
|
||||
handle and stream combination executes a hipBLAS kernel that requires a workspace. In order to
|
||||
avoid repeatedly allocating workspaces, these workspaces are not deallocated unless
|
||||
``torch._C._cuda_clearCublasWorkspaces()`` is called; note that it's the same function for CUDA or
|
||||
HIP. The workspace size per allocation can be specified via the environment variable
|
||||
``HIPBLAS_WORKSPACE_CONFIG`` with the format ``:[SIZE]:[COUNT]``. As an example, the environment
|
||||
variable ``HIPBLAS_WORKSPACE_CONFIG=:4096:2:16:8`` specifies a total size of ``2 * 4096 + 8 * 16
|
||||
KiB`` or 8 MIB. The default workspace size is 32 MiB; MI300 and newer defaults to 128 MiB. To force
|
||||
hipBLAS to avoid using workspaces, set ``HIPBLAS_WORKSPACE_CONFIG=:0:0``. For convenience,
|
||||
``CUBLAS_WORKSPACE_CONFIG`` is also accepted.
|
||||
|
||||
.. _hipfft-plan-cache:
|
||||
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@ from torch.cuda._memory_viz import (
|
|||
from torch.testing._internal.autocast_test_lists import AutocastTestLists, TestAutocast
|
||||
from torch.testing._internal.common_cuda import (
|
||||
_create_scaling_case,
|
||||
_get_torch_cuda_version,
|
||||
TEST_CUDNN,
|
||||
TEST_MULTIGPU,
|
||||
)
|
||||
|
|
@ -63,6 +62,7 @@ from torch.testing._internal.common_utils import (
|
|||
parametrize,
|
||||
run_tests,
|
||||
serialTest,
|
||||
setBlasBackendsToDefaultFinally,
|
||||
skipCUDAMemoryLeakCheckIf,
|
||||
skipCUDANonDefaultStreamIf,
|
||||
skipIfRocm,
|
||||
|
|
@ -417,19 +417,23 @@ class TestCuda(TestCase):
|
|||
q_copy[1].fill_(10)
|
||||
self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10))
|
||||
|
||||
@unittest.skipIf(
|
||||
TEST_CUDAMALLOCASYNC or TEST_WITH_ROCM, "temporarily disabled for async"
|
||||
)
|
||||
@unittest.skipIf(
|
||||
_get_torch_cuda_version() >= (12, 2),
|
||||
"skipped as explicit workspace allocation is removed",
|
||||
)
|
||||
@unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled for async")
|
||||
@setBlasBackendsToDefaultFinally
|
||||
def test_cublas_workspace_explicit_allocation(self):
|
||||
torch.backends.cuda.preferred_blas_library("cublas")
|
||||
a = torch.randn(7, 7, device="cuda", requires_grad=False)
|
||||
default_workspace_size = 4096 * 2 * 1024 + 16 * 8 * 1024 # :4096:2:16:8
|
||||
# different size (32 MiB) expected on Hopper GPU
|
||||
if torch.cuda.get_device_capability() == (9, 0):
|
||||
default_workspace_size = 4096 * 8 * 1024
|
||||
if torch.version.hip:
|
||||
default_workspace_size = 1024 * 32 * 1024 # :1024:32 32MiB
|
||||
# different size (128 MiB) expected on MI300 GPU
|
||||
if torch.cuda.get_device_capability() >= (9, 4):
|
||||
default_workspace_size = 1024 * 128 * 1024 # :1024:128
|
||||
else:
|
||||
default_workspace_size = (
|
||||
4096 * 2 * 1024 + 16 * 8 * 1024
|
||||
) # :4096:2:16:8 8MiB
|
||||
# different size (32 MiB) expected on Hopper GPU
|
||||
if torch.cuda.get_device_capability() == (9, 0):
|
||||
default_workspace_size = 4096 * 8 * 1024
|
||||
|
||||
def check_workspace_size(inp):
|
||||
torch._C._cuda_clearCublasWorkspaces()
|
||||
|
|
@ -1919,7 +1923,9 @@ exit(2)
|
|||
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
|
||||
)
|
||||
@serialTest()
|
||||
@setBlasBackendsToDefaultFinally
|
||||
def test_repeat_graph_capture_cublas_workspace_memory(self):
|
||||
torch.backends.cuda.preferred_blas_library("cublas")
|
||||
(x, y, z) = 1024, 512, 64
|
||||
a = torch.rand((x, y), device="cuda")
|
||||
b = torch.rand((y, z), device="cuda")
|
||||
|
|
|
|||
|
|
@ -6693,6 +6693,7 @@ CUDA_IDENTIFIER_MAP = collections.OrderedDict(
|
|||
"cublasGetVersion_v2",
|
||||
("hipblasGetVersion_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED),
|
||||
),
|
||||
("cublasSetWorkspace", ("hipblasSetWorkspace", CONV_MATH_FUNC, API_BLAS)),
|
||||
("cublasSetStream", ("hipblasSetStream", CONV_MATH_FUNC, API_BLAS)),
|
||||
("cublasGetStream", ("hipblasGetStream", CONV_MATH_FUNC, API_BLAS)),
|
||||
("cublasSetStream_v2", ("hipblasSetStream_v2", CONV_MATH_FUNC, API_BLAS)),
|
||||
|
|
|
|||
Loading…
Reference in a new issue