diff --git a/onnxruntime/core/framework/execution_frame.cc b/onnxruntime/core/framework/execution_frame.cc index 09f2f1ad55..59a025a617 100644 --- a/onnxruntime/core/framework/execution_frame.cc +++ b/onnxruntime/core/framework/execution_frame.cc @@ -22,11 +22,15 @@ IExecutionFrame::IExecutionFrame(const std::vector& feed_mlvalue_idxs, cons const std::unordered_map& initializers, const std::vector& fetch_mlvalue_idxs, const std::vector& fetches, const OrtValueNameIdxMap& ort_value_idx_map, const NodeIndexInfo& node_index_info) - : node_index_info_{node_index_info}, fetch_mlvalue_idxs_{fetch_mlvalue_idxs} { + : node_index_info_{node_index_info}, + all_values_size_{static_cast(ort_value_idx_map.MaxIdx()) + 1}, + fetch_mlvalue_idxs_{fetch_mlvalue_idxs} { ORT_ENFORCE(feeds.size() == feed_mlvalue_idxs.size()); ORT_ENFORCE(fetches.empty() || fetches.size() == fetch_mlvalue_idxs_.size()); + ORT_ENFORCE(node_index_info_.GetMaxMLValueIdx() == ort_value_idx_map.MaxIdx(), + "node_index_info and ort_value_idx_map are out of sync and cannot be used"); - Init(feed_mlvalue_idxs, feeds, initializers, fetches, ort_value_idx_map); + Init(feed_mlvalue_idxs, feeds, initializers, fetches); } IExecutionFrame::~IExecutionFrame() = default; @@ -79,7 +83,7 @@ AllocatorPtr IExecutionFrame::GetAllocator(const OrtAllocatorInfo& info) const { Status IExecutionFrame::ReleaseMLValue(int ort_value_idx) { return ReleaseMLValueImpl(ort_value_idx); } Status IExecutionFrame::ReleaseMLValueImpl(int ort_value_idx) { - if (ort_value_idx == NodeIndexInfo::kInvalidEntry || static_cast(ort_value_idx) >= all_values_.size()) { + if (ort_value_idx == NodeIndexInfo::kInvalidEntry || static_cast(ort_value_idx) >= all_values_size_) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid index ", ort_value_idx); } @@ -95,19 +99,16 @@ Status IExecutionFrame::ReleaseMLValueImpl(int ort_value_idx) { } int IExecutionFrame::GetNodeIdxToMLValueIdx(int index) const { + // the validity of index is checked by GetMLValueIndex int ort_value_idx = node_index_info_.GetMLValueIndex(index); - ORT_ENFORCE(ort_value_idx == NodeIndexInfo::kInvalidEntry || - (ort_value_idx >= 0 && static_cast(ort_value_idx) < all_values_.size())); - return ort_value_idx; } void IExecutionFrame::Init(const std::vector& feed_mlvalue_idxs, const std::vector& feeds, const std::unordered_map& initializers, - const std::vector& fetches, - const OrtValueNameIdxMap& ort_value_idx_map) { + const std::vector& fetches) { // 1. resize the all_value_ vector - all_values_.resize(ort_value_idx_map.MaxIdx() + 1); + all_values_.resize(all_values_size_); // 2. Handle non-empty output vector if (!fetches.empty()) { diff --git a/onnxruntime/core/framework/execution_frame.h b/onnxruntime/core/framework/execution_frame.h index c99979edb7..06d042de3b 100644 --- a/onnxruntime/core/framework/execution_frame.h +++ b/onnxruntime/core/framework/execution_frame.h @@ -74,10 +74,10 @@ class IExecutionFrame { void Init(const std::vector& feed_mlvalue_idxs, const std::vector& feeds, const std::unordered_map& initializers, - const std::vector& fetches, const OrtValueNameIdxMap& ort_value_idx_map); + const std::vector& fetches); const OrtValue& GetMLValue(int ort_value_index) const { - ORT_ENFORCE(ort_value_index >= 0 && static_cast(ort_value_index) < all_values_.size()); + ORT_ENFORCE(ort_value_index >= 0 && static_cast(ort_value_index) < all_values_size_); return all_values_[ort_value_index]; } @@ -91,6 +91,9 @@ class IExecutionFrame { // Input and Output values are passed in by executors std::vector all_values_; + // perf optimization to avoid calling all_values_.size() repeatedly as the size is fixed once constructed + const size_t all_values_size_; + const std::vector fetch_mlvalue_idxs_; }; diff --git a/onnxruntime/core/framework/node_index_info.cc b/onnxruntime/core/framework/node_index_info.cc index 7931825e7f..d77a72cabc 100644 --- a/onnxruntime/core/framework/node_index_info.cc +++ b/onnxruntime/core/framework/node_index_info.cc @@ -69,6 +69,10 @@ void NodeIndexInfo::Init(const TValidNodes& nodes, NodeIndex max_node_index, // init all to kInvalidEntry node_offsets_.resize(GetNodeOffsetsIndex(max_node_index), kInvalidEntry); node_values_.resize(total_def_count, kInvalidEntry); + + node_offsets_size_ = node_offsets_.size(); + node_values_size_ = node_values_.size(); + int cur_idx = 0; for (auto& node : nodes) { diff --git a/onnxruntime/core/framework/node_index_info.h b/onnxruntime/core/framework/node_index_info.h index afd74a1874..19b4a202f5 100644 --- a/onnxruntime/core/framework/node_index_info.h +++ b/onnxruntime/core/framework/node_index_info.h @@ -31,14 +31,14 @@ class NodeIndexInfo final { // Returns kInvalidEntry if the Node with the given node_index did not exist when the NodeIndexInfo was created. int GetNodeOffset(NodeIndex node_index) const { auto node_offsets_index = GetNodeOffsetsIndex(node_index); - ORT_ENFORCE(node_offsets_index < node_offsets_.size()); + ORT_ENFORCE(node_offsets_index < node_offsets_size_); return node_offsets_[node_offsets_index]; } // Get the ort_value index value. // Returns kInvalidEntry for optional inputs/outputs that do not exist in this graph. int GetMLValueIndex(int offset) const { - ORT_ENFORCE(offset >= 0 && static_cast(offset) < node_values_.size()); + ORT_ENFORCE(offset >= 0 && static_cast(offset) < node_values_size_); return node_values_[offset]; } @@ -63,5 +63,9 @@ class NodeIndexInfo final { std::vector node_offsets_; const int max_mlvalue_idx_; + + // perf optimization to avoid calls to size() on node_values_ and node_offsets_ as they don't change + size_t node_values_size_; + size_t node_offsets_size_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.h b/onnxruntime/core/providers/cpu/math/element_wise_ops.h index 035ed3ee69..e9d28c8314 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.h +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.h @@ -320,6 +320,11 @@ struct BroadcastIterator { return index; } + void Reserve(int64_t max_dims) { + deltas_.reserve(max_dims); + counts_.reserve(max_dims); + } + void Init(int64_t axis, int64_t largest) { ORT_ENFORCE(axis == 1 || axis == largest, "Attempting to broadcast an axis by a dimension other than 1. ", axis, " by ", largest); @@ -368,6 +373,8 @@ struct Broadcaster { size_t dimension_count_max = std::max(shape1.size(), shape2.size()); size_t dimension_count_min = std::min(shape1.size(), shape2.size()); output_shape_.resize(dimension_count_max); + iterator1_.Reserve(dimension_count_max); + iterator2_.Reserve(dimension_count_max); auto iter1 = shape1.end(); auto iter2 = shape2.end(); @@ -395,22 +402,22 @@ struct Broadcaster { *--output_shape = axis; } index++; // Manually increment since we processed one axis - } + } else { + for (; index < dimension_count_min; index++) { + auto axis1 = *--iter1; + auto axis2 = *--iter2; - for (; index < dimension_count_min; index++) { - auto axis1 = *--iter1; - auto axis2 = *--iter2; + auto largest = std::max(axis1, axis2); + *--output_shape = largest; - auto largest = std::max(axis1, axis2); - *--output_shape = largest; + if (largest == 1 && index + 1 < dimension_count_min) // Nothing to do in this case + continue; - if (largest == 1 && index + 1 < dimension_count_min) // Nothing to do in this case - continue; - - iterator1_.Init(axis1, largest); - iterator2_.Init(axis2, largest); - index++; // Manually increment since we processed one axis - break; + iterator1_.Init(axis1, largest); + iterator2_.Init(axis2, largest); + index++; // Manually increment since we processed one axis + break; + } } for (; index < dimension_count_min; index++) {