register gpu data transfer only when there's nvidia gpu related eps. (#1420)

This commit is contained in:
Ke Zhang 2019-07-17 21:12:18 -07:00 committed by GitHub
parent db61eb4cd7
commit f720166887
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -103,11 +103,6 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, loggin
InitLogger(logging_manager);
// Register data transfer methods.
data_transfer_mgr_.RegisterDataTransfer(std::make_unique<CPUDataTransfer>());
#ifdef USE_CUDA
data_transfer_mgr_.RegisterDataTransfer(std::make_unique<GPUDataTransfer>());
#endif
session_state_.SetDataTransferMgr(&data_transfer_mgr_);
// The threadpool is currently evolving. We will always create a per session threadpool.
@ -482,6 +477,16 @@ common::Status InferenceSession::Initialize() {
std::make_unique<CPUExecutionProvider>(epi)));
}
// Register data transfer methods.
data_transfer_mgr_.RegisterDataTransfer(std::make_unique<CPUDataTransfer>());
#ifdef USE_CUDA
// TODO: this should be refactored later by exposing separate API to allow users to register different data transfers for different devices.
bool is_nvidia_gpu_used = (nullptr != execution_providers_.Get(kCudaExecutionProvider)) || (nullptr != execution_providers_.Get(kTensorrtExecutionProvider));
if (is_nvidia_gpu_used) {
data_transfer_mgr_.RegisterDataTransfer(std::make_unique<GPUDataTransfer>());
}
#endif
if (!session_options_.enable_sequential_execution &&
execution_providers_.Get(onnxruntime::kCudaExecutionProvider)) {
LOGS(*session_logger_, ERROR) << "Parallel execution is currently not supported "