From f720166887196b8a7e52cb9c5ef3fed48250e97e Mon Sep 17 00:00:00 2001 From: Ke Zhang Date: Wed, 17 Jul 2019 21:12:18 -0700 Subject: [PATCH] register gpu data transfer only when there's nvidia gpu related eps. (#1420) --- onnxruntime/core/session/inference_session.cc | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 248e21d58a..03ecdcf524 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -103,11 +103,6 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, loggin InitLogger(logging_manager); - // Register data transfer methods. - data_transfer_mgr_.RegisterDataTransfer(std::make_unique()); -#ifdef USE_CUDA - data_transfer_mgr_.RegisterDataTransfer(std::make_unique()); -#endif session_state_.SetDataTransferMgr(&data_transfer_mgr_); // The threadpool is currently evolving. We will always create a per session threadpool. @@ -482,6 +477,16 @@ common::Status InferenceSession::Initialize() { std::make_unique(epi))); } + // Register data transfer methods. + data_transfer_mgr_.RegisterDataTransfer(std::make_unique()); +#ifdef USE_CUDA + // TODO: this should be refactored later by exposing separate API to allow users to register different data transfers for different devices. + bool is_nvidia_gpu_used = (nullptr != execution_providers_.Get(kCudaExecutionProvider)) || (nullptr != execution_providers_.Get(kTensorrtExecutionProvider)); + if (is_nvidia_gpu_used) { + data_transfer_mgr_.RegisterDataTransfer(std::make_unique()); + } +#endif + if (!session_options_.enable_sequential_execution && execution_providers_.Get(onnxruntime::kCudaExecutionProvider)) { LOGS(*session_logger_, ERROR) << "Parallel execution is currently not supported "