Don't call cudaStreamDestroy at destruction time (#15692)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15692

It was leading to ocassional crashes with dynamically linked CUDA because runtime was already destroyed.

Also, unique_ptr<T[]> is more suitable than deque<T> for the purpose.

Reviewed By: Yangqing

Differential Revision: D13571988

fbshipit-source-id: 37eb26dfbe361c49160367b53f87bd037c6c0e46
This commit is contained in:
Dmytro Dzhulgakov 2019-01-11 12:32:50 -08:00 committed by Facebook Github Bot
parent 726341fea7
commit 96ea2594d8
9 changed files with 143 additions and 99 deletions

View file

@ -11,6 +11,7 @@
#include <c10/macros/Macros.h>
#include <c10/core/Device.h>
#include <c10/cuda/CUDAException.h>
namespace c10 {
namespace cuda {

View file

@ -31,3 +31,8 @@
#else
#define C10_CUDA_API C10_CUDA_IMPORT
#endif
/**
* The maximum number of GPUs that we recognizes.
*/
#define C10_COMPILE_TIME_MAX_GPUS 16

View file

@ -1,26 +1,35 @@
#include <c10/cuda/CUDAStream.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/util/Exception.h>
#include <mutex>
#include <array>
#include <atomic>
#include <cstdint>
#include <deque>
#include <mutex>
#include <vector>
#include <array>
#include <iostream>
namespace c10 {
namespace cuda {
namespace {
// Internal implementation is entirely hidden
struct CUDAStreamInternals {
CUDAStreamInternals() = default;
// Internal implementation that leaks the stream. It's not intended to be used
// outside of this file.
struct LeakyStreamInternals {
LeakyStreamInternals() = default;
C10_DISABLE_COPY_AND_ASSIGN(LeakyStreamInternals);
~CUDAStreamInternals() {
if (stream) cudaStreamDestroy(stream);
~LeakyStreamInternals() {
// NB: this code is invoked only in the destruction of global variables
// (since we never shrink the corresponding vectors). At this point the CUDA
// runtime might be already destroyed and invoking cudaStreamDestroy leads
// to a crash. It's likely an issue in CUDA, but to be safe - let's just
// "forget" the destruction.
// if (stream) cudaStreamDestroy(stream);
}
DeviceIndex device_index = -1;
@ -37,13 +46,13 @@ static constexpr unsigned int kDefaultFlags = cudaStreamNonBlocking;
// Note: stream priority is not supported by HIP
// Note: lower numbers are higher priorities, zero is default priority
#ifndef __HIP_PLATFORM_HCC__
static int kHighPriority = -1;
static int kLowPriority = 0;
static int kHighPriority = -1;
static int kLowPriority = 0;
#endif // __HIP_PLATFORM_HCC__
// Default streams
static std::once_flag init_flag;
static std::vector<CUDAStreamInternals> default_streams;
static LeakyStreamInternals default_streams[C10_COMPILE_TIME_MAX_GPUS];
// Non-default streams
// Note: the number of CUDA devices is determined at run time,
@ -53,11 +62,16 @@ static std::vector<CUDAStreamInternals> default_streams;
// the low and high priority counters track, for each device, the next stream
// in the pool to be returned when a stream is requested (round-robin fashion
// , see the note in CUDAStream.h).
static std::deque<std::once_flag> device_flags;
static std::deque<std::atomic<uint32_t>> low_priority_counters;
static std::deque<std::atomic<uint32_t>> high_priority_counters;
static std::vector<std::array<CUDAStreamInternals, kStreamsPerPool>> low_priority_streams;
static std::vector<std::array<CUDAStreamInternals, kStreamsPerPool>> high_priority_streams;
//
// unique_ptr<T[]> is used instead of vector<T> because T might be non-moveable
// and non-copyable.
static std::once_flag device_flags[C10_COMPILE_TIME_MAX_GPUS];
static std::atomic<uint32_t> low_priority_counters[C10_COMPILE_TIME_MAX_GPUS];
static std::atomic<uint32_t> high_priority_counters[C10_COMPILE_TIME_MAX_GPUS];
static std::array<LeakyStreamInternals, kStreamsPerPool>
low_priority_streams[C10_COMPILE_TIME_MAX_GPUS];
static std::array<LeakyStreamInternals, kStreamsPerPool>
high_priority_streams[C10_COMPILE_TIME_MAX_GPUS];
// Note [StreamId assignment]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -88,8 +102,8 @@ static std::vector<std::array<CUDAStreamInternals, kStreamsPerPool>> high_priori
enum class StreamIdType : uint8_t {
DEFAULT = 0x0,
LOW = 0x1,
HIGH = 0x2,
LOW = 0x1,
HIGH = 0x2,
};
std::ostream& operator<<(std::ostream& stream, StreamIdType s) {
@ -123,15 +137,17 @@ static inline size_t streamIdIndex(StreamId s) {
}
StreamId makeStreamId(StreamIdType st, size_t si) {
return (static_cast<StreamId>(st) << kStreamsPerPoolBits) | static_cast<StreamId>(si);
return (static_cast<StreamId>(st) << kStreamsPerPoolBits) |
static_cast<StreamId>(si);
}
template <typename T, typename A>
static bool pointer_within(const T* ptr, const A& arr) {
return std::greater_equal<const T*>()(ptr, arr.data()) && std::less<const T*>()(ptr, arr.data() + arr.size());
return std::greater_equal<const T*>()(ptr, arr.data()) &&
std::less<const T*>()(ptr, arr.data() + arr.size());
}
static StreamId CUDAStream_getStreamId(const CUDAStreamInternals* ptr) {
static StreamId CUDAStream_getStreamId(const LeakyStreamInternals* ptr) {
// Hypothetically, we could store the stream ID in the stream. But that
// introduces a degree of freedom which could lead to bugs (where we
// misnumber streams in the pool, or overwrite the number). Better
@ -149,21 +165,30 @@ static StreamId CUDAStream_getStreamId(const CUDAStreamInternals* ptr) {
// NB: Because ptr may not necessarily lie within the array, we must use
// std::less and similar templates to avoid UB that arises when
// doing an operator< comparison.
if (pointer_within<CUDAStreamInternals>(ptr, low_priority_streams[device_index])) {
return makeStreamId(StreamIdType::LOW, ptr - low_priority_streams[device_index].data());
if (pointer_within<LeakyStreamInternals>(
ptr, low_priority_streams[device_index])) {
return makeStreamId(
StreamIdType::LOW, ptr - low_priority_streams[device_index].data());
}
// Check if it's a high priority stream
if (pointer_within<CUDAStreamInternals>(ptr, high_priority_streams[device_index])) {
return makeStreamId(StreamIdType::HIGH, ptr - high_priority_streams[device_index].data());
if (pointer_within<LeakyStreamInternals>(
ptr, high_priority_streams[device_index])) {
return makeStreamId(
StreamIdType::HIGH, ptr - high_priority_streams[device_index].data());
}
AT_ASSERTM(0, "Could not compute stream ID for ", ptr, " on device ", device_index,
" (something has gone horribly wrong!)");
AT_ASSERTM(
0,
"Could not compute stream ID for ",
ptr,
" on device ",
device_index,
" (something has gone horribly wrong!)");
}
// Thread-local current streams
static thread_local CUDAStreamInternals** current_streams = nullptr;
static thread_local LeakyStreamInternals** current_streams = nullptr;
// Populates global values and creates a default stream for each device.
// Note: the default stream on each device is signified by a nullptr,
@ -173,14 +198,14 @@ static thread_local CUDAStreamInternals** current_streams = nullptr;
// Warning: this function must only be called once!
static void initGlobalStreamState() {
num_gpus = device_count();
// Resizes deques and vectors
default_streams.resize(num_gpus);
device_flags.resize(num_gpus);
low_priority_counters.resize(num_gpus);
high_priority_counters.resize(num_gpus);
low_priority_streams.resize(num_gpus);
high_priority_streams.resize(num_gpus);
// Check if the number of GPUs matches the expected compile-time max number
// of GPUs.
AT_ASSERTM(
num_gpus <= C10_COMPILE_TIME_MAX_GPUS,
"Number of CUDA devices on the machine is larger than the compiled "
"max number of gpus expected (",
C10_COMPILE_TIME_MAX_GPUS,
"). Increase that and recompile.");
// Initializes default streams
for (auto i = decltype(num_gpus){0}; i < num_gpus; ++i) {
@ -204,23 +229,17 @@ static void initDeviceStreamState(DeviceIndex device_index) {
lowpri_stream.device_index = device_index;
hipri_stream.device_index = device_index;
#ifndef __HIP_PLATFORM_HCC__
C10_CUDA_CHECK(cudaStreamCreateWithPriority(
&lowpri_stream.stream
, kDefaultFlags
, kLowPriority));
C10_CUDA_CHECK(cudaStreamCreateWithPriority(
&hipri_stream.stream
, kDefaultFlags
, kHighPriority));
#else
C10_CUDA_CHECK(cudaStreamCreateWithFlags(
&lowpri_stream.stream
, kDefaultFlags));
C10_CUDA_CHECK(cudaStreamCreateWithFlags(
&hipri_stream.stream
, kDefaultFlags));
#endif // __HIP_PLATFORM_HCC__
#ifndef __HIP_PLATFORM_HCC__
C10_CUDA_CHECK(cudaStreamCreateWithPriority(
&lowpri_stream.stream, kDefaultFlags, kLowPriority));
C10_CUDA_CHECK(cudaStreamCreateWithPriority(
&hipri_stream.stream, kDefaultFlags, kHighPriority));
#else
C10_CUDA_CHECK(
cudaStreamCreateWithFlags(&lowpri_stream.stream, kDefaultFlags));
C10_CUDA_CHECK(
cudaStreamCreateWithFlags(&hipri_stream.stream, kDefaultFlags));
#endif // __HIP_PLATFORM_HCC__
}
}
@ -229,10 +248,13 @@ static void initCUDAStreamsOnce() {
// Inits default streams (once, globally)
std::call_once(init_flag, initGlobalStreamState);
if (current_streams) return;
if (current_streams) {
return;
}
// Inits current streams (thread local) to default streams
current_streams = (CUDAStreamInternals**) malloc(num_gpus * sizeof(CUDAStreamInternals*));
current_streams =
(LeakyStreamInternals**)malloc(num_gpus * sizeof(LeakyStreamInternals*));
for (auto i = decltype(num_gpus){0}; i < num_gpus; ++i) {
current_streams[i] = &default_streams[i];
}
@ -245,37 +267,50 @@ static inline void check_gpu(DeviceIndex device_index) {
// Helper to determine the index of the stream to return
// Note: Streams are returned round-robin (see note in CUDAStream.h)
static uint32_t get_idx(std::atomic<uint32_t> &counter) {
static uint32_t get_idx(std::atomic<uint32_t>& counter) {
auto raw_idx = counter++;
return raw_idx % kStreamsPerPool;
}
// See Note [StreamId assignment]
CUDAStreamInternals* CUDAStream_internals(CUDAStream s) {
LeakyStreamInternals* CUDAStream_internals(CUDAStream s) {
c10::DeviceIndex device_index = s.device_index();
StreamIdType st = streamIdType(s.unwrap().id());
size_t si = streamIdIndex(s.unwrap().id());
switch (st) {
case StreamIdType::DEFAULT:
AT_ASSERTM(si == 0, "Unrecognized stream ", s.unwrap(),
" (I think this should be the default stream, but I got a non-zero index ", si, ").",
" Did you manufacture the StreamId yourself? Don't do that; use the",
" official API like c10::cuda::getStreamFromPool() to get a new stream.");
AT_ASSERTM(
si == 0,
"Unrecognized stream ",
s.unwrap(),
" (I think this should be the default stream, but I got a non-zero index ",
si,
").",
" Did you manufacture the StreamId yourself? Don't do that; use the",
" official API like c10::cuda::getStreamFromPool() to get a new stream.");
return &default_streams[device_index];
case StreamIdType::LOW:
return &low_priority_streams[device_index][si];
case StreamIdType::HIGH:
return &high_priority_streams[device_index][si];
default:
AT_ASSERTM(0, "Unrecognized stream ", s.unwrap(), " (I didn't recognize the stream type, ", st, ")");
AT_ASSERTM(
0,
"Unrecognized stream ",
s.unwrap(),
" (I didn't recognize the stream type, ",
st,
")");
}
}
CUDAStream CUDAStream_fromInternals(const CUDAStreamInternals* ptr) {
return CUDAStream(CUDAStream::UNCHECKED,
Stream(Stream::UNSAFE,
c10::Device(DeviceType::CUDA, ptr->device_index),
CUDAStream_getStreamId(ptr)));
CUDAStream CUDAStream_fromInternals(const LeakyStreamInternals* ptr) {
return CUDAStream(
CUDAStream::UNCHECKED,
Stream(
Stream::UNSAFE,
c10::Device(DeviceType::CUDA, ptr->device_index),
CUDAStream_getStreamId(ptr)));
}
} // anonymous namespace
@ -290,14 +325,16 @@ cudaStream_t CUDAStream::stream() const {
// Note: when called the first time on a device, this will create the
// stream pools for that device.
CUDAStream getStreamFromPool(
const bool isHighPriority
, DeviceIndex device_index) {
const bool isHighPriority,
DeviceIndex device_index) {
initCUDAStreamsOnce();
if (device_index == -1) device_index = current_device();
if (device_index == -1)
device_index = current_device();
check_gpu(device_index);
// Initializes the stream pools (once)
std::call_once(device_flags[device_index], initDeviceStreamState, device_index);
std::call_once(
device_flags[device_index], initDeviceStreamState, device_index);
if (isHighPriority) {
const auto idx = get_idx(high_priority_counters[device_index]);
@ -310,13 +347,17 @@ CUDAStream getStreamFromPool(
CUDAStream getDefaultCUDAStream(DeviceIndex device_index) {
initCUDAStreamsOnce();
if (device_index == -1) device_index = current_device();
if (device_index == -1) {
device_index = current_device();
}
check_gpu(device_index);
return CUDAStream_fromInternals(&default_streams[device_index]);
}
CUDAStream getCurrentCUDAStream(DeviceIndex device_index) {
initCUDAStreamsOnce();
if (device_index == -1) device_index = current_device();
if (device_index == -1) {
device_index = current_device();
}
check_gpu(device_index);
return CUDAStream_fromInternals(current_streams[device_index]);
}
@ -333,4 +374,4 @@ std::ostream& operator<<(std::ostream& stream, const CUDAStream& s) {
}
} // namespace cuda
} // namespace at
} // namespace c10

View file

@ -212,8 +212,8 @@ std::pair<std::vector<DeviceOption>, std::vector<DeviceOption>> ncclOpDevInfer(
REGISTER_CUDA_OPERATOR(NCCLAllreduce, NCCLAllreduceOp);
OPERATOR_SCHEMA(NCCLAllreduce)
.NumInputs(1, CAFFE2_COMPILE_TIME_MAX_GPUS)
.NumOutputs(1, CAFFE2_COMPILE_TIME_MAX_GPUS)
.NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
.NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS)
.CostInferenceFunction(NCCLAllreduceOp::CostInference)
.TensorInferenceFunction(NCCLAllreduceOp::ShapeInference)
.IdenticalTypeAndShape()
@ -224,8 +224,8 @@ SHOULD_NOT_DO_GRADIENT(NCCLAllreduce);
REGISTER_CUDA_OPERATOR(NCCLBroadcast, NCCLBroadcastOp);
OPERATOR_SCHEMA(NCCLBroadcast)
.NumInputs(1, CAFFE2_COMPILE_TIME_MAX_GPUS)
.NumOutputs(1, CAFFE2_COMPILE_TIME_MAX_GPUS)
.NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
.NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS)
.IdenticalTypeAndShape()
.InputsCanCrossDevices()
.EnforceOneToOneInplace()
@ -235,7 +235,7 @@ SHOULD_NOT_DO_GRADIENT(NCCLBroadcast);
REGISTER_CUDA_OPERATOR(NCCLReduce, NCCLReduceOp);
OPERATOR_SCHEMA(NCCLReduce)
.NumInputs(1, CAFFE2_COMPILE_TIME_MAX_GPUS)
.NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
.NumOutputs(1)
.IdenticalTypeAndShapeOfInput(0)
.InputsCanCrossDevices()
@ -245,16 +245,16 @@ SHOULD_NOT_DO_GRADIENT(NCCLReduce);
REGISTER_CUDA_OPERATOR(NCCLAllGather, NCCLAllGatherOp);
OPERATOR_SCHEMA(NCCLAllGather)
.NumInputs(1, CAFFE2_COMPILE_TIME_MAX_GPUS)
.NumOutputs(1, CAFFE2_COMPILE_TIME_MAX_GPUS)
.NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
.NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS)
.InputsCanCrossDevices()
.DeviceInferenceFunction(ncclOpDevInfer);
SHOULD_NOT_DO_GRADIENT(NCCLAllGather);
REGISTER_CUDA_OPERATOR(NCCLReduceScatter, NCCLReduceScatterOp);
OPERATOR_SCHEMA(NCCLReduceScatter)
.NumInputs(1, CAFFE2_COMPILE_TIME_MAX_GPUS)
.NumOutputs(1, CAFFE2_COMPILE_TIME_MAX_GPUS)
.NumInputs(1, C10_COMPILE_TIME_MAX_GPUS)
.NumOutputs(1, C10_COMPILE_TIME_MAX_GPUS)
.InputsCanCrossDevices()
.DeviceInferenceFunction(ncclOpDevInfer);
SHOULD_NOT_DO_GRADIENT(NCCLReduceScatter);

View file

@ -25,6 +25,7 @@
#include "caffe2/core/common.h"
#include "caffe2/core/logging.h"
#include "c10/cuda/CUDAMacros.h"
#include "c10/cuda/CUDAMathCompat.h"
// Defines CAFFE2_CUDA_EXPORT and CAFFE2_CUDA_IMPORT. On Windows, this
@ -94,10 +95,6 @@ constexpr int kFp16CUDADevicePropMajor = 3;
#endif // __GNUC__
#endif // CUDA_VERSION >= 9000
/**
* The maximum number of GPUs that caffe2 recognizes.
*/
#define CAFFE2_COMPILE_TIME_MAX_GPUS 16
/**
* The maximum number of peers that each gpu can have when doing p2p setup.
* Currently, according to NVidia documentation, each device can support a

View file

@ -178,8 +178,8 @@ static std::unordered_map<void*, uint8_t> g_cuda_device_affiliation;
// Data structures for optional memory tracking. Access to these structures
// is garded by the CUDAContext::mutex.
static std::unordered_map<void*, long> g_size_map;
static std::vector<long> g_total_by_gpu_map(CAFFE2_COMPILE_TIME_MAX_GPUS, 0);
static std::vector<long> g_max_by_gpu_map(CAFFE2_COMPILE_TIME_MAX_GPUS, 0);
static std::vector<long> g_total_by_gpu_map(C10_COMPILE_TIME_MAX_GPUS, 0);
static std::vector<long> g_max_by_gpu_map(C10_COMPILE_TIME_MAX_GPUS, 0);
static long g_total_mem = 0;
static long g_last_rep = 0;
@ -207,11 +207,11 @@ static void Caffe2InitializeCuda() {
// of GPUs.
CAFFE_ENFORCE_LE(
NumCudaDevices(),
CAFFE2_COMPILE_TIME_MAX_GPUS,
C10_COMPILE_TIME_MAX_GPUS,
"Number of CUDA devices on the machine is larger than the compiled "
"max number of gpus expected (",
CAFFE2_COMPILE_TIME_MAX_GPUS,
"). Increase that and recompile the caffe binary.");
C10_COMPILE_TIME_MAX_GPUS,
"). Increase that and recompile.");
for (DeviceIndex i = 0; i < NumCudaDevices(); ++i) {
DeviceGuard g(i);

View file

@ -57,7 +57,7 @@ class CAFFE2_CUDA_API ThreadLocalCUDAObjects {
private:
ThreadLocalCUDAObjects() {
for (DeviceIndex i = 0; i < CAFFE2_COMPILE_TIME_MAX_GPUS; ++i) {
for (DeviceIndex i = 0; i < C10_COMPILE_TIME_MAX_GPUS; ++i) {
cuda_streams_[i] = vector<c10::cuda::CUDAStream>();
}
}
@ -153,7 +153,7 @@ class CAFFE2_CUDA_API ThreadLocalCUDAObjects {
// WARNING: mapping from logical stream ID to c10::cuda::CUDAStream
// is NOT bijective; multiple logical stream IDs may map to the
// same underlying stream ID.
vector<c10::cuda::CUDAStream> cuda_streams_[CAFFE2_COMPILE_TIME_MAX_GPUS];
vector<c10::cuda::CUDAStream> cuda_streams_[C10_COMPILE_TIME_MAX_GPUS];
std::unordered_map<c10::cuda::CUDAStream, cublasHandle_t> cublas_handles_;
#ifdef CAFFE2_USE_CUDNN
std::unordered_map<c10::cuda::CUDAStream, cudnnHandle_t> cudnn_handles_;

View file

@ -149,7 +149,7 @@ class CuDNNWrapper {
using PerGPUCuDNNStates = std::array<
std::array<SyncedCuDNNState, CAFFE2_COMPILE_TIME_MAX_CUDNN_STATES>,
CAFFE2_COMPILE_TIME_MAX_GPUS>;
C10_COMPILE_TIME_MAX_GPUS>;
static PerGPUCuDNNStates& cudnn_states();
C10_DISABLE_COPY_AND_ASSIGN(CuDNNWrapper);

View file

@ -151,9 +151,9 @@ class MIOPENWrapper
std::unique_ptr<MIOPENState> state;
};
using PerGPUMIOPENStates =
std::array<std::array<SyncedMIOPENState, CAFFE2_COMPILE_TIME_MAX_MIOPEN_STATES>,
CAFFE2_COMPILE_TIME_MAX_GPUS>;
using PerGPUMIOPENStates = std::array<
std::array<SyncedMIOPENState, CAFFE2_COMPILE_TIME_MAX_MIOPEN_STATES>,
C10_COMPILE_TIME_MAX_GPUS>;
static PerGPUMIOPENStates& miopen_states();
C10_DISABLE_COPY_AND_ASSIGN(MIOPENWrapper);