mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-19 21:32:23 +00:00
Use CUDA callback to release deferred-release buffers (#12883)
* Use CUDA callback to release deferred-release buffers Polishment * Minor improvements. 1. Reorder a if-else so that frequent cases are checked first. 2. More documents. * Fix tests. Previously, in CUDAExecutionProvider::OnRunStart, we call GetPerThreadContext in auto& current_deferred_release_event = GetPerThreadContext().GetCurrentDeferredReleaseEvent(); so that a CUDAExecutionProvider always owns an active PerThreadContext and the ReleasePerThreadContext in CUDAExecutionProvider::OnRunEnd is always valid. However, this isn't true after we drop event- based deferred-release code, so we need to check if CUDAExecutionProvider really owns PerThreadContext than call ReleasePerThreadContext if yes. * Follow up for AMD GPU and improve CUDA part's return value.
This commit is contained in:
parent
55c745eefd
commit
28f2e57de5
4 changed files with 252 additions and 149 deletions
|
|
@ -227,22 +227,9 @@ CUDAExecutionProvider::CUDAExecutionProvider(const CUDAExecutionProviderInfo& in
|
|||
}
|
||||
|
||||
CUDAExecutionProvider::~CUDAExecutionProvider() {
|
||||
auto cpu_alloc = GetAllocator(DEFAULT_CPU_ALLOCATOR_DEVICE_ID, OrtMemTypeCPU);
|
||||
{
|
||||
std::lock_guard<OrtMutex> lock(deferred_release_cpu_ptr_mutex_);
|
||||
auto it = deferred_release_cpu_ptr_.begin();
|
||||
while (it != deferred_release_cpu_ptr_.end()) {
|
||||
auto& e = it->first;
|
||||
auto& v = it->second;
|
||||
if (v.recorded)
|
||||
CUDA_CALL_THROW(cudaEventSynchronize(e));
|
||||
for (auto p : v.cpu_ptrs) {
|
||||
cpu_alloc->Free(p);
|
||||
}
|
||||
CUDA_CALL_THROW(cudaEventDestroy(e));
|
||||
it = deferred_release_cpu_ptr_.erase(it);
|
||||
}
|
||||
}
|
||||
// Prevent memory leak when people don't call
|
||||
// OnRunStart and OnRunEnd when calling CudaKernel's.
|
||||
ORT_IGNORE_RETURN_VALUE(EnqueueDeferredRelease());
|
||||
|
||||
// clean up thread local context caches
|
||||
{
|
||||
|
|
@ -352,43 +339,91 @@ void CUDAExecutionProvider::AddDeferredReleaseCPUPtr(void* p) {
|
|||
// when not running in InferenceSession (e.g. Test)
|
||||
// it's OK to not remember the deferred release ptr
|
||||
// as the actual memory will be cleaned in arena allocator dtor
|
||||
auto current_deferred_release_event = GetPerThreadContext().GetCurrentDeferredReleaseEvent();
|
||||
if (current_deferred_release_event) {
|
||||
std::lock_guard<OrtMutex> lock(deferred_release_cpu_ptr_mutex_);
|
||||
auto iter = deferred_release_cpu_ptr_.find(current_deferred_release_event);
|
||||
ORT_ENFORCE(iter != deferred_release_cpu_ptr_.end());
|
||||
iter->second.cpu_ptrs.push_back(p);
|
||||
|
||||
// This function should only record pointers returned by
|
||||
// AllocateBufferOnCPUPinned.
|
||||
|
||||
std::lock_guard<OrtMutex> lock(deferred_release_mutex_);
|
||||
void* stream = GetComputeStream();
|
||||
auto it = deferred_release_buffer_pool_.find(stream);
|
||||
if (it != deferred_release_buffer_pool_.end()) {
|
||||
it->second.push_back(p);
|
||||
} else {
|
||||
deferred_release_buffer_pool_[stream] = {p};
|
||||
}
|
||||
}
|
||||
|
||||
struct CpuBuffersInfo {
|
||||
// This struct stores the information needed
|
||||
// to release CPU buffers allocated for GPU kernels.
|
||||
// It's used to enqueue their release after
|
||||
// associated GPU kernels in a CUDA stream.
|
||||
|
||||
// This is a CPU allocator in CUDA EP.
|
||||
// It must be the one used to allocate the
|
||||
// following pointers.
|
||||
AllocatorPtr allocator;
|
||||
// buffers[i] is the i-th pointer added by
|
||||
// AddDeferredReleaseCPUPtr for a specific
|
||||
// CUDA stream. For example, this fields
|
||||
// should contain all values in
|
||||
// deferred_release_buffer_pool_[my_stream]
|
||||
// when release my_stream's buffers.
|
||||
void** buffers;
|
||||
// CPU buffer buffers[i].
|
||||
// Number of buffer points in "buffers".
|
||||
size_t n_buffers;
|
||||
};
|
||||
|
||||
static void CUDART_CB ReleaseCpuBufferCallback(void* raw_info) {
|
||||
auto info = reinterpret_cast<CpuBuffersInfo*>(raw_info);
|
||||
// Uncomment the following line to check if all previous stream
|
||||
// operations are done correctly.
|
||||
// checkCudaErrors(tmp->status);
|
||||
for (size_t i = 0; i < info->n_buffers; ++i) {
|
||||
info->allocator->Free(info->buffers[i]);
|
||||
}
|
||||
delete[] info->buffers;
|
||||
delete info;
|
||||
}
|
||||
|
||||
Status CUDAExecutionProvider::EnqueueDeferredRelease() {
|
||||
// Release CPU buffers allocated for CUDA kernels (type: CudaKernel).
|
||||
// They have to be released outside CUDA kernels because they must be alive
|
||||
// during asynchronous GPU computation even after the CPU part (e.g,
|
||||
// CudaKernel::ComputeInternal) already return.
|
||||
std::lock_guard<OrtMutex> lock(deferred_release_mutex_);
|
||||
for (auto it = deferred_release_buffer_pool_.begin(); it != deferred_release_buffer_pool_.end(); ++it) {
|
||||
// it->first: a CUDA stream.
|
||||
// it->second: CPU buffers associated with kernels running on it->first.
|
||||
// This iteration enqueues a callback to release all buffers
|
||||
// in it->second on it->first.
|
||||
|
||||
auto stream = static_cast<cudaStream_t>(it->first);
|
||||
auto& buffers = it->second;
|
||||
// Allocate a heap object to extend the lifetime of allocator and buffer pointers
|
||||
// until the end of callback (aka ReleaseCpuBufferCallback).
|
||||
auto cpu_buffers_info = new CpuBuffersInfo;
|
||||
// This allocator must be the same to the allocator
|
||||
// used in AllocateBufferOnCPUPinned.
|
||||
cpu_buffers_info->allocator = GetAllocator(DEFAULT_CPU_ALLOCATOR_DEVICE_ID, OrtMemTypeCPU);
|
||||
cpu_buffers_info->buffers = new void*[buffers.size()];
|
||||
for (size_t i = 0; i < buffers.size(); ++i) {
|
||||
cpu_buffers_info->buffers[i] = buffers.at(i);
|
||||
}
|
||||
cpu_buffers_info->n_buffers = buffers.size();
|
||||
CUDA_RETURN_IF_ERROR(cudaLaunchHostFunc(stream, ReleaseCpuBufferCallback, cpu_buffers_info));
|
||||
}
|
||||
// All buffers are scheduled for release.
|
||||
// Let's clear releated information so that
|
||||
// those buffers won't be released twice.
|
||||
deferred_release_buffer_pool_.clear();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CUDAExecutionProvider::OnRunStart() {
|
||||
// always set CUDA device when session::Run() in case it runs in a worker thread
|
||||
CUDA_RETURN_IF_ERROR(cudaSetDevice(GetDeviceId()));
|
||||
auto cpu_alloc = GetAllocator(0, OrtMemTypeCPU);
|
||||
// check if cudaEvents has passed for deferred release
|
||||
// note that we need to take a mutex in case of multi-threaded Run()
|
||||
std::lock_guard<OrtMutex> lock(deferred_release_cpu_ptr_mutex_);
|
||||
auto it = deferred_release_cpu_ptr_.begin();
|
||||
while (it != deferred_release_cpu_ptr_.end()) {
|
||||
auto& e = it->first;
|
||||
auto& v = it->second;
|
||||
// note that cudaEventQuery returns cudaSucess before first cudaEventRecord
|
||||
if (v.recorded && cudaSuccess == cudaEventQuery(e)) {
|
||||
for (auto p : v.cpu_ptrs) {
|
||||
cpu_alloc->Free(p);
|
||||
}
|
||||
cudaEvent_t expired_event = it->first;
|
||||
it = deferred_release_cpu_ptr_.erase(it);
|
||||
CUDA_RETURN_IF_ERROR(cudaEventDestroy(expired_event));
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
auto& current_deferred_release_event = GetPerThreadContext().GetCurrentDeferredReleaseEvent();
|
||||
CUDA_RETURN_IF_ERROR(cudaEventCreate(¤t_deferred_release_event, cudaEventDisableTiming));
|
||||
deferred_release_cpu_ptr_.emplace(current_deferred_release_event, DeferredReleaseCPUPtrs());
|
||||
|
||||
if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) {
|
||||
LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model";
|
||||
GetPerThreadContext().CaptureBegin();
|
||||
|
|
@ -407,21 +442,29 @@ Status CUDAExecutionProvider::OnRunEnd(bool sync_stream) {
|
|||
GetPerThreadContext().IncrementRegularRunCountBeforeGraphCapture();
|
||||
}
|
||||
}
|
||||
// record deferred release event on default stream, and release per_thread_context
|
||||
auto current_deferred_release_event = GetPerThreadContext().GetCurrentDeferredReleaseEvent();
|
||||
CUDA_RETURN_IF_ERROR(cudaEventRecord(current_deferred_release_event, static_cast<cudaStream_t>(GetComputeStream())));
|
||||
|
||||
// Enqueue deferred CPU memory release on related streams.
|
||||
// This will release all deferred-release CPU buffers allocated
|
||||
// before calling OnRunEnd.
|
||||
ORT_RETURN_IF_ERROR(EnqueueDeferredRelease());
|
||||
|
||||
if (sync_stream) {
|
||||
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(static_cast<cudaStream_t>(GetComputeStream())));
|
||||
}
|
||||
|
||||
// If cuda graph is enabled, the per thread context will not be released
|
||||
// because the per thread cuda graph needs to be maintained and replayed for
|
||||
// the next run.
|
||||
if (!IsGraphCaptureEnabled()) {
|
||||
// The reason of !IsGraphCaptureEnabled():
|
||||
// If cuda graph is enabled, the per thread context will not be released
|
||||
// because the per thread cuda graph needs to be maintained and replayed for
|
||||
// the next run.
|
||||
// The reason of PerThreadContextCache()->find(this) != PerThreadContextCache()->end():
|
||||
// In extreme cases (e.g., 1-op graph and that op fallbacks to CPU),
|
||||
// PerThreadContext won't be created and there isbe nothing to
|
||||
// release. This didn't happen before because we always call
|
||||
// GetPerThreadContext in OnRunStart.
|
||||
if (!IsGraphCaptureEnabled() &&
|
||||
PerThreadContextCache()->find(this) != PerThreadContextCache()->end()) {
|
||||
ReleasePerThreadContext();
|
||||
}
|
||||
std::lock_guard<OrtMutex> lock(deferred_release_cpu_ptr_mutex_);
|
||||
deferred_release_cpu_ptr_[current_deferred_release_event].recorded = true;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -58,7 +58,23 @@ class CUDAExecutionProvider : public IExecutionProvider {
|
|||
return GetPerThreadContext().template GetConstOnes<T>(count);
|
||||
}
|
||||
|
||||
// Add CPU buffer to a buffer pool.
|
||||
// They can and only can be released
|
||||
// by calling EuqueueDeferredRelease.
|
||||
// A common pattern is
|
||||
// 1. auto buffer = AllocateBufferOnCPUPinned<char>(128);
|
||||
// 2. Some GPU kernel calls on GPU stream from GetComputeStream.
|
||||
// 3. Call AddDeferredReleaseCPUPtr(buffer.release());
|
||||
// 4. Call EnqueueDeferredRelease();
|
||||
// so that the allocated "buffer" in (1) will be released
|
||||
// only after all GPU kernels in (2) are finished.
|
||||
// (4) is done in OnRunEnd, so the user doesn't need to call
|
||||
// it in most cases.
|
||||
void AddDeferredReleaseCPUPtr(void* p);
|
||||
// Release all buffers added by
|
||||
// AddDeferredReleaseCPUPtr using
|
||||
// GPU callback (so it's async).
|
||||
Status EnqueueDeferredRelease();
|
||||
|
||||
template <typename T>
|
||||
IAllocatorUniquePtr<T> GetScratchBuffer(size_t count_or_bytes) const {
|
||||
|
|
@ -112,13 +128,17 @@ class CUDAExecutionProvider : public IExecutionProvider {
|
|||
bool external_stream_ = false;
|
||||
cudaStream_t stream_ = nullptr;
|
||||
|
||||
struct DeferredReleaseCPUPtrs {
|
||||
bool recorded = false;
|
||||
std::vector<void*> cpu_ptrs;
|
||||
};
|
||||
|
||||
std::unordered_map<cudaEvent_t, DeferredReleaseCPUPtrs> deferred_release_cpu_ptr_;
|
||||
OrtMutex deferred_release_cpu_ptr_mutex_;
|
||||
// deferred_release_buffer_pool_[my_stream] store all CPU buffers associated with
|
||||
// CUDA kernels running on my_stream (type: cudaStream_t).
|
||||
// Buffers' release is enqueued as a CUDA callback onto the associated stream (aka
|
||||
// stream returned by GetComputeStream when calling AddDeferredReleaseCPUPtr) in OnRunEnd.
|
||||
// Those are pointers allocated by AllocateBufferOnCPUPinned and should be released
|
||||
// by CPU Allocator's Free function.
|
||||
std::unordered_map<void*, std::vector<void*>> deferred_release_buffer_pool_;
|
||||
// To add a pointer to deferred_release_buffer_pool_, we need a mutex because
|
||||
// different threads may create CPU buffers at the same time. Releasing
|
||||
// buffers also needs this mutex.
|
||||
OrtMutex deferred_release_mutex_;
|
||||
|
||||
class PerThreadContext final {
|
||||
public:
|
||||
|
|
@ -138,10 +158,6 @@ class CUDAExecutionProvider : public IExecutionProvider {
|
|||
return cublas_lt_handle_;
|
||||
}
|
||||
|
||||
cudaEvent_t& GetCurrentDeferredReleaseEvent() {
|
||||
return current_deferred_release_event_;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const T* GetConstOnes(size_t count) {
|
||||
if (std::is_same<T, float>::value) {
|
||||
|
|
@ -188,11 +204,6 @@ class CUDAExecutionProvider : public IExecutionProvider {
|
|||
cudnnHandle_t cudnn_handle_ = nullptr;
|
||||
cublasLtHandle_t cublas_lt_handle_ = nullptr;
|
||||
|
||||
// deferred release for temporary CPU pinned memory used in cudaMemcpyAsync
|
||||
// note that cudaEvent will be assigned at OnRunEnd() when PerThreadContext destory
|
||||
// so the ownership is passed to deferred_release_cpu_ptr_
|
||||
cudaEvent_t current_deferred_release_event_ = nullptr;
|
||||
|
||||
std::unique_ptr<cuda::IConstantBuffer<float>> constant_ones_float_;
|
||||
std::unique_ptr<cuda::IConstantBuffer<double>> constant_ones_double_;
|
||||
std::unique_ptr<cuda::IConstantBuffer<half>> constant_ones_half_;
|
||||
|
|
@ -253,4 +264,4 @@ class CUDAExecutionProvider : public IExecutionProvider {
|
|||
void ReleasePerThreadContext() const;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -183,22 +183,9 @@ ROCMExecutionProvider::ROCMExecutionProvider(const ROCMExecutionProviderInfo& in
|
|||
}
|
||||
|
||||
ROCMExecutionProvider::~ROCMExecutionProvider() {
|
||||
auto cpu_alloc = GetAllocator(DEFAULT_CPU_ALLOCATOR_DEVICE_ID, OrtMemTypeCPU);
|
||||
{
|
||||
std::lock_guard<OrtMutex> lock(deferred_release_cpu_ptr_mutex_);
|
||||
auto it = deferred_release_cpu_ptr_.begin();
|
||||
while (it != deferred_release_cpu_ptr_.end()) {
|
||||
auto& e = it->first;
|
||||
auto& v = it->second;
|
||||
if (v.recorded)
|
||||
HIP_CALL_THROW(hipEventSynchronize(e));
|
||||
for (auto p : v.cpu_ptrs) {
|
||||
cpu_alloc->Free(p);
|
||||
}
|
||||
HIP_CALL_THROW(hipEventDestroy(e));
|
||||
it = deferred_release_cpu_ptr_.erase(it);
|
||||
}
|
||||
}
|
||||
// Prevent memory leak when people don't call
|
||||
// OnRunStart and OnRunEnd when calling CudaKernel's.
|
||||
ORT_IGNORE_RETURN_VALUE(EnqueueDeferredRelease());
|
||||
|
||||
// clean up thread local context caches
|
||||
{
|
||||
|
|
@ -294,63 +281,114 @@ void ROCMExecutionProvider::AddDeferredReleaseCPUPtr(void* p) {
|
|||
// when not running in InferenceSession (e.g. Test)
|
||||
// it's OK to not remember the deferred release ptr
|
||||
// as the actual memory will be cleaned in arena allocator dtor
|
||||
auto current_deferred_release_event = GetPerThreadContext().GetCurrentDeferredReleaseEvent();
|
||||
if (current_deferred_release_event) {
|
||||
std::lock_guard<OrtMutex> lock(deferred_release_cpu_ptr_mutex_);
|
||||
auto iter = deferred_release_cpu_ptr_.find(current_deferred_release_event);
|
||||
ORT_ENFORCE(iter != deferred_release_cpu_ptr_.end());
|
||||
iter->second.cpu_ptrs.push_back(p);
|
||||
|
||||
// This function should only record pointers returned by
|
||||
// AllocateBufferOnCPUPinned.
|
||||
|
||||
std::lock_guard<OrtMutex> lock(deferred_release_mutex_);
|
||||
void* stream = GetComputeStream();
|
||||
auto it = deferred_release_buffer_pool_.find(stream);
|
||||
if (it != deferred_release_buffer_pool_.end()) {
|
||||
it->second.push_back(p);
|
||||
} else {
|
||||
deferred_release_buffer_pool_[stream] = {p};
|
||||
}
|
||||
}
|
||||
|
||||
struct CpuBuffersInfo {
|
||||
// This struct stores the information needed
|
||||
// to release CPU buffers allocated for GPU kernels.
|
||||
// It's used to enqueue their release after
|
||||
// associated GPU kernels in a GPU stream.
|
||||
|
||||
// This is a CPU allocator in GPU EP.
|
||||
// It must be the one used to allocate the
|
||||
// following pointers.
|
||||
AllocatorPtr allocator;
|
||||
// buffers[i] is the i-th pointer added by
|
||||
// AddDeferredReleaseCPUPtr for a specific
|
||||
// GPU stream. For example, this fields
|
||||
// should contain all values in
|
||||
// deferred_release_buffer_pool_[my_stream]
|
||||
// when release my_stream's buffers.
|
||||
void** buffers;
|
||||
// CPU buffer buffers[i].
|
||||
// Number of buffer points in "buffers".
|
||||
size_t n_buffers;
|
||||
};
|
||||
|
||||
void ReleaseCpuBufferCallback(hipStream_t /*stream*/, hipError_t /*status*/, void* raw_info) {
|
||||
auto info = reinterpret_cast<CpuBuffersInfo*>(raw_info);
|
||||
for (size_t i = 0; i < info->n_buffers; ++i) {
|
||||
info->allocator->Free(info->buffers[i]);
|
||||
}
|
||||
delete[] info->buffers;
|
||||
delete info;
|
||||
}
|
||||
|
||||
Status ROCMExecutionProvider::EnqueueDeferredRelease() {
|
||||
// Release CPU buffers allocated for GPU kernels (type: RocmKernel).
|
||||
// They have to be released outside GPU kernels because they must be alive
|
||||
// during asynchronous GPU computation even after the CPU part (e.g,
|
||||
// RocmKernel::ComputeInternal) already return.
|
||||
std::lock_guard<OrtMutex> lock(deferred_release_mutex_);
|
||||
for (auto it = deferred_release_buffer_pool_.begin(); it != deferred_release_buffer_pool_.end(); ++it) {
|
||||
// it->first: a ROCM stream.
|
||||
// it->second: CPU buffers associated with kernels running on it->first.
|
||||
// This iteration enqueues a callback to release all buffers
|
||||
// in it->second on it->first.
|
||||
|
||||
auto stream = static_cast<hipStream_t>(it->first);
|
||||
auto& buffers = it->second;
|
||||
// Allocate a heap object to extend the lifetime of allocator and buffer pointers
|
||||
// until the end of callback (aka ReleaseCpuBufferCallback).
|
||||
auto cpu_buffers_info = new CpuBuffersInfo;
|
||||
// This allocator must be the same to the allocator
|
||||
// used in AllocateBufferOnCPUPinned.
|
||||
cpu_buffers_info->allocator = GetAllocator(DEFAULT_CPU_ALLOCATOR_DEVICE_ID, OrtMemTypeCPU);
|
||||
cpu_buffers_info->buffers = new void*[buffers.size()];
|
||||
for (size_t i = 0; i < buffers.size(); ++i) {
|
||||
cpu_buffers_info->buffers[i] = buffers.at(i);
|
||||
}
|
||||
cpu_buffers_info->n_buffers = buffers.size();
|
||||
// TODO(wechi): CUDA deprecates cudaStreamAddCallback and
|
||||
// uses another API, cudaLaunchHostFunc(which can be
|
||||
// captured in CUDA graph). Once AMD adds similar feature,
|
||||
// we should replace the following line with
|
||||
// hipLaunchHostFunc(stream, ReleaseCpuBufferCallback, cpu_buffers_info);
|
||||
HIP_RETURN_IF_ERROR(hipStreamAddCallback(stream, ReleaseCpuBufferCallback, cpu_buffers_info, 0));
|
||||
}
|
||||
// All buffers are scheduled for release.
|
||||
// Let's clear releated information so that
|
||||
// those buffers won't be released twice in
|
||||
// the next EnqueueDeferredRelease call.
|
||||
deferred_release_buffer_pool_.clear();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ROCMExecutionProvider::OnRunStart() {
|
||||
// always set ROCM device when session::Run() in case it runs in a worker thread
|
||||
HIP_RETURN_IF_ERROR(hipSetDevice(GetDeviceId()));
|
||||
auto cpu_alloc = GetAllocator(0, OrtMemTypeCPU);
|
||||
// check if hipEvents has passed for deferred release
|
||||
// note that we need to take a mutex in case of multi-threaded Run()
|
||||
std::lock_guard<OrtMutex> lock(deferred_release_cpu_ptr_mutex_);
|
||||
auto it = deferred_release_cpu_ptr_.begin();
|
||||
while (it != deferred_release_cpu_ptr_.end()) {
|
||||
auto& e = it->first;
|
||||
auto& v = it->second;
|
||||
// note that hipEventQuery returns rocmSucess before first hipEventRecord
|
||||
if (v.recorded) {
|
||||
auto event_query_status = hipEventQuery(e);
|
||||
if (event_query_status == hipSuccess) {
|
||||
for (auto p : v.cpu_ptrs) {
|
||||
cpu_alloc->Free(p);
|
||||
}
|
||||
HIP_RETURN_IF_ERROR(hipEventDestroy(e));
|
||||
it = deferred_release_cpu_ptr_.erase(it);
|
||||
} else if (event_query_status == hipErrorNotReady) {
|
||||
// ignore and clear the error if not ready; void to silence nodiscard
|
||||
(void)hipGetLastError();
|
||||
it++;
|
||||
} else {
|
||||
HIP_RETURN_IF_ERROR(event_query_status);
|
||||
}
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
auto& current_deferred_release_event = GetPerThreadContext().GetCurrentDeferredReleaseEvent();
|
||||
HIP_RETURN_IF_ERROR(hipEventCreateWithFlags(¤t_deferred_release_event, hipEventDisableTiming));
|
||||
deferred_release_cpu_ptr_.emplace(current_deferred_release_event, DeferredReleaseCPUPtrs());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ROCMExecutionProvider::OnRunEnd(bool sync_stream) {
|
||||
// record deferred release event on default stream, and release per_thread_context
|
||||
auto current_deferred_release_event = GetPerThreadContext().GetCurrentDeferredReleaseEvent();
|
||||
HIP_RETURN_IF_ERROR(hipEventRecord(current_deferred_release_event, static_cast<hipStream_t>(GetComputeStream())));
|
||||
// Enqueue deferred CPU memory release on related streams.
|
||||
// This will release all deferred-release CPU buffers allocated
|
||||
// before calling OnRunEnd.
|
||||
ORT_RETURN_IF_ERROR(EnqueueDeferredRelease());
|
||||
|
||||
if (sync_stream) {
|
||||
HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast<hipStream_t>(GetComputeStream())));
|
||||
}
|
||||
ReleasePerThreadContext();
|
||||
std::lock_guard<OrtMutex> lock(deferred_release_cpu_ptr_mutex_);
|
||||
deferred_release_cpu_ptr_[current_deferred_release_event].recorded = true;
|
||||
|
||||
// In extreme cases (e.g., 1-op graph and that op fallbacks to CPU),
|
||||
// PerThreadContext won't be created and there is nothing to
|
||||
// release. This didn't happen before because we always call
|
||||
// GetPerThreadContext in OnRunStart.
|
||||
if (PerThreadContextCache()->find(this) != PerThreadContextCache()->end()) {
|
||||
ReleasePerThreadContext();
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -53,7 +53,23 @@ class ROCMExecutionProvider : public IExecutionProvider {
|
|||
return GetPerThreadContext().template GetConstOnes<T>(count);
|
||||
}
|
||||
|
||||
// Add CPU buffer to a buffer pool.
|
||||
// They can and only can be released
|
||||
// by calling EuqueueDeferredRelease.
|
||||
// A common pattern is
|
||||
// 1. auto buffer = AllocateBufferOnCPUPinned<char>(128);
|
||||
// 2. Some GPU kernel calls on GPU stream from GetComputeStream.
|
||||
// 3. Call AddDeferredReleaseCPUPtr(buffer.release());
|
||||
// 4. Call EnqueueDeferredRelease();
|
||||
// so that the allocated "buffer" in (1) will be released
|
||||
// only after all GPU kernels in (2) are finished.
|
||||
// (4) is done in OnRunEnd, so the user doesn't need to call
|
||||
// it in most cases.
|
||||
void AddDeferredReleaseCPUPtr(void* p);
|
||||
// Release all buffers added by
|
||||
// AddDeferredReleaseCPUPtr using
|
||||
// GPU callback (so it's async).
|
||||
Status EnqueueDeferredRelease();
|
||||
|
||||
template <typename T>
|
||||
IAllocatorUniquePtr<T> GetScratchBuffer(size_t count_or_bytes) const {
|
||||
|
|
@ -101,13 +117,17 @@ class ROCMExecutionProvider : public IExecutionProvider {
|
|||
bool external_stream_ = false;
|
||||
hipStream_t stream_ = nullptr;
|
||||
|
||||
struct DeferredReleaseCPUPtrs {
|
||||
bool recorded = false;
|
||||
std::vector<void*> cpu_ptrs;
|
||||
};
|
||||
|
||||
std::unordered_map<hipEvent_t, DeferredReleaseCPUPtrs> deferred_release_cpu_ptr_;
|
||||
OrtMutex deferred_release_cpu_ptr_mutex_;
|
||||
// deferred_release_buffer_pool_[my_stream] store all CPU buffers associated with
|
||||
// CUDA kernels running on my_stream (type: cudaStream_t).
|
||||
// Buffers' release is enqueued as a CUDA callback onto the associated stream (aka
|
||||
// stream returned by GetComputeStream when calling AddDeferredReleaseCPUPtr) in OnRunEnd.
|
||||
// Those are pointers allocated by AllocateBufferOnCPUPinned and should be released
|
||||
// by CPU Allocator's Free function.
|
||||
std::unordered_map<void*, std::vector<void*>> deferred_release_buffer_pool_;
|
||||
// To add a pointer to deferred_release_buffer_pool_, we need a mutex because
|
||||
// different threads may create CPU buffers at the same time. Releasing
|
||||
// buffers also needs this mutex.
|
||||
OrtMutex deferred_release_mutex_;
|
||||
|
||||
class PerThreadContext final {
|
||||
public:
|
||||
|
|
@ -123,10 +143,6 @@ class ROCMExecutionProvider : public IExecutionProvider {
|
|||
return miopen_handle_;
|
||||
}
|
||||
|
||||
hipEvent_t& GetCurrentDeferredReleaseEvent() {
|
||||
return current_deferred_release_event_;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const T* GetConstOnes(size_t count) {
|
||||
if (std::is_same<T, float>::value) {
|
||||
|
|
@ -158,11 +174,6 @@ class ROCMExecutionProvider : public IExecutionProvider {
|
|||
rocblas_handle rocblas_handle_ = nullptr;
|
||||
miopenHandle_t miopen_handle_ = nullptr;
|
||||
|
||||
// deferred release for temporary CPU pinned memory used in hipMemcpyAsync
|
||||
// note that hipEvent will be assigned at OnRunEnd() when PerThreadContext destory
|
||||
// so the ownership is passed to deferred_release_cpu_ptr_
|
||||
hipEvent_t current_deferred_release_event_ = nullptr;
|
||||
|
||||
std::unique_ptr<rocm::IConstantBuffer<float>> constant_ones_float_;
|
||||
std::unique_ptr<rocm::IConstantBuffer<double>> constant_ones_double_;
|
||||
std::unique_ptr<rocm::IConstantBuffer<half>> constant_ones_half_;
|
||||
|
|
|
|||
Loading…
Reference in a new issue