[fp8 rowwise] Allocate workspace as a PyTorch Tensor (#134110)

This makes us pass through the CUDA caching allocator which is safer e.g. in case of CUDA graphs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134110
Approved by: https://github.com/drisspg
This commit is contained in:
Luca Wehrstedt 2024-08-21 15:30:06 +00:00 committed by PyTorch MergeBot
parent 4c8193b8f0
commit b8ea5b01c9

View file

@ -305,7 +305,9 @@ void f8f8bf16_rowwise_impl(
size_t workspace_size = Gemm::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
auto workspace = XQ.new_empty(
{static_cast<int64_t>(workspace_size)},
at::TensorOptions().dtype(at::kByte));
// Check the problem size is supported or not
cutlass::Status status = gemm.can_implement(arguments);
@ -314,7 +316,7 @@ void f8f8bf16_rowwise_impl(
}
// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm.initialize(arguments, workspace.get());
status = gemm.initialize(arguments, workspace.data_ptr());
if (status != cutlass::Status::kSuccess) {
throw std::runtime_error("cutlass cannot initialize");
}