mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Improve iobinding, faster name search (#10005)
* Improve iobinding, faster name search
This commit is contained in:
parent
3ea7fb0f9f
commit
e38e51ea8e
3 changed files with 34 additions and 36 deletions
|
|
@ -268,9 +268,12 @@ const OrtMemoryInfo& FindMemoryInfoForValue(const SessionState& session_state,
|
|||
static common::Status CalculateStaticCopyInfoForFeed(const SessionState& session_state,
|
||||
const std::string& input_name,
|
||||
MLValueCopyInfo& copy_info) {
|
||||
#ifdef ENABLE_TRAINING
|
||||
std::vector<SessionState::NodeInfo> node_info_vec;
|
||||
#ifdef ENABLE_TRAINING
|
||||
if (session_state.GetInputNodeInfo(input_name, node_info_vec) == Status::OK()) {
|
||||
#else
|
||||
ORT_RETURN_IF_ERROR(session_state.GetInputNodeInfo(input_name, node_info_vec));
|
||||
#endif
|
||||
const auto& node_info = node_info_vec.front(); // all consumers of a feed have the same device so first entry is fine
|
||||
|
||||
if (node_info.p_node == nullptr) {
|
||||
|
|
@ -280,6 +283,7 @@ static common::Status CalculateStaticCopyInfoForFeed(const SessionState& session
|
|||
|
||||
copy_info.target_device = *node_info.device;
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
} else {
|
||||
// This input might be for an intermediate tensor for partial graph execution.
|
||||
const auto* exec_plan = session_state.GetExecutionPlan();
|
||||
|
|
@ -289,22 +293,9 @@ static common::Status CalculateStaticCopyInfoForFeed(const SessionState& session
|
|||
const auto& device = exec_plan->GetLocation(index).device;
|
||||
copy_info.target_device = device;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
#else
|
||||
std::vector<SessionState::NodeInfo> node_info_vec;
|
||||
ORT_RETURN_IF_ERROR(session_state.GetInputNodeInfo(input_name, node_info_vec));
|
||||
const auto& node_info = node_info_vec.front(); // all consumers of a feed have the same device so first entry is fine
|
||||
|
||||
if (node_info.p_node == nullptr) {
|
||||
// ignore dummy entry for an input that we didn't find a use of in the graph.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
copy_info.target_device = *node_info.device;
|
||||
|
||||
return Status::OK();
|
||||
#endif
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static common::Status CalculateStaticCopyInfoForFeeds(const SessionState& session_state,
|
||||
|
|
@ -508,6 +499,7 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons
|
|||
}
|
||||
|
||||
MLValueCopyInfo copy_info;
|
||||
// Sets copy_info.target_device.
|
||||
ORT_RETURN_IF_ERROR(CalculateStaticCopyInfoForFeed(session_state, input_name, copy_info));
|
||||
#if !defined(DISABLE_SPARSE_TENSORS)
|
||||
copy_info.source_device = (orig_mlvalue.IsTensor())
|
||||
|
|
@ -517,6 +509,7 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons
|
|||
copy_info.source_device = orig_mlvalue.Get<Tensor>().Location().device;
|
||||
#endif
|
||||
|
||||
// copy_info.target_device is not set leaving to be equal to CPU.
|
||||
return BatchOrCopyMLValue(session_state, copy_info, orig_mlvalue, new_mlvalue);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -11,38 +11,38 @@ namespace onnxruntime {
|
|||
IOBinding::IOBinding(const SessionState& session_state) : session_state_(session_state) {
|
||||
}
|
||||
|
||||
static std::pair<bool, size_t> Contains(const std::vector<std::string>& names, const std::string& name) {
|
||||
auto it = std::find(std::begin(names), std::end(names), name);
|
||||
if (it == std::end(names)) {
|
||||
return {false, 0};
|
||||
}
|
||||
return {true, it - std::begin(names)};
|
||||
}
|
||||
|
||||
common::Status IOBinding::BindInput(const std::string& name, const OrtValue& ml_value) {
|
||||
auto rc = Contains(feed_names_, name);
|
||||
auto it = mapped_feed_names_.emplace(name, feed_names_.size());
|
||||
|
||||
auto add_or_replace = [this, &name](const bool exists, size_t index, const OrtValue& value) {
|
||||
if (exists) {
|
||||
feeds_[index] = value;
|
||||
} else {
|
||||
auto add_or_replace = [&](const OrtValue& value) {
|
||||
if (it.second) {
|
||||
feed_names_.push_back(name);
|
||||
feeds_.push_back(value);
|
||||
} else {
|
||||
feeds_[it.first->second] = value;
|
||||
}
|
||||
};
|
||||
|
||||
if (ml_value.IsTensor() || ml_value.IsSparseTensor()) {
|
||||
OrtValue new_mlvalue;
|
||||
// Do not replace new_mlvalue by feeds_[index] in the following line.
|
||||
// It may copy the data instead of copying the pointer.
|
||||
// When OrtValue is empty, the pointer is copied. When it is not
|
||||
// (if feeds_[index] is not for example),
|
||||
// CopyOneInputAcrossDevices has a different behaviour.
|
||||
ORT_RETURN_IF_ERROR(utils::CopyOneInputAcrossDevices(session_state_, name, ml_value, new_mlvalue));
|
||||
add_or_replace(rc.first, rc.second, new_mlvalue);
|
||||
add_or_replace(new_mlvalue);
|
||||
} else {
|
||||
add_or_replace(rc.first, rc.second, ml_value);
|
||||
add_or_replace(ml_value);
|
||||
}
|
||||
|
||||
ORT_ENFORCE(mapped_feed_names_.size() == feed_names_.size(), "Size mismatch:", mapped_feed_names_.size(), "!=", feed_names_.size(), " index=", it.first->second, " it.second=", it.second);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void IOBinding::ClearInputs() {
|
||||
mapped_feed_names_.clear();
|
||||
feed_names_.clear();
|
||||
feeds_.clear();
|
||||
}
|
||||
|
|
@ -88,20 +88,23 @@ common::Status IOBinding::BindOutput(const std::string& name, OrtDevice device)
|
|||
}
|
||||
|
||||
common::Status IOBinding::BindOutputImpl(const std::string& name, const OrtValue& ml_value, OrtDevice device) {
|
||||
auto rc = Contains(output_names_, name);
|
||||
if (rc.first) {
|
||||
outputs_[rc.second] = ml_value;
|
||||
outputs_device_info_[rc.second] = device;
|
||||
} else {
|
||||
auto it = mapped_output_names_.emplace(name, output_names_.size());
|
||||
size_t index = it.first->second;
|
||||
if (it.second) {
|
||||
output_names_.push_back(name);
|
||||
outputs_.push_back(ml_value);
|
||||
outputs_device_info_.push_back(device);
|
||||
} else {
|
||||
outputs_[index] = ml_value;
|
||||
outputs_device_info_[index] = device;
|
||||
}
|
||||
ORT_ENFORCE(mapped_output_names_.size() == output_names_.size(), "Size mismatch", mapped_output_names_.size(), "!=", output_names_.size());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void IOBinding::ClearOutputs() {
|
||||
mapped_output_names_.clear();
|
||||
output_names_.clear();
|
||||
outputs_.clear();
|
||||
outputs_device_info_.clear();
|
||||
|
|
|
|||
|
|
@ -100,8 +100,10 @@ class IOBinding {
|
|||
|
||||
const SessionState& session_state_;
|
||||
std::vector<std::string> feed_names_;
|
||||
std::unordered_map<std::string, size_t> mapped_feed_names_;
|
||||
std::vector<OrtValue> feeds_;
|
||||
std::vector<std::string> output_names_;
|
||||
std::unordered_map<std::string, size_t> mapped_output_names_;
|
||||
std::vector<OrtValue> outputs_;
|
||||
std::vector<OrtDevice> outputs_device_info_;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue