mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
Minor perf improvements. (#1580)
* Minor perf improvements.
- Cache the vector sizes in IExecutionFrame and NodeIndexInfo to avoid calls to size().
- 2 instructions instead of 10
- Remove an unnecessary check in IExecutionFrame
- add a check to the ctor so we guarantee it's unnecessary
- Reserve memory for the vectors in BroadcastIterator
- saves reallocs if more than one value is added
- but rare with the mlperf models for multiple values to be added so benefit is limited.
- slight tweak to the Broadcaster ctor code to make it more readable
This commit is contained in:
parent
a6a4c4c079
commit
8a559d75ae
5 changed files with 45 additions and 26 deletions
|
|
@ -22,11 +22,15 @@ IExecutionFrame::IExecutionFrame(const std::vector<int>& feed_mlvalue_idxs, cons
|
|||
const std::unordered_map<int, OrtValue>& initializers,
|
||||
const std::vector<int>& fetch_mlvalue_idxs, const std::vector<OrtValue>& fetches,
|
||||
const OrtValueNameIdxMap& ort_value_idx_map, const NodeIndexInfo& node_index_info)
|
||||
: node_index_info_{node_index_info}, fetch_mlvalue_idxs_{fetch_mlvalue_idxs} {
|
||||
: node_index_info_{node_index_info},
|
||||
all_values_size_{static_cast<size_t>(ort_value_idx_map.MaxIdx()) + 1},
|
||||
fetch_mlvalue_idxs_{fetch_mlvalue_idxs} {
|
||||
ORT_ENFORCE(feeds.size() == feed_mlvalue_idxs.size());
|
||||
ORT_ENFORCE(fetches.empty() || fetches.size() == fetch_mlvalue_idxs_.size());
|
||||
ORT_ENFORCE(node_index_info_.GetMaxMLValueIdx() == ort_value_idx_map.MaxIdx(),
|
||||
"node_index_info and ort_value_idx_map are out of sync and cannot be used");
|
||||
|
||||
Init(feed_mlvalue_idxs, feeds, initializers, fetches, ort_value_idx_map);
|
||||
Init(feed_mlvalue_idxs, feeds, initializers, fetches);
|
||||
}
|
||||
|
||||
IExecutionFrame::~IExecutionFrame() = default;
|
||||
|
|
@ -79,7 +83,7 @@ AllocatorPtr IExecutionFrame::GetAllocator(const OrtAllocatorInfo& info) const {
|
|||
Status IExecutionFrame::ReleaseMLValue(int ort_value_idx) { return ReleaseMLValueImpl(ort_value_idx); }
|
||||
|
||||
Status IExecutionFrame::ReleaseMLValueImpl(int ort_value_idx) {
|
||||
if (ort_value_idx == NodeIndexInfo::kInvalidEntry || static_cast<size_t>(ort_value_idx) >= all_values_.size()) {
|
||||
if (ort_value_idx == NodeIndexInfo::kInvalidEntry || static_cast<size_t>(ort_value_idx) >= all_values_size_) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid index ", ort_value_idx);
|
||||
}
|
||||
|
||||
|
|
@ -95,19 +99,16 @@ Status IExecutionFrame::ReleaseMLValueImpl(int ort_value_idx) {
|
|||
}
|
||||
|
||||
int IExecutionFrame::GetNodeIdxToMLValueIdx(int index) const {
|
||||
// the validity of index is checked by GetMLValueIndex
|
||||
int ort_value_idx = node_index_info_.GetMLValueIndex(index);
|
||||
ORT_ENFORCE(ort_value_idx == NodeIndexInfo::kInvalidEntry ||
|
||||
(ort_value_idx >= 0 && static_cast<size_t>(ort_value_idx) < all_values_.size()));
|
||||
|
||||
return ort_value_idx;
|
||||
}
|
||||
|
||||
void IExecutionFrame::Init(const std::vector<int>& feed_mlvalue_idxs, const std::vector<OrtValue>& feeds,
|
||||
const std::unordered_map<int, OrtValue>& initializers,
|
||||
const std::vector<OrtValue>& fetches,
|
||||
const OrtValueNameIdxMap& ort_value_idx_map) {
|
||||
const std::vector<OrtValue>& fetches) {
|
||||
// 1. resize the all_value_ vector
|
||||
all_values_.resize(ort_value_idx_map.MaxIdx() + 1);
|
||||
all_values_.resize(all_values_size_);
|
||||
|
||||
// 2. Handle non-empty output vector
|
||||
if (!fetches.empty()) {
|
||||
|
|
|
|||
|
|
@ -74,10 +74,10 @@ class IExecutionFrame {
|
|||
|
||||
void Init(const std::vector<int>& feed_mlvalue_idxs, const std::vector<OrtValue>& feeds,
|
||||
const std::unordered_map<int, OrtValue>& initializers,
|
||||
const std::vector<OrtValue>& fetches, const OrtValueNameIdxMap& ort_value_idx_map);
|
||||
const std::vector<OrtValue>& fetches);
|
||||
|
||||
const OrtValue& GetMLValue(int ort_value_index) const {
|
||||
ORT_ENFORCE(ort_value_index >= 0 && static_cast<size_t>(ort_value_index) < all_values_.size());
|
||||
ORT_ENFORCE(ort_value_index >= 0 && static_cast<size_t>(ort_value_index) < all_values_size_);
|
||||
return all_values_[ort_value_index];
|
||||
}
|
||||
|
||||
|
|
@ -91,6 +91,9 @@ class IExecutionFrame {
|
|||
// Input and Output values are passed in by executors
|
||||
std::vector<OrtValue> all_values_;
|
||||
|
||||
// perf optimization to avoid calling all_values_.size() repeatedly as the size is fixed once constructed
|
||||
const size_t all_values_size_;
|
||||
|
||||
const std::vector<int> fetch_mlvalue_idxs_;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -69,6 +69,10 @@ void NodeIndexInfo::Init(const TValidNodes& nodes, NodeIndex max_node_index,
|
|||
// init all to kInvalidEntry
|
||||
node_offsets_.resize(GetNodeOffsetsIndex(max_node_index), kInvalidEntry);
|
||||
node_values_.resize(total_def_count, kInvalidEntry);
|
||||
|
||||
node_offsets_size_ = node_offsets_.size();
|
||||
node_values_size_ = node_values_.size();
|
||||
|
||||
int cur_idx = 0;
|
||||
|
||||
for (auto& node : nodes) {
|
||||
|
|
|
|||
|
|
@ -31,14 +31,14 @@ class NodeIndexInfo final {
|
|||
// Returns kInvalidEntry if the Node with the given node_index did not exist when the NodeIndexInfo was created.
|
||||
int GetNodeOffset(NodeIndex node_index) const {
|
||||
auto node_offsets_index = GetNodeOffsetsIndex(node_index);
|
||||
ORT_ENFORCE(node_offsets_index < node_offsets_.size());
|
||||
ORT_ENFORCE(node_offsets_index < node_offsets_size_);
|
||||
return node_offsets_[node_offsets_index];
|
||||
}
|
||||
|
||||
// Get the ort_value index value.
|
||||
// Returns kInvalidEntry for optional inputs/outputs that do not exist in this graph.
|
||||
int GetMLValueIndex(int offset) const {
|
||||
ORT_ENFORCE(offset >= 0 && static_cast<size_t>(offset) < node_values_.size());
|
||||
ORT_ENFORCE(offset >= 0 && static_cast<size_t>(offset) < node_values_size_);
|
||||
return node_values_[offset];
|
||||
}
|
||||
|
||||
|
|
@ -63,5 +63,9 @@ class NodeIndexInfo final {
|
|||
std::vector<int> node_offsets_;
|
||||
|
||||
const int max_mlvalue_idx_;
|
||||
|
||||
// perf optimization to avoid calls to size() on node_values_ and node_offsets_ as they don't change
|
||||
size_t node_values_size_;
|
||||
size_t node_offsets_size_;
|
||||
};
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -320,6 +320,11 @@ struct BroadcastIterator {
|
|||
return index;
|
||||
}
|
||||
|
||||
void Reserve(int64_t max_dims) {
|
||||
deltas_.reserve(max_dims);
|
||||
counts_.reserve(max_dims);
|
||||
}
|
||||
|
||||
void Init(int64_t axis, int64_t largest) {
|
||||
ORT_ENFORCE(axis == 1 || axis == largest, "Attempting to broadcast an axis by a dimension other than 1. ", axis, " by ", largest);
|
||||
|
||||
|
|
@ -368,6 +373,8 @@ struct Broadcaster {
|
|||
size_t dimension_count_max = std::max(shape1.size(), shape2.size());
|
||||
size_t dimension_count_min = std::min(shape1.size(), shape2.size());
|
||||
output_shape_.resize(dimension_count_max);
|
||||
iterator1_.Reserve(dimension_count_max);
|
||||
iterator2_.Reserve(dimension_count_max);
|
||||
|
||||
auto iter1 = shape1.end();
|
||||
auto iter2 = shape2.end();
|
||||
|
|
@ -395,22 +402,22 @@ struct Broadcaster {
|
|||
*--output_shape = axis;
|
||||
}
|
||||
index++; // Manually increment since we processed one axis
|
||||
}
|
||||
} else {
|
||||
for (; index < dimension_count_min; index++) {
|
||||
auto axis1 = *--iter1;
|
||||
auto axis2 = *--iter2;
|
||||
|
||||
for (; index < dimension_count_min; index++) {
|
||||
auto axis1 = *--iter1;
|
||||
auto axis2 = *--iter2;
|
||||
auto largest = std::max(axis1, axis2);
|
||||
*--output_shape = largest;
|
||||
|
||||
auto largest = std::max(axis1, axis2);
|
||||
*--output_shape = largest;
|
||||
if (largest == 1 && index + 1 < dimension_count_min) // Nothing to do in this case
|
||||
continue;
|
||||
|
||||
if (largest == 1 && index + 1 < dimension_count_min) // Nothing to do in this case
|
||||
continue;
|
||||
|
||||
iterator1_.Init(axis1, largest);
|
||||
iterator2_.Init(axis2, largest);
|
||||
index++; // Manually increment since we processed one axis
|
||||
break;
|
||||
iterator1_.Init(axis1, largest);
|
||||
iterator2_.Init(axis2, largest);
|
||||
index++; // Manually increment since we processed one axis
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for (; index < dimension_count_min; index++) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue