From b53f40a886114bbcd093fe3a417b4c233fb3ed2c Mon Sep 17 00:00:00 2001 From: Ke Zhang Date: Wed, 21 Aug 2019 19:58:05 -0700 Subject: [PATCH] update set fetches for execution with allocation plan. (#1668) --- onnxruntime/core/framework/utils.cc | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 5fc78a4f99..fdbea44864 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -71,7 +71,7 @@ static Status CopyMLValue(const DataTransferManager& data_transfer_mgr, const FeedsFetchesManager::MLValueCopyInfo& copy_info, const OrtValue& source_mlvalue, OrtValue& target_mlvalue) { - if (copy_info.allocation_provider == nullptr){ + if (copy_info.allocation_provider == nullptr) { target_mlvalue = source_mlvalue; return Status::OK(); } @@ -202,18 +202,16 @@ static common::Status CachedCopyInputsAcrossDevices( // Setup fetches for execution. Use any provided fetches directly if the provider matches. // If the provider doesn't match, we don't know what device the execution output may be on, so can't assume the output // can be returned to the user directly. -// TODO: We should be able to use the allocation plan to know which device an output will be on. static common::Status SetupFetchesForExecute(const SessionState& session_state, const std::vector& output_names, std::vector& fetches, std::vector& new_fetches, std::vector* copy_to_new_fetches_cached_values) { ORT_ENFORCE(new_fetches.empty()); - - const auto& execution_providers = session_state.GetExecutionProviders(); auto num_outputs = output_names.size(); - new_fetches.resize(num_outputs); + const auto& name_to_id = session_state.GetOrtValueNameIdxMap(); + const auto* exec_plan = session_state.GetExecutionPlan(); // track which fetches can be copied to new_fetches and used directly in the execution. std::vector local_can_copy_flags(num_outputs, false); @@ -254,16 +252,12 @@ static common::Status SetupFetchesForExecute(const SessionState& session_state, continue; } - const auto& node_provider_type = node.GetExecutionProviderType(); - const auto& provided_tensor = provided_mlvalue.Get(); - const auto& provided_tensor_loc = provided_tensor.Location(); - const auto* tensor_provider = execution_providers.Get(provided_tensor_loc); - if (!tensor_provider) { - tensor_provider = execution_providers.Get(onnxruntime::kCpuExecutionProvider); - } + int arg_index; + ORT_RETURN_IF_ERROR(name_to_id.GetIdx(arg->Name(), arg_index)); + const auto& planned_device = exec_plan->GetLocation(arg_index).device; + const auto& provided_tensor_device = provided_mlvalue.Get().Location().device; - auto tensor_provider_type = tensor_provider->Type(); - if (node_provider_type == tensor_provider_type) { + if (planned_device == provided_tensor_device) { new_fetches[idx] = fetches[idx]; local_can_copy_flags[idx] = true; continue;