diff --git a/BUILD.bazel b/BUILD.bazel index 523f2ad2ce6..f079c76a7f7 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -574,7 +574,7 @@ cu_library( name = "torch_cuda", srcs = [ "torch/csrc/distributed/c10d/intra_node_comm.cu", - "torch/csrc/distributed/c10d/NCCLUtils.cu", + "torch/csrc/distributed/c10d/Utils.cu", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", ], copts = torch_cuda_half_options, @@ -722,7 +722,7 @@ cc_library( "torch/csrc/distributed/c10d/intra_node_comm.cu", "torch/csrc/distributed/c10d/CUDASymmetricMemory.cu", "torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu", - "torch/csrc/distributed/c10d/NCCLUtils.cu", + "torch/csrc/distributed/c10d/Utils.cu", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", ], )) + torch_sources, diff --git a/build_variables.bzl b/build_variables.bzl index 7fc5802550a..e7434305c72 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -691,7 +691,7 @@ libtorch_cuda_distributed_extra_sources = [ "torch/csrc/distributed/c10d/intra_node_comm.cu", "torch/csrc/distributed/c10d/CUDASymmetricMemory.cu", "torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu", - "torch/csrc/distributed/c10d/NCCLUtils.cu", + "torch/csrc/distributed/c10d/Utils.cu", "torch/csrc/distributed/rpc/tensorpipe_cuda.cpp", "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", ] diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index e45d9d09b25..b34ff33333d 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -10,7 +10,6 @@ #include #include -#include #include #include #include @@ -715,11 +714,6 @@ struct NCCLTraceBuffer { bool includeStackTraces, bool onlyActive); }; - -// Check for NaNs in a tensor on a given stream. If any are found, throw a -// device-side error. -void checkForNan(const at::Tensor& tensor, at::cuda::CUDAStream& stream); - } // namespace c10d #endif // USE_C10D_NCCL diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index abcf493e62f..71a081b5e49 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -2638,6 +2638,9 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( OpType opType, const char* profilingTitle, bool avoidRecordStreams) { + if (enableNanCheck_) { + checkForNan(input); + } // Environment setting by the user may add onto collective call's option avoidRecordStreams |= avoidRecordStreams_; c10::cuda::CaptureStatus capture_status = @@ -2693,10 +2696,6 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( at::cuda::OptionalCUDAGuard gpuGuard; - if (enableNanCheck_) { - checkForNan(input, ncclStream); - } - // Start event should only be recorded before the ncclGroupStart() if (work->timingEnabled_) { work->ncclStartEvent_->record(ncclStream); @@ -2998,6 +2997,9 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( PreProcess pre, PostProcess post, const char* profilingTitle) { + if (enableNanCheck_) { + checkForNan(tensor); + } // avoidRecordStreams_ note: // send, recv, and irecv should be ok with avoidRecordStreams, // However, for isend, I don't think the API requires the user @@ -3126,10 +3128,6 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( // is gpuGuard needed for the if block below, or can i swap them at::cuda::OptionalCUDAGuard gpuGuard; - if (enableNanCheck_) { - checkForNan(tensor, ncclStream); - } - if (!coalescing_state_) { // Start event should only be recorded before the ncclGroupStart() if (work->timingEnabled_) { diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cu b/torch/csrc/distributed/c10d/Utils.cu similarity index 86% rename from torch/csrc/distributed/c10d/NCCLUtils.cu rename to torch/csrc/distributed/c10d/Utils.cu index 5dcf01fb98b..ae2017efdf8 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cu +++ b/torch/csrc/distributed/c10d/Utils.cu @@ -1,7 +1,7 @@ #include #include #include -#include +#include #include #include @@ -20,7 +20,7 @@ __global__ void checkForNaN(T* data, size_t size) { } // CHECK if a Tensor contains NAN in any of its element -void checkForNan(const at::Tensor& tensor, at::cuda::CUDAStream& stream) { +void checkForNan(const at::Tensor& tensor) { // skip check for non float types if (!torch::is_floating_point(tensor)) { return; @@ -40,7 +40,7 @@ void checkForNan(const at::Tensor& tensor, at::cuda::CUDAStream& stream) { tensor.scalar_type(), "checkForNaN", [&] { - checkForNaN<<>>( + checkForNaN<<>>( tensor.data_ptr(), tensor.numel()); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); diff --git a/torch/csrc/distributed/c10d/Utils.hpp b/torch/csrc/distributed/c10d/Utils.hpp index ea4a4653bc3..5c736539340 100644 --- a/torch/csrc/distributed/c10d/Utils.hpp +++ b/torch/csrc/distributed/c10d/Utils.hpp @@ -611,6 +611,8 @@ using SizeType = uint64_t; // Since SOCKET_ERROR = -1 in MSVC, so also leverage SYSCHECK_ERR_RETURN_NEG1 #define SYSCHECK_ERR_RETURN_NEG1(expr) SYSCHECK(expr, __output != -1) +void checkForNan(const at::Tensor& tensor); + namespace tcputil { // Send and receive