From b8ea5b01c90c5d9a4d61134feaa9fa87be9703ff Mon Sep 17 00:00:00 2001 From: Luca Wehrstedt Date: Wed, 21 Aug 2024 15:30:06 +0000 Subject: [PATCH] [fp8 rowwise] Allocate workspace as a PyTorch Tensor (#134110) This makes us pass through the CUDA caching allocator which is safer e.g. in case of CUDA graphs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134110 Approved by: https://github.com/drisspg --- aten/src/ATen/native/cuda/RowwiseScaledMM.cu | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu index be543c654a1..dbfbdfe9835 100644 --- a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu +++ b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu @@ -305,7 +305,9 @@ void f8f8bf16_rowwise_impl( size_t workspace_size = Gemm::get_workspace_size(arguments); // Allocate workspace memory - cutlass::device_memory::allocation workspace(workspace_size); + auto workspace = XQ.new_empty( + {static_cast(workspace_size)}, + at::TensorOptions().dtype(at::kByte)); // Check the problem size is supported or not cutlass::Status status = gemm.can_implement(arguments); @@ -314,7 +316,7 @@ void f8f8bf16_rowwise_impl( } // Initialize CUTLASS kernel with arguments and workspace pointer - status = gemm.initialize(arguments, workspace.get()); + status = gemm.initialize(arguments, workspace.data_ptr()); if (status != cutlass::Status::kSuccess) { throw std::runtime_error("cutlass cannot initialize"); }