mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
register gpu data transfer only when there's nvidia gpu related eps. (#1420)
This commit is contained in:
parent
db61eb4cd7
commit
f720166887
1 changed files with 10 additions and 5 deletions
|
|
@ -103,11 +103,6 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, loggin
|
|||
|
||||
InitLogger(logging_manager);
|
||||
|
||||
// Register data transfer methods.
|
||||
data_transfer_mgr_.RegisterDataTransfer(std::make_unique<CPUDataTransfer>());
|
||||
#ifdef USE_CUDA
|
||||
data_transfer_mgr_.RegisterDataTransfer(std::make_unique<GPUDataTransfer>());
|
||||
#endif
|
||||
session_state_.SetDataTransferMgr(&data_transfer_mgr_);
|
||||
|
||||
// The threadpool is currently evolving. We will always create a per session threadpool.
|
||||
|
|
@ -482,6 +477,16 @@ common::Status InferenceSession::Initialize() {
|
|||
std::make_unique<CPUExecutionProvider>(epi)));
|
||||
}
|
||||
|
||||
// Register data transfer methods.
|
||||
data_transfer_mgr_.RegisterDataTransfer(std::make_unique<CPUDataTransfer>());
|
||||
#ifdef USE_CUDA
|
||||
// TODO: this should be refactored later by exposing separate API to allow users to register different data transfers for different devices.
|
||||
bool is_nvidia_gpu_used = (nullptr != execution_providers_.Get(kCudaExecutionProvider)) || (nullptr != execution_providers_.Get(kTensorrtExecutionProvider));
|
||||
if (is_nvidia_gpu_used) {
|
||||
data_transfer_mgr_.RegisterDataTransfer(std::make_unique<GPUDataTransfer>());
|
||||
}
|
||||
#endif
|
||||
|
||||
if (!session_options_.enable_sequential_execution &&
|
||||
execution_providers_.Get(onnxruntime::kCudaExecutionProvider)) {
|
||||
LOGS(*session_logger_, ERROR) << "Parallel execution is currently not supported "
|
||||
|
|
|
|||
Loading…
Reference in a new issue