mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[fp8 rowwise] Allocate workspace as a PyTorch Tensor (#134110)
This makes us pass through the CUDA caching allocator which is safer e.g. in case of CUDA graphs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134110 Approved by: https://github.com/drisspg
This commit is contained in:
parent
4c8193b8f0
commit
b8ea5b01c9
1 changed files with 4 additions and 2 deletions
|
|
@ -305,7 +305,9 @@ void f8f8bf16_rowwise_impl(
|
|||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||
|
||||
// Allocate workspace memory
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
auto workspace = XQ.new_empty(
|
||||
{static_cast<int64_t>(workspace_size)},
|
||||
at::TensorOptions().dtype(at::kByte));
|
||||
|
||||
// Check the problem size is supported or not
|
||||
cutlass::Status status = gemm.can_implement(arguments);
|
||||
|
|
@ -314,7 +316,7 @@ void f8f8bf16_rowwise_impl(
|
|||
}
|
||||
|
||||
// Initialize CUTLASS kernel with arguments and workspace pointer
|
||||
status = gemm.initialize(arguments, workspace.get());
|
||||
status = gemm.initialize(arguments, workspace.data_ptr());
|
||||
if (status != cutlass::Status::kSuccess) {
|
||||
throw std::runtime_error("cutlass cannot initialize");
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue