From 2e563c417cfc6fa7db6b465513b16b757e95ef8b Mon Sep 17 00:00:00 2001 From: Duc Ngo Date: Mon, 20 Aug 2018 10:52:11 -0700 Subject: [PATCH] Nomnigraph - rename some APIs that invole Subtree to Subgraph (#10551) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/10551 Renaming from "subtree" -> "subgraph" to improve clarity of subgraph matcher APIs since it now supports DAG This is pure renaming, no functionalities change. Reviewed By: bwasti Differential Revision: D9348311 fbshipit-source-id: 4b9267845950f3029dfe385ce3257d3abb8bdad4 --- .../nomnigraph/Representations/NeuralNet.cc | 5 +- .../nomnigraph/Representations/NeuralNet.h | 4 +- .../Transformations/SubgraphMatcher.h | 80 +++++++++---------- .../core/nomnigraph/tests/neural_net_test.cc | 16 ++-- .../nomnigraph/tests/subgraph_matcher_test.cc | 80 +++++++++---------- 5 files changed, 93 insertions(+), 92 deletions(-) diff --git a/caffe2/core/nomnigraph/Representations/NeuralNet.cc b/caffe2/core/nomnigraph/Representations/NeuralNet.cc index a60ddb127d5..c31de031f85 100644 --- a/caffe2/core/nomnigraph/Representations/NeuralNet.cc +++ b/caffe2/core/nomnigraph/Representations/NeuralNet.cc @@ -199,12 +199,13 @@ NNNodeMatchCriteria matchAnyNode() { [](NNGraph::NodeRef /* unused */) { return true; }, "matchAnyNode"); } -NNMatchGraph::NodeRef operatorTree( +NNMatchGraph::NodeRef operatorSubgraph( NNMatchGraph& g, const NNNodeMatchCriteria& root, const std::vector& childrenCriteria, int count) { - return tree(g, matchAnyNode(), {tree(g, root, childrenCriteria)}, count); + return subgraph( + g, matchAnyNode(), {subgraph(g, root, childrenCriteria)}, count); } } // namespace nn diff --git a/caffe2/core/nomnigraph/include/nomnigraph/Representations/NeuralNet.h b/caffe2/core/nomnigraph/include/nomnigraph/Representations/NeuralNet.h index ac4e1fa6132..98e1bcba123 100644 --- a/caffe2/core/nomnigraph/include/nomnigraph/Representations/NeuralNet.h +++ b/caffe2/core/nomnigraph/include/nomnigraph/Representations/NeuralNet.h @@ -487,9 +487,9 @@ using NNSubgraphMatcher = nom::matcher::SubgraphMatcher; // This helper method makes it easy to create matching criteria in NNGraph. -// For example, operatorTree(opMatch, ...) will refer to a tree like this: +// For example, operatorSubgraph(opMatch, ...) will refer to a tree like this: // ... -> opMatch -> opMatch_Output -NNMatchGraph::NodeRef operatorTree( +NNMatchGraph::NodeRef operatorSubgraph( NNMatchGraph& g, const NNNodeMatchCriteria& root, const std::vector& childrenCriteria = {}, diff --git a/caffe2/core/nomnigraph/include/nomnigraph/Transformations/SubgraphMatcher.h b/caffe2/core/nomnigraph/include/nomnigraph/Transformations/SubgraphMatcher.h index 5c80bea2d07..9e0f44c896a 100644 --- a/caffe2/core/nomnigraph/include/nomnigraph/Transformations/SubgraphMatcher.h +++ b/caffe2/core/nomnigraph/include/nomnigraph/Transformations/SubgraphMatcher.h @@ -57,7 +57,7 @@ template using MatchNodeRef = typename MatchGraph::NodeRef; template -MatchNodeRef tree( +MatchNodeRef subgraph( MatchGraph& graph, const NodeMatchCriteria& root, const std::vector>& children, @@ -97,19 +97,20 @@ std::string debugString(MatchNodeRef rootCriteriaRef) { } template -class SubtreeMatchResult { +class SubgraphMatchResult { public: - static SubtreeMatchResult notMatched( + static SubgraphMatchResult notMatched( const std::string& debugMessage) { - return SubtreeMatchResult(false, debugMessage); + return SubgraphMatchResult(false, debugMessage); } - static SubtreeMatchResult notMatched() { - return SubtreeMatchResult(false, "Debug message is not enabled"); + static SubgraphMatchResult notMatched() { + return SubgraphMatchResult( + false, "Debug message is not enabled"); } - static SubtreeMatchResult matched() { - return SubtreeMatchResult(true, "Matched"); + static SubgraphMatchResult matched() { + return SubgraphMatchResult(true, "Matched"); } bool isMatch() const { @@ -121,7 +122,7 @@ class SubtreeMatchResult { } private: - SubtreeMatchResult(bool isMatch, const std::string& debugMessage) + SubgraphMatchResult(bool isMatch, const std::string& debugMessage) : isMatch_(isMatch), debugMessage_(debugMessage) {} const bool isMatch_; @@ -142,12 +143,12 @@ struct SubgraphMatcher { return NodeMatcherClass::isMatch(node, criteria); } - // Check if there can be a sub-tree that matches the given criteria that + // Check if there can be a subgraph that matches the given criteria that // is rooted at the given rootNode. // The flag invertGraphTraversal specify if we should follow out edges or // in edges. The default is true which is useful for a functional // intepretation of a dataflow graph. - static SubtreeMatchResult isSubtreeMatch( + static SubgraphMatchResult isSubgraphMatch( typename GraphType::NodeRef root, const MatchNodeRef& rootCriteriaRef, bool invertGraphTraversal = true, @@ -156,25 +157,24 @@ struct SubgraphMatcher { MatchNodeRef, typename GraphType::NodeRef> matchedNodes; - return isSubtreeMatchInternal( + return isSubgraphMatchInternal( matchedNodes, root, rootCriteriaRef, invertGraphTraversal, debug); } - // Utility to transform a graph by looking for subtrees that match + // Utility to transform a graph by looking for subgraphs that match // a given pattern and then allow callers to mutate the graph based on - // subtrees that are found. + // subgraphs that are found. // The current implementation doesn't handle any graph transformation // itself. Callers should be responsible for all intended mutation, including - // deleting nodes in the subtrees found by this algorithm. + // deleting nodes in the subgraphs found by this algorithm. // Note: if the replaceFunction lambda returns false, the entire procedure // is aborted. This maybe useful in certain cases when we want to terminate - // the subtree search early. - // invertGraphTraversal flag: see documentation in isSubtreeMatch - static void replaceSubtree( + // the subgraph search early. + // invertGraphTraversal flag: see documentation in isSubgraphMatch + static void replaceSubgraph( GraphType& graph, const MatchNodeRef& criteria, - const std::function< - bool(GraphType& g, typename GraphType::NodeRef subtreeRoot)>& + const std::function& replaceFunction, bool invertGraphTraversal = true) { for (auto nodeRef : graph.getMutableNodes()) { @@ -182,7 +182,7 @@ struct SubgraphMatcher { if (!graph.hasNode(nodeRef)) { continue; } - if (isSubtreeMatch(nodeRef, criteria, invertGraphTraversal).isMatch()) { + if (isSubgraphMatch(nodeRef, criteria, invertGraphTraversal).isMatch()) { if (!replaceFunction(graph, nodeRef)) { // If replaceFunction returns false, it means that we should abort // the entire procedure. @@ -193,7 +193,7 @@ struct SubgraphMatcher { } private: - static SubtreeMatchResult isSubtreeMatchInternal( + static SubgraphMatchResult isSubgraphMatchInternal( std::unordered_map< MatchNodeRef, typename GraphType::NodeRef>& matchedNodes, @@ -211,15 +211,15 @@ struct SubgraphMatcher { // and verify if it is the same. auto matchedNode = matchedNodeEntry->second; if (matchedNode == root) { - return SubtreeMatchResult::matched(); + return SubgraphMatchResult::matched(); } else if (debug) { std::ostringstream debugMessage; - debugMessage << "Subtree root at " << root << " is not the same as " + debugMessage << "Subgraph root at " << root << " is not the same as " << matchedNode << " which previously matched criteria " << debugString(rootCriteriaRef); - return SubtreeMatchResult::notMatched(debugMessage.str()); + return SubgraphMatchResult::notMatched(debugMessage.str()); } else { - return SubtreeMatchResult::notMatched(); + return SubgraphMatchResult::notMatched(); } } } @@ -227,19 +227,19 @@ struct SubgraphMatcher { if (!isNodeMatch(root, rootCriteriaNode.getCriteria())) { if (debug) { std::ostringstream debugMessage; - debugMessage << "Subtree root at " << root + debugMessage << "Subgraph root at " << root << " does not match criteria " << debugString(rootCriteriaRef); - return SubtreeMatchResult::notMatched(debugMessage.str()); + return SubgraphMatchResult::notMatched(debugMessage.str()); } else { - return SubtreeMatchResult::notMatched(); + return SubgraphMatchResult::notMatched(); } } if (rootCriteriaNode.isNonTerminal()) { // This is sufficient to be a match if this criteria specifies a non // terminal node. matchedNodes[rootCriteriaRef] = root; - return SubtreeMatchResult::matched(); + return SubgraphMatchResult::matched(); } auto& edges = invertGraphTraversal ? root->getInEdges() : root->getOutEdges(); @@ -249,7 +249,7 @@ struct SubgraphMatcher { int numChildrenCriteria = outEdges.size(); // The current algorithm implies that the ordering of the children is - // important. The children nodes will be matched with the children subtree + // important. The children nodes will be matched with the children subgraph // criteria in the given order. int currentEdgeIdx = 0; @@ -273,7 +273,7 @@ struct SubgraphMatcher { auto edge = edges[currentEdgeIdx]; auto child = invertGraphTraversal ? edge->tail() : edge->head(); - if (!isSubtreeMatchInternal( + if (!isSubgraphMatchInternal( matchedNodes, child, childrenCriteriaRef, invertGraphTraversal) .isMatch()) { if (!isStarCount) { @@ -287,10 +287,10 @@ struct SubgraphMatcher { childrenCriteriaRef) << ". We expected " << expectedCount << " matches but only found " << countMatch << "."; - return SubtreeMatchResult::notMatched( + return SubgraphMatchResult::notMatched( debugMessage.str()); } else { - return SubtreeMatchResult::notMatched(); + return SubgraphMatchResult::notMatched(); } } else { // Otherwise, we should move on to the next children criteria. @@ -310,9 +310,9 @@ struct SubgraphMatcher { << " matches for child criteria " << debugString(childrenCriteriaRef) << " but only found " << countMatch; - return SubtreeMatchResult::notMatched(debugMessage.str()); + return SubgraphMatchResult::notMatched(debugMessage.str()); } else { - return SubtreeMatchResult::notMatched(); + return SubgraphMatchResult::notMatched(); } } } @@ -321,17 +321,17 @@ struct SubgraphMatcher { // Fails because there are unmatched edges. if (debug) { std::ostringstream debugMessage; - debugMessage << "Unmatched children for subtree root at " << root + debugMessage << "Unmatched children for subgraph root at " << root << ". There are " << numEdges << " children, but only found " << currentEdgeIdx << " matches for the children criteria."; - return SubtreeMatchResult::notMatched(debugMessage.str()); + return SubgraphMatchResult::notMatched(debugMessage.str()); } else { - return SubtreeMatchResult::notMatched(); + return SubgraphMatchResult::notMatched(); } } matchedNodes[rootCriteriaRef] = root; - return SubtreeMatchResult::matched(); + return SubgraphMatchResult::matched(); } }; diff --git a/caffe2/core/nomnigraph/tests/neural_net_test.cc b/caffe2/core/nomnigraph/tests/neural_net_test.cc index bdafce3b364..34dd9840309 100644 --- a/caffe2/core/nomnigraph/tests/neural_net_test.cc +++ b/caffe2/core/nomnigraph/tests/neural_net_test.cc @@ -44,23 +44,23 @@ TEST(NeuralNetGraph, ReplaceGraph) { auto mg = NNMatchGraph(); // clang-format off - auto pattern = tree(mg, + auto pattern = subgraph(mg, matchNodeType(), { - operatorTree(mg, + operatorSubgraph(mg, matchNodeType(), { - tree(mg, matchNodeType(), {}, 2, true) + subgraph(mg, matchNodeType(), {}, 2, true) }), }); // clang-format on - EXPECT_FALSE(NNSubgraphMatcher::isSubtreeMatch(sum, pattern).isMatch()); + EXPECT_FALSE(NNSubgraphMatcher::isSubgraphMatch(sum, pattern).isMatch()); EXPECT_FALSE( - NNSubgraphMatcher::isSubtreeMatch(reluOutput, pattern).isMatch()); - EXPECT_FALSE(NNSubgraphMatcher::isSubtreeMatch(input1, pattern).isMatch()); + NNSubgraphMatcher::isSubgraphMatch(reluOutput, pattern).isMatch()); + EXPECT_FALSE(NNSubgraphMatcher::isSubgraphMatch(input1, pattern).isMatch()); - EXPECT_TRUE(NNSubgraphMatcher::isSubtreeMatch(relu, pattern).isMatch()); + EXPECT_TRUE(NNSubgraphMatcher::isSubgraphMatch(relu, pattern).isMatch()); - NNSubgraphMatcher::replaceSubtree( + NNSubgraphMatcher::replaceSubgraph( graph, pattern, [](NNGraph& g, NNGraph::NodeRef relu) { auto sumOutput = getInputs(relu)[0]; auto sum = getProducer(sumOutput); diff --git a/caffe2/core/nomnigraph/tests/subgraph_matcher_test.cc b/caffe2/core/nomnigraph/tests/subgraph_matcher_test.cc index 7a5ed1af548..ced26d69beb 100644 --- a/caffe2/core/nomnigraph/tests/subgraph_matcher_test.cc +++ b/caffe2/core/nomnigraph/tests/subgraph_matcher_test.cc @@ -41,11 +41,11 @@ TestMatchGraph::NodeRef Tree( const Criteria& root, const std::vector& children = {}, int count = 1) { - return tree(graph, root, children, count, false); + return subgraph(graph, root, children, count, false); } TestMatchGraph::NodeRef NonTerminal(const Criteria& root, int count = 1) { - return tree(graph, root, {}, count, true); + return subgraph(graph, root, {}, count, true); } Criteria any() { @@ -202,11 +202,11 @@ TestGraph::NodeRef getInNode(TestGraph::NodeRef node, int index) { return node->getInEdges()[index]->tail(); } -bool isSubtreeMatch( +bool isSubgraphMatch( TestGraph::NodeRef nodeRef, const TestMatchGraph::NodeRef& criteria, bool invertGraphTraversal = true) { - return TestMatcher::isSubtreeMatch(nodeRef, criteria, invertGraphTraversal) + return TestMatcher::isSubgraphMatch(nodeRef, criteria, invertGraphTraversal) .isMatch(); } } // namespace matcher @@ -254,32 +254,32 @@ TEST(SubgraphMatcher, IsSubtreeMatch) { reset(); auto subtree = Tree(any(), {Tree(any()), Tree(any())}); - EXPECT_FALSE(isSubtreeMatch(n1, subtree, false)); - EXPECT_FALSE(isSubtreeMatch(n4, subtree, false)); + EXPECT_FALSE(isSubgraphMatch(n1, subtree, false)); + EXPECT_FALSE(isSubgraphMatch(n4, subtree, false)); - EXPECT_TRUE(isSubtreeMatch(n2, subtree, false)); - EXPECT_TRUE(isSubtreeMatch(n5, subtree, false)); + EXPECT_TRUE(isSubgraphMatch(n2, subtree, false)); + EXPECT_TRUE(isSubgraphMatch(n5, subtree, false)); reset(); subtree = Tree(Criteria("5"), {Tree(any()), Tree(any())}); - EXPECT_FALSE(isSubtreeMatch(n2, subtree, false)); - EXPECT_TRUE(isSubtreeMatch(n5, subtree, false)); + EXPECT_FALSE(isSubgraphMatch(n2, subtree, false)); + EXPECT_TRUE(isSubgraphMatch(n5, subtree, false)); reset(); subtree = Tree(any(), {Tree(any()), Tree(Criteria("4"))}); - EXPECT_TRUE(isSubtreeMatch(n2, subtree, false)); - EXPECT_FALSE(isSubtreeMatch(n5, subtree, false)); + EXPECT_TRUE(isSubgraphMatch(n2, subtree, false)); + EXPECT_FALSE(isSubgraphMatch(n5, subtree, false)); reset(); // Accepts non terminal node subtree = Tree(any(), {NonTerminal(any()), NonTerminal(any())}); - EXPECT_TRUE(isSubtreeMatch(n1, subtree, false)); - EXPECT_TRUE(isSubtreeMatch(n2, subtree, false)); - EXPECT_TRUE(isSubtreeMatch(n5, subtree, false)); - EXPECT_FALSE(isSubtreeMatch(n3, subtree, false)); - EXPECT_FALSE(isSubtreeMatch(n4, subtree, false)); - EXPECT_FALSE(isSubtreeMatch(n6, subtree, false)); - EXPECT_FALSE(isSubtreeMatch(n7, subtree, false)); + EXPECT_TRUE(isSubgraphMatch(n1, subtree, false)); + EXPECT_TRUE(isSubgraphMatch(n2, subtree, false)); + EXPECT_TRUE(isSubgraphMatch(n5, subtree, false)); + EXPECT_FALSE(isSubgraphMatch(n3, subtree, false)); + EXPECT_FALSE(isSubgraphMatch(n4, subtree, false)); + EXPECT_FALSE(isSubgraphMatch(n6, subtree, false)); + EXPECT_FALSE(isSubgraphMatch(n7, subtree, false)); } // Test subtree matching in which * (repeated) matching of children is allowed. @@ -304,11 +304,11 @@ TEST(SubgraphMatcher, IsSubtreeMatchRepeated) { reset(); auto subtree = Tree(any(), {Tree(Criteria("2"))}); - EXPECT_FALSE(isSubtreeMatch(n1, subtree, false)); + EXPECT_FALSE(isSubgraphMatch(n1, subtree, false)); reset(); subtree = Tree(any(), {Tree(Criteria("2"), {}, TestMatchNode::kStarCount)}); - EXPECT_FALSE(isSubtreeMatch(n1, subtree, false)); + EXPECT_FALSE(isSubgraphMatch(n1, subtree, false)); reset(); // clang-format off @@ -318,7 +318,7 @@ TEST(SubgraphMatcher, IsSubtreeMatchRepeated) { Tree(Criteria("4"), {}, 2), Tree(Criteria("5"), {}, 3) }); - EXPECT_TRUE(isSubtreeMatch(n1, subtree, false)); + EXPECT_TRUE(isSubgraphMatch(n1, subtree, false)); reset(); subtree = Tree(any(), { @@ -328,7 +328,7 @@ TEST(SubgraphMatcher, IsSubtreeMatchRepeated) { Tree(Criteria("5"), {}, 4) }); // Failes because exepected 4 matches of n5 but found 3. - EXPECT_FALSE(isSubtreeMatch(n1, subtree, false)); + EXPECT_FALSE(isSubgraphMatch(n1, subtree, false)); reset(); subtree = Tree(any(), { @@ -337,7 +337,7 @@ TEST(SubgraphMatcher, IsSubtreeMatchRepeated) { Tree(Criteria("4"), {}, 2), Tree(Criteria("5"), {}, TestMatchNode::kStarCount) }); - EXPECT_TRUE(isSubtreeMatch(n1, subtree, false)); + EXPECT_TRUE(isSubgraphMatch(n1, subtree, false)); reset(); subtree = Tree(any(), { @@ -346,7 +346,7 @@ TEST(SubgraphMatcher, IsSubtreeMatchRepeated) { Tree(Criteria("4"), {}, 2), Tree(Criteria("5"), {}, TestMatchNode::kStarCount) }); - EXPECT_TRUE(isSubtreeMatch(n1, subtree, false)); + EXPECT_TRUE(isSubgraphMatch(n1, subtree, false)); reset(); subtree = Tree(any(), { @@ -354,7 +354,7 @@ TEST(SubgraphMatcher, IsSubtreeMatchRepeated) { Tree(Criteria("3"), {}, TestMatchNode::kStarCount), }); // Fails because there are unmatched edges. - EXPECT_FALSE(isSubtreeMatch(n1, subtree, false)); + EXPECT_FALSE(isSubgraphMatch(n1, subtree, false)); reset(); subtree = Tree(any(), { @@ -365,7 +365,7 @@ TEST(SubgraphMatcher, IsSubtreeMatchRepeated) { }); // Fails because the count is wrong; we have 2 edges to node N4 while // the pattern expects only 1. - EXPECT_FALSE(isSubtreeMatch(n1, subtree, false)); + EXPECT_FALSE(isSubgraphMatch(n1, subtree, false)); // clang-format on } @@ -376,7 +376,7 @@ TEST(SubgraphMatcher, DagMatching) { auto n4match = Tree(Criteria("4"), { Tree(Criteria("5")) }); - auto subtree = Tree(Criteria("1"), { + auto subgraph = Tree(Criteria("1"), { Tree(Criteria("2"), { n4match }), @@ -409,7 +409,7 @@ TEST(SubgraphMatcher, DagMatching) { N5 */ - EXPECT_TRUE(isSubtreeMatch(n1, subtree, false)); + EXPECT_TRUE(isSubgraphMatch(n1, subgraph, false)); } { @@ -438,7 +438,7 @@ TEST(SubgraphMatcher, DagMatching) { */ // This should fail because n4A and n4B are not the same node. - EXPECT_FALSE(isSubtreeMatch(n1, subtree, false)); + EXPECT_FALSE(isSubgraphMatch(n1, subgraph, false)); } } @@ -447,7 +447,7 @@ TEST(SubgraphMatcher, DagMatchingMultiEdges) { // clang-format off auto n2match = Tree(Criteria("2")); - auto subtree = Tree(Criteria("1"), { + auto subgraph = Tree(Criteria("1"), { n2match, n2match }); @@ -461,7 +461,7 @@ TEST(SubgraphMatcher, DagMatchingMultiEdges) { graph.createEdge(n1, n2); graph.createEdge(n1, n2); - EXPECT_TRUE(isSubtreeMatch(n1, subtree, false)); + EXPECT_TRUE(isSubgraphMatch(n1, subgraph, false)); } { @@ -473,7 +473,7 @@ TEST(SubgraphMatcher, DagMatchingMultiEdges) { graph.createEdge(n1, n2A); graph.createEdge(n1, n2B); - EXPECT_FALSE(isSubtreeMatch(n1, subtree, false)); + EXPECT_FALSE(isSubgraphMatch(n1, subgraph, false)); } } @@ -533,7 +533,7 @@ TEST(SubgraphMatcher, DagMatchingRandomLargeGraph) { int countMatch = 0; for (auto node : graph.getMutableNodes()) { - if (isSubtreeMatch(node, subtree, false)) { + if (isSubgraphMatch(node, subtree, false)) { countMatch++; } } @@ -545,12 +545,12 @@ TEST(SubgraphMatcher, IsSubtreeMatchRealistic) { auto graph = DataFlowTestGraph(); auto subtree = DataFlowTestGraphCriteria(); - EXPECT_FALSE(isSubtreeMatch(graph.opF, subtree)); - EXPECT_FALSE(isSubtreeMatch(graph.opC, subtree)); - EXPECT_FALSE(isSubtreeMatch(graph.opB, subtree)); - EXPECT_FALSE(isSubtreeMatch(graph.dataOut, subtree)); + EXPECT_FALSE(isSubgraphMatch(graph.opF, subtree)); + EXPECT_FALSE(isSubgraphMatch(graph.opC, subtree)); + EXPECT_FALSE(isSubgraphMatch(graph.opB, subtree)); + EXPECT_FALSE(isSubgraphMatch(graph.dataOut, subtree)); - EXPECT_TRUE(isSubtreeMatch(graph.opG, subtree)); + EXPECT_TRUE(isSubgraphMatch(graph.opG, subtree)); } TEST(SubgraphMatcher, ReplaceSubtreeRealistic) { @@ -558,7 +558,7 @@ TEST(SubgraphMatcher, ReplaceSubtreeRealistic) { auto graph = DataFlowTestGraph(); auto subtree = DataFlowTestGraphCriteria(); - TestMatcher::replaceSubtree( + TestMatcher::replaceSubgraph( graph.graph, subtree, [](TestGraph& g, TestGraph::NodeRef opG) { auto opFused = g.createNode("opFused");