pytorch/c10/cuda/driver_api.h
Yifu Wang ee42a99745 [SymmetricMemory] introduce a binding for cuMemset32Async (#138755)
## This Stack

This stack does the following things to support `xformers`-style, comm-aware Triton kernels:
- Exposes `signal_pad`s as tensors in Python
- Adds a binding for `cuMemsetAsync`

These in combination aims to provide users with more flexibility to express custom signaling/synchronization patterns.

## This PR
Make `cuMemset32Async` available via `_SymmetricMemory.memset32`. We chose `cuMemset32Async` over `cudaMemsetAsync` because it allows for `uint32_t`-wise memset. This provides users with better flexibility.

To enable this, we also added the following cuda driver APIs in `c10::cuda::DriverAPI`:
- `cuDevicePrimaryCtxRetain` - for obtaining the primary context of a device in the form of `CUcontext`.
- `cuCtxGetCurrent`/`cuCtxSetCurrent` - for setting and restoring the context for cuda driver APIs such as `cuMemset32Async`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138755
Approved by: https://github.com/weifengpy, https://github.com/eqy, https://github.com/lw
2024-11-05 18:47:24 +00:00

65 lines
2.4 KiB
C++

#pragma once
#include <cuda.h>
#define NVML_NO_UNVERSIONED_FUNC_DEFS
#include <nvml.h>
#define C10_CUDA_DRIVER_CHECK(EXPR) \
do { \
CUresult __err = EXPR; \
if (__err != CUDA_SUCCESS) { \
const char* err_str; \
CUresult get_error_str_err [[maybe_unused]] = \
c10::cuda::DriverAPI::get()->cuGetErrorString_(__err, &err_str); \
if (get_error_str_err != CUDA_SUCCESS) { \
TORCH_CHECK(false, "CUDA driver error: unknown error"); \
} else { \
TORCH_CHECK(false, "CUDA driver error: ", err_str); \
} \
} \
} while (0)
#define C10_LIBCUDA_DRIVER_API(_) \
_(cuDeviceGetAttribute) \
_(cuMemAddressReserve) \
_(cuMemRelease) \
_(cuMemMap) \
_(cuMemAddressFree) \
_(cuMemSetAccess) \
_(cuMemUnmap) \
_(cuMemCreate) \
_(cuMemGetAllocationGranularity) \
_(cuMemExportToShareableHandle) \
_(cuMemImportFromShareableHandle) \
_(cuMemsetD32Async) \
_(cuStreamWriteValue32) \
_(cuGetErrorString)
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030)
#define C10_LIBCUDA_DRIVER_API_12030(_) \
_(cuMulticastAddDevice) \
_(cuMulticastBindMem) \
_(cuMulticastCreate)
#else
#define C10_LIBCUDA_DRIVER_API_12030(_)
#endif
#define C10_NVML_DRIVER_API(_) \
_(nvmlInit_v2) \
_(nvmlDeviceGetHandleByPciBusId_v2) \
_(nvmlDeviceGetNvLinkRemoteDeviceType) \
_(nvmlDeviceGetNvLinkRemotePciInfo_v2) \
_(nvmlDeviceGetComputeRunningProcesses)
namespace c10::cuda {
struct DriverAPI {
#define CREATE_MEMBER(name) decltype(&name) name##_;
C10_LIBCUDA_DRIVER_API(CREATE_MEMBER)
C10_LIBCUDA_DRIVER_API_12030(CREATE_MEMBER)
C10_NVML_DRIVER_API(CREATE_MEMBER)
#undef CREATE_MEMBER
static DriverAPI* get();
static void* get_nvml_handle();
};
} // namespace c10::cuda