diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index bf76d3bcbd..96bdb44b65 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -268,9 +268,12 @@ const OrtMemoryInfo& FindMemoryInfoForValue(const SessionState& session_state, static common::Status CalculateStaticCopyInfoForFeed(const SessionState& session_state, const std::string& input_name, MLValueCopyInfo& copy_info) { -#ifdef ENABLE_TRAINING std::vector node_info_vec; +#ifdef ENABLE_TRAINING if (session_state.GetInputNodeInfo(input_name, node_info_vec) == Status::OK()) { +#else + ORT_RETURN_IF_ERROR(session_state.GetInputNodeInfo(input_name, node_info_vec)); +#endif const auto& node_info = node_info_vec.front(); // all consumers of a feed have the same device so first entry is fine if (node_info.p_node == nullptr) { @@ -280,6 +283,7 @@ static common::Status CalculateStaticCopyInfoForFeed(const SessionState& session copy_info.target_device = *node_info.device; +#ifdef ENABLE_TRAINING } else { // This input might be for an intermediate tensor for partial graph execution. const auto* exec_plan = session_state.GetExecutionPlan(); @@ -289,22 +293,9 @@ static common::Status CalculateStaticCopyInfoForFeed(const SessionState& session const auto& device = exec_plan->GetLocation(index).device; copy_info.target_device = device; } - - return Status::OK(); -#else - std::vector node_info_vec; - ORT_RETURN_IF_ERROR(session_state.GetInputNodeInfo(input_name, node_info_vec)); - const auto& node_info = node_info_vec.front(); // all consumers of a feed have the same device so first entry is fine - - if (node_info.p_node == nullptr) { - // ignore dummy entry for an input that we didn't find a use of in the graph. - return Status::OK(); - } - - copy_info.target_device = *node_info.device; - - return Status::OK(); #endif + + return Status::OK(); } static common::Status CalculateStaticCopyInfoForFeeds(const SessionState& session_state, @@ -508,6 +499,7 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons } MLValueCopyInfo copy_info; + // Sets copy_info.target_device. ORT_RETURN_IF_ERROR(CalculateStaticCopyInfoForFeed(session_state, input_name, copy_info)); #if !defined(DISABLE_SPARSE_TENSORS) copy_info.source_device = (orig_mlvalue.IsTensor()) @@ -517,6 +509,7 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons copy_info.source_device = orig_mlvalue.Get().Location().device; #endif + // copy_info.target_device is not set leaving to be equal to CPU. return BatchOrCopyMLValue(session_state, copy_info, orig_mlvalue, new_mlvalue); } diff --git a/onnxruntime/core/session/IOBinding.cc b/onnxruntime/core/session/IOBinding.cc index c1609b7ccd..e0dd66d188 100644 --- a/onnxruntime/core/session/IOBinding.cc +++ b/onnxruntime/core/session/IOBinding.cc @@ -11,38 +11,38 @@ namespace onnxruntime { IOBinding::IOBinding(const SessionState& session_state) : session_state_(session_state) { } -static std::pair Contains(const std::vector& names, const std::string& name) { - auto it = std::find(std::begin(names), std::end(names), name); - if (it == std::end(names)) { - return {false, 0}; - } - return {true, it - std::begin(names)}; -} - common::Status IOBinding::BindInput(const std::string& name, const OrtValue& ml_value) { - auto rc = Contains(feed_names_, name); + auto it = mapped_feed_names_.emplace(name, feed_names_.size()); - auto add_or_replace = [this, &name](const bool exists, size_t index, const OrtValue& value) { - if (exists) { - feeds_[index] = value; - } else { + auto add_or_replace = [&](const OrtValue& value) { + if (it.second) { feed_names_.push_back(name); feeds_.push_back(value); + } else { + feeds_[it.first->second] = value; } }; if (ml_value.IsTensor() || ml_value.IsSparseTensor()) { OrtValue new_mlvalue; + // Do not replace new_mlvalue by feeds_[index] in the following line. + // It may copy the data instead of copying the pointer. + // When OrtValue is empty, the pointer is copied. When it is not + // (if feeds_[index] is not for example), + // CopyOneInputAcrossDevices has a different behaviour. ORT_RETURN_IF_ERROR(utils::CopyOneInputAcrossDevices(session_state_, name, ml_value, new_mlvalue)); - add_or_replace(rc.first, rc.second, new_mlvalue); + add_or_replace(new_mlvalue); } else { - add_or_replace(rc.first, rc.second, ml_value); + add_or_replace(ml_value); } + ORT_ENFORCE(mapped_feed_names_.size() == feed_names_.size(), "Size mismatch:", mapped_feed_names_.size(), "!=", feed_names_.size(), " index=", it.first->second, " it.second=", it.second); + return Status::OK(); } void IOBinding::ClearInputs() { + mapped_feed_names_.clear(); feed_names_.clear(); feeds_.clear(); } @@ -88,20 +88,23 @@ common::Status IOBinding::BindOutput(const std::string& name, OrtDevice device) } common::Status IOBinding::BindOutputImpl(const std::string& name, const OrtValue& ml_value, OrtDevice device) { - auto rc = Contains(output_names_, name); - if (rc.first) { - outputs_[rc.second] = ml_value; - outputs_device_info_[rc.second] = device; - } else { + auto it = mapped_output_names_.emplace(name, output_names_.size()); + size_t index = it.first->second; + if (it.second) { output_names_.push_back(name); outputs_.push_back(ml_value); outputs_device_info_.push_back(device); + } else { + outputs_[index] = ml_value; + outputs_device_info_[index] = device; } + ORT_ENFORCE(mapped_output_names_.size() == output_names_.size(), "Size mismatch", mapped_output_names_.size(), "!=", output_names_.size()); return Status::OK(); } void IOBinding::ClearOutputs() { + mapped_output_names_.clear(); output_names_.clear(); outputs_.clear(); outputs_device_info_.clear(); diff --git a/onnxruntime/core/session/IOBinding.h b/onnxruntime/core/session/IOBinding.h index 61f7557f03..52b8b6aaf7 100644 --- a/onnxruntime/core/session/IOBinding.h +++ b/onnxruntime/core/session/IOBinding.h @@ -100,8 +100,10 @@ class IOBinding { const SessionState& session_state_; std::vector feed_names_; + std::unordered_map mapped_feed_names_; std::vector feeds_; std::vector output_names_; + std::unordered_map mapped_output_names_; std::vector outputs_; std::vector outputs_device_info_;