Revert "[cuBLAS][cuBLASLt] Unify cuBLASLt workspaces with cuBLAS workspaces (#145130)"

This reverts commit 5f0901e573.

Reverted https://github.com/pytorch/pytorch/pull/145130 on behalf of https://github.com/atalman due to Reverted internally ([comment](https://github.com/pytorch/pytorch/pull/145130#issuecomment-2644122846))
This commit is contained in:
PyTorch MergeBot 2025-02-07 21:04:23 +00:00
parent 206ad9f4ad
commit 80a1696679
4 changed files with 16 additions and 73 deletions

View file

@ -3,7 +3,6 @@
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContextLight.h>
#include <ATen/cuda/CUDABlas.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/CUDADataType.h>
@ -215,16 +214,6 @@ static size_t _getWorkspaceSize() {
return workspace_size;
}
void* _getWorkspaceWithoutHandle() {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
auto stream = c10::cuda::getCurrentCUDAStream();
cudaStream_t _stream = stream;
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
auto workspace_it = at::cuda::cublas_handle_stream_to_workspace().find(key);
TORCH_CHECK(workspace_it != at::cuda::cublas_handle_stream_to_workspace().end());
return workspace_it->second.mutable_get();
}
} // anonymous namespace
namespace at::cuda::blas {
@ -406,13 +395,9 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
}
CuBlasLtMatmulPreference preference;
#ifdef USE_ROCM
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
// setting this to 1M.
size_t workspaceSize = _getWorkspaceSize();
#else
size_t workspaceSize = getChosenWorkspaceSize();
#endif
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
#ifndef USE_ROCM
@ -424,14 +409,7 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, c_alignment);
#endif
#ifdef USE_ROCM
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
auto workspace = allocator.allocate(workspaceSize);
auto workspace_ptr = workspace.mutable_get();
TORCH_CHECK(workspace_ptr != nullptr, "OOM trying to allocate workspace for cublaslt");
#else
auto workspace_ptr = _getWorkspaceWithoutHandle();
#endif
auto workspace = at::empty(static_cast<int64_t>(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
cublasLtMatmulHeuristicResult_t heuristicResult = {};
int returnedResult = 0;
@ -464,7 +442,7 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
c,
Cdesc.descriptor(),
&heuristicResult.algo,
workspace_ptr,
workspace.mutable_data_ptr(),
workspaceSize,
at::cuda::getCurrentCUDAStream());
TORCH_CHECK(
@ -1350,14 +1328,9 @@ void gemm_and_bias(
CuBlasLtMatrixLayout Cdesc(abcType, m, n, result_ld);
CuBlasLtMatmulPreference preference;
#ifdef USE_ROCM
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
// setting this to 1M.
size_t workspaceSize = _getWorkspaceSize();
#else
size_t workspaceSize = getChosenWorkspaceSize();
#endif
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
#ifndef USE_ROCM
@ -1371,16 +1344,7 @@ void gemm_and_bias(
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, d_alignment);
#endif
auto stream = c10::cuda::getCurrentCUDAStream();
#ifdef USE_ROCM
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
auto workspace = allocator.allocate(workspaceSize);
auto workspace_ptr = workspace.mutable_get();
TORCH_CHECK(workspace_ptr != nullptr, "OOM trying to allocate workspace for cublaslt");
#else
auto workspace_ptr = _getWorkspaceWithoutHandle();
#endif
auto workspace = at::empty(static_cast<int64_t>(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
cublasLtMatmulHeuristicResult_t heuristicResult = {};
int returnedResult = 0;
@ -1414,9 +1378,9 @@ void gemm_and_bias(
result_ptr,
Cdesc.descriptor(),
&heuristicResult.algo,
workspace_ptr,
workspace.mutable_data_ptr(),
workspaceSize,
stream);
at::cuda::getCurrentCUDAStream());
TORCH_CHECK(
cublasStatus == CUBLAS_STATUS_SUCCESS,
"CUDA error: ",
@ -1575,17 +1539,8 @@ void scaled_gemm(
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, CUBLASLT_EPILOGUE_BIAS);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype));
}
auto stream = c10::cuda::getCurrentCUDAStream();
size_t workspaceSize = _getWorkspaceSize();
#ifdef USE_ROCM
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
auto workspace = allocator.allocate(workspaceSize);
auto workspace_ptr = workspace.mutable_get();
TORCH_CHECK(workspace_ptr != nullptr, "OOM trying to allocate workspace for cublaslt");
#else
auto workspace_ptr = _getWorkspaceWithoutHandle();
#endif
auto workspace = at::empty(static_cast<int64_t>(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
CuBlasLtMatmulPreference preference;
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
@ -1669,9 +1624,9 @@ void scaled_gemm(
result_ptr,
Ddesc.descriptor(),
&heuristicResult.algo,
workspace_ptr,
workspace.mutable_data_ptr(),
workspaceSize,
stream);
at::cuda::getCurrentCUDAStream());
TORCH_CHECK(
cublasStatus == CUBLAS_STATUS_SUCCESS,
"CUDA error: ",
@ -1740,8 +1695,8 @@ void int8_gemm(
CuBlasLtMatmulPreference preference;
size_t workspaceSize = _getWorkspaceSize();
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
auto workspace = allocator.allocate(workspaceSize);
auto workspace = at::empty(workspaceSize, at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
cublasLtMatmulHeuristicResult_t heuristicResult = {};
int returnedResult = 0;
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
@ -1779,7 +1734,7 @@ void int8_gemm(
nullptr, // Heuristics don't seem to work for int8
#endif
#ifdef USE_ROCM
workspace.mutable_get(),
workspace.mutable_data_ptr(),
#else
nullptr, // Non-zero workspace doesn't seem to work.
#endif

View file

@ -2,7 +2,6 @@
// Light-weight version of CUDAContext.h with fewer transitive includes
#include <cstdint>
#include <map>
#include <cuda_runtime_api.h>
#include <cusparse.h>
@ -88,8 +87,6 @@ TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
TORCH_CUDA_CPP_API void clearCublasWorkspaces();
TORCH_CUDA_CPP_API std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace();
TORCH_CUDA_CPP_API size_t getChosenWorkspaceSize();
#if defined(CUDART_VERSION) || defined(USE_ROCM)
TORCH_CUDA_CPP_API cusolverDnHandle_t getCurrentCUDASolverDnHandle();

View file

@ -83,6 +83,11 @@ static hipblasStatus_t hipblasSetWorkspace_replacement(hipblasHandle_t handle, v
#endif
std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace() {
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
return instance;
}
void createCublasHandle(cublasHandle_t *handle) {
TORCH_CUDABLAS_CHECK(cublasCreate(handle));
}
@ -104,11 +109,6 @@ using CuBlasPoolType = DeviceThreadHandlePool<cublasHandle_t, createCublasHandle
} // namespace
std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace() {
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
return instance;
}
void clearCublasWorkspaces() {
cublas_handle_stream_to_workspace().clear();
}

View file

@ -3563,15 +3563,6 @@ def run(runner, args, original_dir=None):
# some of the models do not support use_deterministic_algorithms
torch.use_deterministic_algorithms(True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
if args.only is not None and args.only in {
"DebertaForQuestionAnswering",
"RobertaForQuestionAnswering",
"nvidia_deeprecommender",
"volo_d1_224",
}:
# These seem unhappy with numerics of larger cuBLASLt workspace
# sizes following #145130 (due to enabling split-k?)
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.allow_tf32 = False
torch.backends.cudnn.benchmark = False