Fixes to DSA infra (#91835)

Differential Revision: D42397325

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91835
Approved by: https://github.com/soumith
This commit is contained in:
Richard Barnes 2023-01-12 21:54:22 +00:00 committed by PyTorch MergeBot
parent 4636fe701c
commit 6f749fd171
12 changed files with 60 additions and 85 deletions

View file

@ -18,7 +18,7 @@ static __device__ void dstrcpy(char* dst, const char* src) {
*dst = '\0';
}
__device__ __noinline__ void dsa_add_new_assertion_failure(
__device__ void dsa_add_new_assertion_failure(
DeviceAssertionsData* assertions_data,
const char* assertion_msg,
const char* filename,

View file

@ -21,6 +21,33 @@ namespace cuda {
namespace {
#ifdef TORCH_USE_CUDA_DSA
/// Get current device id
/// We need our own implementation of this function to prevent
/// an infinite initialization loop for CUDAKernelLaunchRegistry
int dsa_get_device_id() {
int device = -1;
C10_CUDA_ERROR_HANDLED(cudaGetDevice(&device));
CHECK_CUDA_API_CALL_WITHOUT_CHECKING_DEVICE_ASSERTS();
return device;
}
/// Get a device's compute capability - note that this dangerously assumes
/// that if one CUDA GPU supports device-side assertions they all do. This is
/// probably fine since the latest CUDA GPU that doesn't support UVM is the
/// K80 released 2014-11-17. Mixing that GPU with a newer one is likely to be
/// rare enough that the defensive
/// We need our own implementation of this function to prevent
/// an infinite initialization loop for CUDAKernelLaunchRegistry
int dsa_get_device_compute_capability(const int device_num) {
int compute_capability = -1;
C10_CUDA_ERROR_HANDLED(cudaDeviceGetAttribute(
&compute_capability, cudaDevAttrComputeCapabilityMajor, device_num));
CHECK_CUDA_API_CALL_WITHOUT_CHECKING_DEVICE_ASSERTS();
return compute_capability;
}
#endif
/// Get the number of CUDA devices
/// We need our own implementation of this function to prevent
/// an infinite initialization loop for CUDAKernelLaunchRegistry
@ -60,40 +87,13 @@ void uvm_deleter(DeviceAssertionsData* uvm_assertions_ptr) {
}
}
#ifdef TORCH_USE_CUDA_DSA
/// Get current device id
/// We need our own implementation of this function to prevent
/// an infinite initialization loop for CUDAKernelLaunchRegistry
int dsa_get_device_id() {
int device = -1;
C10_CUDA_ERROR_HANDLED(cudaGetDevice(&device));
CHECK_CUDA_API_CALL_WITHOUT_CHECKING_DEVICE_ASSERTS();
return device;
}
/// Get a device's compute capability - note that this dangerously assumes
/// that if one CUDA GPU supports device-side assertions they all do. This is
/// probably fine since the latest CUDA GPU that doesn't support UVM is the
/// K80 released 2014-11-17. Mixing that GPU with a newer one is likely to be
/// rare enough that the defensive
/// We need our own implementation of this function to prevent
/// an infinite initialization loop for CUDAKernelLaunchRegistry
int dsa_get_device_compute_capability(const int device_num) {
int compute_capability = -1;
C10_CUDA_ERROR_HANDLED(cudaDeviceGetAttribute(
&compute_capability, cudaDevAttrComputeCapabilityMajor, device_num));
CHECK_CUDA_API_CALL_WITHOUT_CHECKING_DEVICE_ASSERTS();
return compute_capability;
}
#endif
} // namespace
/// Check that kernels ran correctly by checking the message buffer. BLOCKING.
std::string c10_retrieve_device_side_assertion_info() {
#ifdef TORCH_USE_CUDA_DSA
const auto& launch_registry = CUDAKernelLaunchRegistry::get_singleton_ref();
if (!launch_registry.enabled) {
if (!launch_registry.enabled_at_runtime) {
return "Device-side assertion tracking was not enabled by user.";
} else if (!launch_registry.do_all_devices_support_managed_memory) {
return "Device-side assertions disabled because not all devices support managed memory.";
@ -115,20 +115,7 @@ std::string c10_retrieve_device_side_assertion_info() {
std::stringstream oss;
{
oss << "This process interacted the following GPUs = {";
bool first_gpu_listed = true;
for (const auto& x : uvm_assertions) {
if (x) {
if (!first_gpu_listed) {
oss << ","
}
first_gpu_listed = true;
oss << x;
}
}
oss << "}" << std::endl;
}
oss << "Looking for device-side assertion failure information...\n";
// Loop over each device that could be managed by the process
for (const auto device_num : c10::irange(assertion_data.size())) {
@ -202,7 +189,7 @@ CUDAKernelLaunchRegistry::CUDAKernelLaunchRegistry()
: do_all_devices_support_managed_memory(
dsa_check_if_all_devices_support_managed_memory()),
gather_launch_stacktrace(check_env_for_enable_launch_stacktracing()),
enabled(check_env_for_dsa_enabled()) {
enabled_at_runtime(check_env_for_dsa_enabled()) {
for (C10_UNUSED const auto _ : c10::irange(dsa_get_device_count())) {
uvm_assertions.emplace_back(nullptr, uvm_deleter);
}
@ -226,7 +213,7 @@ uint32_t CUDAKernelLaunchRegistry::insert(
const char* kernel_name,
const int32_t stream_id) {
#ifdef TORCH_USE_CUDA_DSA
if (!is_enabled()) {
if (!enabled_at_runtime) {
return 0;
}
@ -274,7 +261,7 @@ CUDAKernelLaunchRegistry::snapshot() const {
DeviceAssertionsData* CUDAKernelLaunchRegistry::
get_uvm_assertions_ptr_for_current_device() {
#ifdef TORCH_USE_CUDA_DSA
if (!is_enabled()) {
if (!enabled_at_runtime) {
return nullptr;
}
@ -352,16 +339,5 @@ bool CUDAKernelLaunchRegistry::has_failed() const {
return false;
}
bool CUDAKernelLaunchRegistry::is_enabled() const {
#ifdef TORCH_USE_CUDA_DSA
std::cerr << ""
#else
std::cerr
<< "TORCH_USE_CUDA_DSA not enabled in CUDAKernelLaunchRegistry::is_enabled"
<< std::endl;
return false;
#endif
}
} // namespace cuda
} // namespace c10

View file

@ -130,13 +130,15 @@ class C10_CUDA_API CUDAKernelLaunchRegistry {
/// Whether or not to gather stack traces when launching kernels
bool gather_launch_stacktrace = false;
/// Whether or not host-side DSA is enabled or disabled at run-time
/// Device-side code cannot be adjusted at run-time
bool enabled = false;
/// Note: Device-side code cannot be enabled/disabled at run-time
bool enabled_at_runtime = false;
/// Whether or not a device has indicated a failure
bool has_failed() const;
/// Since multiple mechanisms can enable/disable, we add a function that
/// aggregates them
bool is_enabled() const;
#ifdef TORCH_USE_CUDA_DSA
const bool enabled_at_compile_time = true;
#else
const bool enabled_at_compile_time = false;
#endif
};
std::string c10_retrieve_device_side_assertion_info();
@ -147,9 +149,9 @@ std::string c10_retrieve_device_side_assertion_info();
// Each kernel launched with TORCH_DSA_KERNEL_LAUNCH
// requires the same input arguments. We introduce the following macro to
// standardize these.
#define TORCH_DSA_KERNEL_ARGS \
c10::cuda::DeviceAssertionsData *const assertions_data, \
uint32_t assertion_caller_id
#define TORCH_DSA_KERNEL_ARGS \
[[maybe_unused]] c10::cuda::DeviceAssertionsData *const assertions_data, \
[[maybe_unused]] uint32_t assertion_caller_id
// This macro can be used to pass the DSA arguments onward to another
// function

View file

@ -12,7 +12,7 @@ const char* get_cuda_check_suffix() noexcept {
return "";
} else {
return "\nCUDA kernel errors might be asynchronously reported at some"
" other API call,so the stacktrace below might be incorrect."
" other API call, so the stacktrace below might be incorrect."
"\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.";
}
}

View file

@ -14,19 +14,13 @@
using ::testing::HasSubstr;
void did_not_fail_diagnostics() {
#ifdef TORCH_USE_CUDA_DSA
std::cerr << "DSA was enabled" << std::endl;
#else
std::cerr << "DSA was not enabled" << std::endl;
#endif
std::cerr
<< "c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled = "
<< c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled
<< "c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled_at_runtime = "
<< c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled_at_runtime
<< std::endl;
std::cerr
<< "c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().is_enabled() = "
<< c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().is_enabled()
<< "c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled_at_compile_time = "
<< c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled_at_compile_time
<< std::endl;
std::cerr
<< "c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().do_all_devices_support_managed_memory = "
@ -92,11 +86,10 @@ void cuda_device_assertions_1_var_test() {
TEST(CUDATest, cuda_device_assertions_1_var_test) {
#ifdef TORCH_USE_CUDA_DSA
c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled = true;
std::cerr << "BEFORE TEST" << std::endl;
c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled_at_runtime = true;
did_not_fail_diagnostics();
cuda_device_assertions_1_var_test();
#else
GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled.";
GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled at compile time.";
#endif
}

View file

@ -96,6 +96,6 @@ TEST(CUDATest, cuda_device_assertions_catches_stream) {
c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled = true;
cuda_device_assertions_catches_stream();
#else
GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled.";
GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled at compile time.";
#endif
}

View file

@ -81,6 +81,6 @@ TEST(CUDATest, cuda_device_assertions_catches_thread_and_block_and_device) {
c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled = true;
cuda_device_assertions_catches_thread_and_block_and_device();
#else
GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled.";
GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled at compile time.";
#endif
}

View file

@ -99,7 +99,7 @@ TEST(CUDATest, cuda_device_assertions_from_2_processes) {
c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled = true;
cuda_device_assertions_from_2_processes();
#else
GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled.";
GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled at compile time.";
#endif
}

View file

@ -88,6 +88,6 @@ TEST(CUDATest, cuda_device_assertions_multiple_writes_from_blocks_and_threads) {
c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled = true;
cuda_device_assertions_multiple_writes_from_blocks_and_threads();
#else
GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled.";
GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled at compile time.";
#endif
}

View file

@ -85,6 +85,6 @@ TEST(CUDATest, cuda_device_assertions_multiple_writes_from_multiple_blocks) {
c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled = true;
cuda_device_assertions_multiple_writes_from_multiple_blocks();
#else
GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled.";
GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled at compile time.";
#endif
}

View file

@ -73,6 +73,6 @@ TEST(CUDATest, cuda_device_assertions_multiple_writes_from_same_block) {
c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled = true;
cuda_device_assertions_multiple_writes_from_same_block();
#else
GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled.";
GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled at compile time.";
#endif
}

View file

@ -201,6 +201,10 @@ class CAFFE2_CUDA_API CUDAContext final : public BaseContext {
return gpu_id_;
}
inline c10::cuda::CUDAStream stream() const {
return at::cuda::getStreamFromExternal(getCudaObjects().GetStream(gpu_id_), gpu_id_);
}
inline cudaStream_t cuda_stream() const {
return getCudaObjects().GetStream(gpu_id_);
}