Filter initializers for GraphViewer with IndexedSubGraph (#5884)

* fix filtered subgraph initializer issue

* minor fix

* Inlcude implicit input of nodes to see if they are initializers

* Add test case

* minor update

* Address PR comments

* Fix some code error
This commit is contained in:
Guoyu Wang 2020-11-20 18:36:58 -08:00 committed by GitHub
parent ba739a8000
commit cc6e8fb7cc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 39 additions and 4 deletions

View file

@ -95,7 +95,7 @@ class GraphViewer {
*/
const ConstGraphNodes& Nodes() const noexcept;
/** Gets the number of valid nodes in the Graph.
/** Gets the number of valid nodes in the Graph.
@remarks Returns the number of nodes in filter_info_ if set.
*/
int NumberOfNodes() const noexcept;
@ -103,7 +103,7 @@ class GraphViewer {
/** Gets the maximum NodeIndex value used by Nodes in the Graph. */
int MaxNodeIndex() const noexcept;
/** Gets the NodeIndex values for the Graph nodes, sorted into topological order.
/** Gets the NodeIndex values for the Graph nodes, sorted into topological order.
@remarks Filtered using filter_info_ if set.
*/
const std::vector<NodeIndex>& GetNodesInTopologicalOrder(ExecutionOrder order = ExecutionOrder::DEFAULT) const;
@ -138,7 +138,7 @@ class GraphViewer {
/**
returns true if 'name' is an initializer, and is constant and cannot be overridden at runtime.
@param check_outer_scope If true and the 'graph_' is a subgraph, check parent graph/s for 'name'
@param check_outer_scope If true and the 'graph_' is a subgraph, check parent graph/s for 'name'
if the name is not found in 'graph_'.
*/
bool IsConstantInitializer(const std::string& name, bool check_outer_scope) const;
@ -188,5 +188,6 @@ class GraphViewer {
std::vector<const NodeArg*> filtered_node_inputs_;
std::vector<const NodeArg*> filtered_node_inputs_including_initializers_;
std::vector<const NodeArg*> filtered_node_outputs_;
InitializedTensorSet filtered_initializers_;
};
} // namespace onnxruntime

View file

@ -121,6 +121,26 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info)
std::copy_if(orig_order.cbegin(), orig_order.cend(), std::back_inserter(nodes_in_topological_order_),
[this](NodeIndex idx) { return filtered_node_indices_.count(idx) != 0; });
// Filter the initializers also
// Get the names of all the inputs and implicit inputs of all the nodes in this subgraph
for (const auto node_idx : filtered_node_indices_) {
const auto* node = GetNode(node_idx);
ORT_ENFORCE(node, "Mismatch between Graph and IndexedSubGraph. Node not found: ", node_idx);
const ONNX_NAMESPACE::TensorProto* tensor = nullptr;
for (const auto* node_input : node->InputDefs()) {
if (graph.GetInitializedTensor(node_input->Name(), tensor)) {
filtered_initializers_.insert({node_input->Name(), tensor});
}
}
// The implicit inputs for subgraphs (if any)
for (const auto* node_input : node->ImplicitInputDefs()) {
if (graph.GetInitializedTensor(node_input->Name(), tensor)) {
filtered_initializers_.insert({node_input->Name(), tensor});
}
}
}
#if !defined(ORT_MINIMAL_BUILD)
auto orig_priority_order = std::move(nodes_in_topological_order_with_priority_);
nodes_in_topological_order_with_priority_.reserve(filter_info->nodes.size());
@ -146,6 +166,10 @@ const std::string& GraphViewer::Description() const noexcept {
bool GraphViewer::GetInitializedTensor(const std::string& tensor_name,
const ONNX_NAMESPACE::TensorProto*& value) const {
// if we are using filtered subgraph, the initializer has to be part of the subgraph
if (filter_info_ != nullptr && filtered_initializers_.find(tensor_name) == filtered_initializers_.cend())
return false;
return graph_->GetInitializedTensor(tensor_name, value);
}
@ -220,7 +244,9 @@ const std::vector<NodeIndex>& GraphViewer::GetRootNodes() const {
}
const InitializedTensorSet& GraphViewer::GetAllInitializedTensors() const noexcept {
return graph_->GetAllInitializedTensors();
return (filter_info_ == nullptr)
? graph_->GetAllInitializedTensors()
: filtered_initializers_;
}
const NodeArg* GraphViewer::GetNodeArg(const std::string& name) const {

View file

@ -92,6 +92,14 @@ TEST(GraphViewer, FilteredGraph) {
EXPECT_EQ(viewer.GetOutputs().size(), final_metadef->outputs.size());
EXPECT_EQ(viewer.IsSubgraph(), false)
<< "GraphViewer is for a filtered set of nodes of a single graph and not a nested subgraph";
// Verify the viewer's initializers are filtered as well
const auto& viewer_initializers = viewer.GetAllInitializedTensors();
EXPECT_EQ(viewer_initializers.size(), initializers.size());
// We should have less initializers in the viewer than the underlying graph
EXPECT_LT(viewer_initializers.size(), graph.GetAllInitializedTensors().size());
// Pick a initializers which is not in the viewer, and check it is not part of the viewer's initializers
EXPECT_TRUE(viewer_initializers.count("Constant15770PastValue16469") == 0);
}
} // namespace test
} // namespace onnxruntime