// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #pragma once #include #include #include namespace onnxruntime { class Node; /** Class to filter out null entries from either a vector of unique_ptr or a vector of [const] Node* and provide an iterator interface that returns [const] Node& for the valid entries. */ template class ValidNodes { public: template class NodeIterator; // optional filtering function to return a subset of nodes using NodeFilterFunc = std::function; /** Construct a ValidNodes instance to provide iteration over all valid nodes in the TNodesCollection @param[in] nodes Nodes to iterate, skipping invalid entries. */ explicit ValidNodes(TNodesContainer& nodes) noexcept : nodes_(&nodes) {} explicit ValidNodes(TNodesContainer& nodes, NodeFilterFunc&& filter_node_fn) noexcept : nodes_(&nodes), filter_node_fn_{std::move(filter_node_fn)} {} using ConstNodeIterator = NodeIterator; using MutableNodeIterator = NodeIterator; using ConstReverseNodeIterator = NodeIterator; ConstNodeIterator cbegin() const noexcept { return {nodes_->cbegin(), nodes_->cend(), filter_node_fn_}; } ConstNodeIterator cend() const noexcept { return {nodes_->cend(), nodes_->cend(), filter_node_fn_}; } ConstNodeIterator begin() const noexcept { return cbegin(); } ConstNodeIterator end() const noexcept { return cend(); } ConstReverseNodeIterator rbegin() const noexcept { return {nodes_->crbegin(), nodes_->crend(), filter_node_fn_}; } ConstReverseNodeIterator rend() const noexcept { return {nodes_->crend(), nodes_->crend(), filter_node_fn_}; } // we only allow mutable access if the container is non-const. // we need to templatize the functions for enable_if to work at this level, but mandate T2 being TNodesContainer template typename std::enable_if::value, MutableNodeIterator>::type begin() noexcept { static_assert(std::is_same::value, "Explicit specialization is not allowed"); return MutableNodeIterator(nodes_->begin(), nodes_->end(), filter_node_fn_); } template typename std::enable_if::value, MutableNodeIterator>::type end() noexcept { static_assert(std::is_same::value, "Explicit specialization is not allowed"); return MutableNodeIterator(nodes_->end(), nodes_->end(), filter_node_fn_); } bool empty() const noexcept { return nodes_->empty(); } /** @class NodeIterator Iterator to provide const and non-const access to valid Node instances in a Graph. @remarks Skips invalid nodes. */ template class NodeIterator { // get the type being returned by the iterator. can't use TIterator::value_type as that is always non-const using IterType = typename std::remove_reference::reference>::type; // and determine what we will return based on its constness using T = typename std::conditional::value, const Node, // return const Node if this is a const iterator Node>::type; // else return Node public: using iterator_category = std::input_iterator_tag; using value_type = T; using difference_type = typename TIterator::difference_type; using pointer = T*; using reference = T&; using const_reference = const T&; /** Construct a NodeInterator and move to the first valid node. */ NodeIterator(const TIterator current, const TIterator end, const NodeFilterFunc& filter_fn) noexcept : current_{current}, end_{end}, apply_filter_{filter_fn != nullptr}, filter_func_{&filter_fn} { // skip to next valid node, stopping at end if none are found while (current_ < end && (*current_ == nullptr || (apply_filter_ && (*filter_func_)((*current_)->Index()) == true))) { ++current_; } } bool operator==(const NodeIterator& other) const noexcept { return (current_ == other.current_); } bool operator!=(const NodeIterator& other) const noexcept { return (current_ != other.current_); } void operator++() { if (current_ < end_) { while (++current_ != end_) { if (*current_ != nullptr && (!apply_filter_ || (*filter_func_)((*current_)->Index()) == false)) break; } } } NodeIterator operator++(int) { NodeIterator tmp{*this}; ++(*this); return tmp; } /** Return the current Node&. This will be const if the iterator was returned from a const GraphNodes instance. */ reference operator*() { // if iterator is valid we always have a non-nullptr node // if this is a nullptr we're at end_ and this shouldn't be being called return **current_; } pointer operator->() { return current_->get(); } private: TIterator current_; TIterator end_; bool apply_filter_; // store whether filter_func_ is not nullptr and contains a callable const NodeFilterFunc* filter_func_; // store as pointer so iterator is copyable }; private: gsl::not_null nodes_; // always set by ctor // no filtering if not set. this instance owns the filter func if set. NodeFilterFunc filter_node_fn_; }; /** Class that provides iteration over all valid nodes in the Graph. */ class GraphNodes : public ValidNodes>> { public: GraphNodes(std::vector>& nodes) : ValidNodes(nodes) { } }; // Variant that only ever allows const access to nodes and optionally allows filtering of the nodes. class ConstGraphNodes : public ValidNodes>> { public: ConstGraphNodes(const std::vector>& nodes) : ValidNodes(nodes) { } ConstGraphNodes(const std::vector>& nodes, GraphNodes::NodeFilterFunc&& filter_func) : ValidNodes(nodes, std::move(filter_func)) { } }; } // namespace onnxruntime