diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 82b160a704..b0171f25be 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -98,7 +98,6 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons // info on the logic to create the node_info_vec. // for (auto& node_info : node_info_vec) { auto& node_info = node_info_vec.front(); - if (node_info.p_node == nullptr) { // dummy entry for an input that we didn't find a use of in the graph. // use the input as is given we don't believe it's actually needed. @@ -409,25 +408,18 @@ static common::Status CachedCopyOutputsAcrossDevices( return Status::OK(); } -// check if all the execution providers use the same allocator. if so, no copies between devices should be required, -// and the overall status for DeviceCopyChecks can be set to NoCopy static DeviceCopyCheck CheckExecutionProviders(const ExecutionProviders& execution_providers) { - bool all_cpu = true; for (const auto& execution_provider : execution_providers) { - const auto& allocators = execution_provider->GetAllocators(); - // this won't work as desired until multiple providers can share the CPU Allocator and the logic here is updated - // to detect that.. - // it will currently handle the scenario when only the CPUExecutionProvider is registered though - if (!std::all_of(allocators.cbegin(), allocators.cend(), - [](const gsl::not_null& allocator) { - return strcmp(allocator->Info().name, CPU) == 0; - })) { - all_cpu = false; - break; + if (execution_provider->Type() != onnxruntime::kCpuExecutionProvider && + execution_provider->Type() != onnxruntime::kMklDnnExecutionProvider && + execution_provider->Type() != onnxruntime::kNGraphExecutionProvider && + execution_provider->Type() != onnxruntime::kNupharExecutionProvider && + execution_provider->Type() != onnxruntime::kOpenVINOExecutionProvider) { + return DeviceCopyCheck::Unknown; } } - return all_cpu ? DeviceCopyCheck::NoCopy : DeviceCopyCheck::Unknown; + return DeviceCopyCheck::NoCopy; } // execute graph with cached info from FeedsFetchesManager.