Improve iobinding, faster name search (#10005)

* Improve iobinding, faster name search
This commit is contained in:
Xavier Dupré 2022-01-14 12:18:18 +01:00 committed by GitHub
parent 3ea7fb0f9f
commit e38e51ea8e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 34 additions and 36 deletions

View file

@ -268,9 +268,12 @@ const OrtMemoryInfo& FindMemoryInfoForValue(const SessionState& session_state,
static common::Status CalculateStaticCopyInfoForFeed(const SessionState& session_state,
const std::string& input_name,
MLValueCopyInfo& copy_info) {
#ifdef ENABLE_TRAINING
std::vector<SessionState::NodeInfo> node_info_vec;
#ifdef ENABLE_TRAINING
if (session_state.GetInputNodeInfo(input_name, node_info_vec) == Status::OK()) {
#else
ORT_RETURN_IF_ERROR(session_state.GetInputNodeInfo(input_name, node_info_vec));
#endif
const auto& node_info = node_info_vec.front(); // all consumers of a feed have the same device so first entry is fine
if (node_info.p_node == nullptr) {
@ -280,6 +283,7 @@ static common::Status CalculateStaticCopyInfoForFeed(const SessionState& session
copy_info.target_device = *node_info.device;
#ifdef ENABLE_TRAINING
} else {
// This input might be for an intermediate tensor for partial graph execution.
const auto* exec_plan = session_state.GetExecutionPlan();
@ -289,22 +293,9 @@ static common::Status CalculateStaticCopyInfoForFeed(const SessionState& session
const auto& device = exec_plan->GetLocation(index).device;
copy_info.target_device = device;
}
return Status::OK();
#else
std::vector<SessionState::NodeInfo> node_info_vec;
ORT_RETURN_IF_ERROR(session_state.GetInputNodeInfo(input_name, node_info_vec));
const auto& node_info = node_info_vec.front(); // all consumers of a feed have the same device so first entry is fine
if (node_info.p_node == nullptr) {
// ignore dummy entry for an input that we didn't find a use of in the graph.
return Status::OK();
}
copy_info.target_device = *node_info.device;
return Status::OK();
#endif
return Status::OK();
}
static common::Status CalculateStaticCopyInfoForFeeds(const SessionState& session_state,
@ -508,6 +499,7 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons
}
MLValueCopyInfo copy_info;
// Sets copy_info.target_device.
ORT_RETURN_IF_ERROR(CalculateStaticCopyInfoForFeed(session_state, input_name, copy_info));
#if !defined(DISABLE_SPARSE_TENSORS)
copy_info.source_device = (orig_mlvalue.IsTensor())
@ -517,6 +509,7 @@ common::Status CopyOneInputAcrossDevices(const SessionState& session_state, cons
copy_info.source_device = orig_mlvalue.Get<Tensor>().Location().device;
#endif
// copy_info.target_device is not set leaving to be equal to CPU.
return BatchOrCopyMLValue(session_state, copy_info, orig_mlvalue, new_mlvalue);
}

View file

@ -11,38 +11,38 @@ namespace onnxruntime {
IOBinding::IOBinding(const SessionState& session_state) : session_state_(session_state) {
}
static std::pair<bool, size_t> Contains(const std::vector<std::string>& names, const std::string& name) {
auto it = std::find(std::begin(names), std::end(names), name);
if (it == std::end(names)) {
return {false, 0};
}
return {true, it - std::begin(names)};
}
common::Status IOBinding::BindInput(const std::string& name, const OrtValue& ml_value) {
auto rc = Contains(feed_names_, name);
auto it = mapped_feed_names_.emplace(name, feed_names_.size());
auto add_or_replace = [this, &name](const bool exists, size_t index, const OrtValue& value) {
if (exists) {
feeds_[index] = value;
} else {
auto add_or_replace = [&](const OrtValue& value) {
if (it.second) {
feed_names_.push_back(name);
feeds_.push_back(value);
} else {
feeds_[it.first->second] = value;
}
};
if (ml_value.IsTensor() || ml_value.IsSparseTensor()) {
OrtValue new_mlvalue;
// Do not replace new_mlvalue by feeds_[index] in the following line.
// It may copy the data instead of copying the pointer.
// When OrtValue is empty, the pointer is copied. When it is not
// (if feeds_[index] is not for example),
// CopyOneInputAcrossDevices has a different behaviour.
ORT_RETURN_IF_ERROR(utils::CopyOneInputAcrossDevices(session_state_, name, ml_value, new_mlvalue));
add_or_replace(rc.first, rc.second, new_mlvalue);
add_or_replace(new_mlvalue);
} else {
add_or_replace(rc.first, rc.second, ml_value);
add_or_replace(ml_value);
}
ORT_ENFORCE(mapped_feed_names_.size() == feed_names_.size(), "Size mismatch:", mapped_feed_names_.size(), "!=", feed_names_.size(), " index=", it.first->second, " it.second=", it.second);
return Status::OK();
}
void IOBinding::ClearInputs() {
mapped_feed_names_.clear();
feed_names_.clear();
feeds_.clear();
}
@ -88,20 +88,23 @@ common::Status IOBinding::BindOutput(const std::string& name, OrtDevice device)
}
common::Status IOBinding::BindOutputImpl(const std::string& name, const OrtValue& ml_value, OrtDevice device) {
auto rc = Contains(output_names_, name);
if (rc.first) {
outputs_[rc.second] = ml_value;
outputs_device_info_[rc.second] = device;
} else {
auto it = mapped_output_names_.emplace(name, output_names_.size());
size_t index = it.first->second;
if (it.second) {
output_names_.push_back(name);
outputs_.push_back(ml_value);
outputs_device_info_.push_back(device);
} else {
outputs_[index] = ml_value;
outputs_device_info_[index] = device;
}
ORT_ENFORCE(mapped_output_names_.size() == output_names_.size(), "Size mismatch", mapped_output_names_.size(), "!=", output_names_.size());
return Status::OK();
}
void IOBinding::ClearOutputs() {
mapped_output_names_.clear();
output_names_.clear();
outputs_.clear();
outputs_device_info_.clear();

View file

@ -100,8 +100,10 @@ class IOBinding {
const SessionState& session_state_;
std::vector<std::string> feed_names_;
std::unordered_map<std::string, size_t> mapped_feed_names_;
std::vector<OrtValue> feeds_;
std::vector<std::string> output_names_;
std::unordered_map<std::string, size_t> mapped_output_names_;
std::vector<OrtValue> outputs_;
std::vector<OrtDevice> outputs_device_info_;