From ff23b9ff55b7f90694a2e6abd64ec69afaa3c820 Mon Sep 17 00:00:00 2001 From: Sherlock Date: Thu, 21 Oct 2021 11:28:25 -0700 Subject: [PATCH] Avoid cudaStreamSync at the end of Forward/Backward (#9470) * Skip cudaStreamSynchronize at the end of fw * skip sync stream for end of backward --- include/onnxruntime/core/framework/execution_provider.h | 2 +- .../core/providers/cuda/cuda_execution_provider.cc | 6 ++++-- onnxruntime/core/providers/cuda/cuda_execution_provider.h | 2 +- .../core/providers/rocm/rocm_execution_provider.cc | 6 ++++-- onnxruntime/core/providers/rocm/rocm_execution_provider.h | 2 +- .../providers/tensorrt/tensorrt_execution_provider.cc | 8 +++++--- .../core/providers/tensorrt/tensorrt_execution_provider.h | 2 +- onnxruntime/core/session/inference_session.cc | 4 ++-- 8 files changed, 19 insertions(+), 13 deletions(-) diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index e23229c700..bea28f31b2 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -160,7 +160,7 @@ class IExecutionProvider { may not be finished on device This function should be regarded as the point that all commands of current Run has been submmited by CPU */ - virtual common::Status OnRunEnd() { return Status::OK(); } + virtual common::Status OnRunEnd(bool /*sync_stream*/) { return Status::OK(); } /** Called when session creation is complete diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 6dd68a72bf..14fa9fd737 100755 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -333,11 +333,13 @@ Status CUDAExecutionProvider::OnRunStart() { return Status::OK(); } -Status CUDAExecutionProvider::OnRunEnd() { +Status CUDAExecutionProvider::OnRunEnd(bool sync_stream) { // record deferred release event on default stream, and release per_thread_context auto current_deferred_release_event = GetPerThreadContext().GetCurrentDeferredReleaseEvent(); CUDA_RETURN_IF_ERROR(cudaEventRecord(current_deferred_release_event, static_cast(GetComputeStream()))); - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(static_cast(GetComputeStream()))); + if (sync_stream) { + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(static_cast(GetComputeStream()))); + } ReleasePerThreadContext(); std::lock_guard lock(deferred_release_cpu_ptr_mutex_); deferred_release_cpu_ptr_[current_deferred_release_event].recorded = true; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index d96ee26fe2..67120de11c 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -29,7 +29,7 @@ class CUDAExecutionProvider : public IExecutionProvider { Status OnRunStart() override; - Status OnRunEnd() override; + Status OnRunEnd(bool sync_stream) override; const void* GetExecutionHandle() const noexcept override { // The CUDA interface does not return anything interesting. diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 5ba380f0e4..b7c207f8bd 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -336,11 +336,13 @@ Status ROCMExecutionProvider::OnRunStart() { return Status::OK(); } -Status ROCMExecutionProvider::OnRunEnd() { +Status ROCMExecutionProvider::OnRunEnd(bool sync_stream) { // record deferred release event on default stream, and release per_thread_context auto current_deferred_release_event = GetPerThreadContext().GetCurrentDeferredReleaseEvent(); HIP_RETURN_IF_ERROR(hipEventRecord(current_deferred_release_event, static_cast(GetComputeStream()))); - HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(GetComputeStream()))); + if (sync_stream) { + HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(GetComputeStream()))); + } ReleasePerThreadContext(); std::lock_guard lock(deferred_release_cpu_ptr_mutex_); deferred_release_cpu_ptr_[current_deferred_release_event].recorded = true; diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.h b/onnxruntime/core/providers/rocm/rocm_execution_provider.h index 765d5766e1..a0b8817d87 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.h +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.h @@ -29,7 +29,7 @@ class ROCMExecutionProvider : public IExecutionProvider { Status OnRunStart() override; - Status OnRunEnd() override; + Status OnRunEnd(bool sync_stream) override; const void* GetExecutionHandle() const noexcept override { // The ROCM interface does not return anything interesting. diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index acf7e09b5b..ed7f654ef5 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -645,8 +645,10 @@ std::unique_ptr TensorrtExecutionProvider::GetDataTransfer() cons return onnxruntime::CreateGPUDataTransfer(static_cast(GetComputeStream())); } -Status TensorrtExecutionProvider::OnRunEnd() { - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(static_cast(GetComputeStream()))); +Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream) { + if (sync_stream) { + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(static_cast(GetComputeStream()))); + } return Status::OK(); } @@ -911,7 +913,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect if (input_shape != nullptr) { auto dim_size = input_shape->dim_size(); for (int i = 0; i < dim_size; ++i) { - auto &dim = input_shape->dim(i); + auto& dim = input_shape->dim(i); if (!dim.has_dim_value() && !dim.has_dim_param()) { has_dim_value_or_param = false; break; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index d2f305f7ff..a1e3e14f19 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -125,7 +125,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { void RegisterAllocator(std::shared_ptr allocator_manager) override; - Status OnRunEnd() override; + Status OnRunEnd(bool sync_stream) override; Status SetComputeStream(void* stream) override; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index b5a22a4d37..44bd9d37c7 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1730,7 +1730,7 @@ Status InferenceSession::PartialRun(onnxruntime::RunOptions& run_options, // info all execution providers InferenceSession:Run ended for (auto* xp : exec_providers_to_stop) { - auto status = xp->OnRunEnd(); + auto status = xp->OnRunEnd(/*sync_stream*/ false); ORT_CHECK_AND_SET_RETVAL(status); } @@ -1844,7 +1844,7 @@ Status InferenceSession::Run(const RunOptions& run_options, // info all execution providers InferenceSession:Run ended for (auto* xp : exec_providers_to_stop) { - auto status = xp->OnRunEnd(); + auto status = xp->OnRunEnd(/*sync_stream*/ true); ORT_CHECK_AND_SET_RETVAL(status); }