Nomnigraph - rename some APIs that invole Subtree to Subgraph (#10551)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10551

Renaming from "subtree" -> "subgraph" to improve clarity of subgraph matcher APIs since it now supports DAG

This is pure renaming, no functionalities change.

Reviewed By: bwasti

Differential Revision: D9348311

fbshipit-source-id: 4b9267845950f3029dfe385ce3257d3abb8bdad4
This commit is contained in:
Duc Ngo 2018-08-20 10:52:11 -07:00 committed by Facebook Github Bot
parent aa9f328fa3
commit 2e563c417c
5 changed files with 93 additions and 92 deletions

View file

@ -199,12 +199,13 @@ NNNodeMatchCriteria matchAnyNode() {
[](NNGraph::NodeRef /* unused */) { return true; }, "matchAnyNode");
}
NNMatchGraph::NodeRef operatorTree(
NNMatchGraph::NodeRef operatorSubgraph(
NNMatchGraph& g,
const NNNodeMatchCriteria& root,
const std::vector<NNMatchGraph::NodeRef>& childrenCriteria,
int count) {
return tree(g, matchAnyNode(), {tree(g, root, childrenCriteria)}, count);
return subgraph(
g, matchAnyNode(), {subgraph(g, root, childrenCriteria)}, count);
}
} // namespace nn

View file

@ -487,9 +487,9 @@ using NNSubgraphMatcher =
nom::matcher::SubgraphMatcher<NNGraph, NNNodeMatchCriteria, NNNodeMatch>;
// This helper method makes it easy to create matching criteria in NNGraph.
// For example, operatorTree(opMatch, ...) will refer to a tree like this:
// For example, operatorSubgraph(opMatch, ...) will refer to a tree like this:
// ... -> opMatch -> opMatch_Output
NNMatchGraph::NodeRef operatorTree(
NNMatchGraph::NodeRef operatorSubgraph(
NNMatchGraph& g,
const NNNodeMatchCriteria& root,
const std::vector<NNMatchGraph::NodeRef>& childrenCriteria = {},

View file

@ -57,7 +57,7 @@ template <typename NodeMatchCriteria>
using MatchNodeRef = typename MatchGraph<NodeMatchCriteria>::NodeRef;
template <typename NodeMatchCriteria>
MatchNodeRef<NodeMatchCriteria> tree(
MatchNodeRef<NodeMatchCriteria> subgraph(
MatchGraph<NodeMatchCriteria>& graph,
const NodeMatchCriteria& root,
const std::vector<MatchNodeRef<NodeMatchCriteria>>& children,
@ -97,19 +97,20 @@ std::string debugString(MatchNodeRef<NodeMatchCriteria> rootCriteriaRef) {
}
template <typename GraphType>
class SubtreeMatchResult {
class SubgraphMatchResult {
public:
static SubtreeMatchResult<GraphType> notMatched(
static SubgraphMatchResult<GraphType> notMatched(
const std::string& debugMessage) {
return SubtreeMatchResult<GraphType>(false, debugMessage);
return SubgraphMatchResult<GraphType>(false, debugMessage);
}
static SubtreeMatchResult<GraphType> notMatched() {
return SubtreeMatchResult<GraphType>(false, "Debug message is not enabled");
static SubgraphMatchResult<GraphType> notMatched() {
return SubgraphMatchResult<GraphType>(
false, "Debug message is not enabled");
}
static SubtreeMatchResult<GraphType> matched() {
return SubtreeMatchResult<GraphType>(true, "Matched");
static SubgraphMatchResult<GraphType> matched() {
return SubgraphMatchResult<GraphType>(true, "Matched");
}
bool isMatch() const {
@ -121,7 +122,7 @@ class SubtreeMatchResult {
}
private:
SubtreeMatchResult(bool isMatch, const std::string& debugMessage)
SubgraphMatchResult(bool isMatch, const std::string& debugMessage)
: isMatch_(isMatch), debugMessage_(debugMessage) {}
const bool isMatch_;
@ -142,12 +143,12 @@ struct SubgraphMatcher {
return NodeMatcherClass::isMatch(node, criteria);
}
// Check if there can be a sub-tree that matches the given criteria that
// Check if there can be a subgraph that matches the given criteria that
// is rooted at the given rootNode.
// The flag invertGraphTraversal specify if we should follow out edges or
// in edges. The default is true which is useful for a functional
// intepretation of a dataflow graph.
static SubtreeMatchResult<GraphType> isSubtreeMatch(
static SubgraphMatchResult<GraphType> isSubgraphMatch(
typename GraphType::NodeRef root,
const MatchNodeRef<NodeMatchCriteria>& rootCriteriaRef,
bool invertGraphTraversal = true,
@ -156,25 +157,24 @@ struct SubgraphMatcher {
MatchNodeRef<NodeMatchCriteria>,
typename GraphType::NodeRef>
matchedNodes;
return isSubtreeMatchInternal(
return isSubgraphMatchInternal(
matchedNodes, root, rootCriteriaRef, invertGraphTraversal, debug);
}
// Utility to transform a graph by looking for subtrees that match
// Utility to transform a graph by looking for subgraphs that match
// a given pattern and then allow callers to mutate the graph based on
// subtrees that are found.
// subgraphs that are found.
// The current implementation doesn't handle any graph transformation
// itself. Callers should be responsible for all intended mutation, including
// deleting nodes in the subtrees found by this algorithm.
// deleting nodes in the subgraphs found by this algorithm.
// Note: if the replaceFunction lambda returns false, the entire procedure
// is aborted. This maybe useful in certain cases when we want to terminate
// the subtree search early.
// invertGraphTraversal flag: see documentation in isSubtreeMatch
static void replaceSubtree(
// the subgraph search early.
// invertGraphTraversal flag: see documentation in isSubgraphMatch
static void replaceSubgraph(
GraphType& graph,
const MatchNodeRef<NodeMatchCriteria>& criteria,
const std::function<
bool(GraphType& g, typename GraphType::NodeRef subtreeRoot)>&
const std::function<bool(GraphType&, typename GraphType::NodeRef)>&
replaceFunction,
bool invertGraphTraversal = true) {
for (auto nodeRef : graph.getMutableNodes()) {
@ -182,7 +182,7 @@ struct SubgraphMatcher {
if (!graph.hasNode(nodeRef)) {
continue;
}
if (isSubtreeMatch(nodeRef, criteria, invertGraphTraversal).isMatch()) {
if (isSubgraphMatch(nodeRef, criteria, invertGraphTraversal).isMatch()) {
if (!replaceFunction(graph, nodeRef)) {
// If replaceFunction returns false, it means that we should abort
// the entire procedure.
@ -193,7 +193,7 @@ struct SubgraphMatcher {
}
private:
static SubtreeMatchResult<GraphType> isSubtreeMatchInternal(
static SubgraphMatchResult<GraphType> isSubgraphMatchInternal(
std::unordered_map<
MatchNodeRef<NodeMatchCriteria>,
typename GraphType::NodeRef>& matchedNodes,
@ -211,15 +211,15 @@ struct SubgraphMatcher {
// and verify if it is the same.
auto matchedNode = matchedNodeEntry->second;
if (matchedNode == root) {
return SubtreeMatchResult<GraphType>::matched();
return SubgraphMatchResult<GraphType>::matched();
} else if (debug) {
std::ostringstream debugMessage;
debugMessage << "Subtree root at " << root << " is not the same as "
debugMessage << "Subgraph root at " << root << " is not the same as "
<< matchedNode << " which previously matched criteria "
<< debugString<NodeMatchCriteria>(rootCriteriaRef);
return SubtreeMatchResult<GraphType>::notMatched(debugMessage.str());
return SubgraphMatchResult<GraphType>::notMatched(debugMessage.str());
} else {
return SubtreeMatchResult<GraphType>::notMatched();
return SubgraphMatchResult<GraphType>::notMatched();
}
}
}
@ -227,19 +227,19 @@ struct SubgraphMatcher {
if (!isNodeMatch(root, rootCriteriaNode.getCriteria())) {
if (debug) {
std::ostringstream debugMessage;
debugMessage << "Subtree root at " << root
debugMessage << "Subgraph root at " << root
<< " does not match criteria "
<< debugString<NodeMatchCriteria>(rootCriteriaRef);
return SubtreeMatchResult<GraphType>::notMatched(debugMessage.str());
return SubgraphMatchResult<GraphType>::notMatched(debugMessage.str());
} else {
return SubtreeMatchResult<GraphType>::notMatched();
return SubgraphMatchResult<GraphType>::notMatched();
}
}
if (rootCriteriaNode.isNonTerminal()) {
// This is sufficient to be a match if this criteria specifies a non
// terminal node.
matchedNodes[rootCriteriaRef] = root;
return SubtreeMatchResult<GraphType>::matched();
return SubgraphMatchResult<GraphType>::matched();
}
auto& edges =
invertGraphTraversal ? root->getInEdges() : root->getOutEdges();
@ -249,7 +249,7 @@ struct SubgraphMatcher {
int numChildrenCriteria = outEdges.size();
// The current algorithm implies that the ordering of the children is
// important. The children nodes will be matched with the children subtree
// important. The children nodes will be matched with the children subgraph
// criteria in the given order.
int currentEdgeIdx = 0;
@ -273,7 +273,7 @@ struct SubgraphMatcher {
auto edge = edges[currentEdgeIdx];
auto child = invertGraphTraversal ? edge->tail() : edge->head();
if (!isSubtreeMatchInternal(
if (!isSubgraphMatchInternal(
matchedNodes, child, childrenCriteriaRef, invertGraphTraversal)
.isMatch()) {
if (!isStarCount) {
@ -287,10 +287,10 @@ struct SubgraphMatcher {
childrenCriteriaRef)
<< ". We expected " << expectedCount
<< " matches but only found " << countMatch << ".";
return SubtreeMatchResult<GraphType>::notMatched(
return SubgraphMatchResult<GraphType>::notMatched(
debugMessage.str());
} else {
return SubtreeMatchResult<GraphType>::notMatched();
return SubgraphMatchResult<GraphType>::notMatched();
}
} else {
// Otherwise, we should move on to the next children criteria.
@ -310,9 +310,9 @@ struct SubgraphMatcher {
<< " matches for child criteria "
<< debugString<NodeMatchCriteria>(childrenCriteriaRef)
<< " but only found " << countMatch;
return SubtreeMatchResult<GraphType>::notMatched(debugMessage.str());
return SubgraphMatchResult<GraphType>::notMatched(debugMessage.str());
} else {
return SubtreeMatchResult<GraphType>::notMatched();
return SubgraphMatchResult<GraphType>::notMatched();
}
}
}
@ -321,17 +321,17 @@ struct SubgraphMatcher {
// Fails because there are unmatched edges.
if (debug) {
std::ostringstream debugMessage;
debugMessage << "Unmatched children for subtree root at " << root
debugMessage << "Unmatched children for subgraph root at " << root
<< ". There are " << numEdges
<< " children, but only found " << currentEdgeIdx
<< " matches for the children criteria.";
return SubtreeMatchResult<GraphType>::notMatched(debugMessage.str());
return SubgraphMatchResult<GraphType>::notMatched(debugMessage.str());
} else {
return SubtreeMatchResult<GraphType>::notMatched();
return SubgraphMatchResult<GraphType>::notMatched();
}
}
matchedNodes[rootCriteriaRef] = root;
return SubtreeMatchResult<GraphType>::matched();
return SubgraphMatchResult<GraphType>::matched();
}
};

View file

@ -44,23 +44,23 @@ TEST(NeuralNetGraph, ReplaceGraph) {
auto mg = NNMatchGraph();
// clang-format off
auto pattern = tree(mg,
auto pattern = subgraph(mg,
matchNodeType<Relu>(), {
operatorTree(mg,
operatorSubgraph(mg,
matchNodeType<Sum>(), {
tree(mg, matchNodeType<Tensor>(), {}, 2, true)
subgraph(mg, matchNodeType<Tensor>(), {}, 2, true)
}),
});
// clang-format on
EXPECT_FALSE(NNSubgraphMatcher::isSubtreeMatch(sum, pattern).isMatch());
EXPECT_FALSE(NNSubgraphMatcher::isSubgraphMatch(sum, pattern).isMatch());
EXPECT_FALSE(
NNSubgraphMatcher::isSubtreeMatch(reluOutput, pattern).isMatch());
EXPECT_FALSE(NNSubgraphMatcher::isSubtreeMatch(input1, pattern).isMatch());
NNSubgraphMatcher::isSubgraphMatch(reluOutput, pattern).isMatch());
EXPECT_FALSE(NNSubgraphMatcher::isSubgraphMatch(input1, pattern).isMatch());
EXPECT_TRUE(NNSubgraphMatcher::isSubtreeMatch(relu, pattern).isMatch());
EXPECT_TRUE(NNSubgraphMatcher::isSubgraphMatch(relu, pattern).isMatch());
NNSubgraphMatcher::replaceSubtree(
NNSubgraphMatcher::replaceSubgraph(
graph, pattern, [](NNGraph& g, NNGraph::NodeRef relu) {
auto sumOutput = getInputs(relu)[0];
auto sum = getProducer(sumOutput);

View file

@ -41,11 +41,11 @@ TestMatchGraph::NodeRef Tree(
const Criteria& root,
const std::vector<TestMatchGraph::NodeRef>& children = {},
int count = 1) {
return tree(graph, root, children, count, false);
return subgraph(graph, root, children, count, false);
}
TestMatchGraph::NodeRef NonTerminal(const Criteria& root, int count = 1) {
return tree(graph, root, {}, count, true);
return subgraph(graph, root, {}, count, true);
}
Criteria any() {
@ -202,11 +202,11 @@ TestGraph::NodeRef getInNode(TestGraph::NodeRef node, int index) {
return node->getInEdges()[index]->tail();
}
bool isSubtreeMatch(
bool isSubgraphMatch(
TestGraph::NodeRef nodeRef,
const TestMatchGraph::NodeRef& criteria,
bool invertGraphTraversal = true) {
return TestMatcher::isSubtreeMatch(nodeRef, criteria, invertGraphTraversal)
return TestMatcher::isSubgraphMatch(nodeRef, criteria, invertGraphTraversal)
.isMatch();
}
} // namespace matcher
@ -254,32 +254,32 @@ TEST(SubgraphMatcher, IsSubtreeMatch) {
reset();
auto subtree = Tree(any(), {Tree(any()), Tree(any())});
EXPECT_FALSE(isSubtreeMatch(n1, subtree, false));
EXPECT_FALSE(isSubtreeMatch(n4, subtree, false));
EXPECT_FALSE(isSubgraphMatch(n1, subtree, false));
EXPECT_FALSE(isSubgraphMatch(n4, subtree, false));
EXPECT_TRUE(isSubtreeMatch(n2, subtree, false));
EXPECT_TRUE(isSubtreeMatch(n5, subtree, false));
EXPECT_TRUE(isSubgraphMatch(n2, subtree, false));
EXPECT_TRUE(isSubgraphMatch(n5, subtree, false));
reset();
subtree = Tree(Criteria("5"), {Tree(any()), Tree(any())});
EXPECT_FALSE(isSubtreeMatch(n2, subtree, false));
EXPECT_TRUE(isSubtreeMatch(n5, subtree, false));
EXPECT_FALSE(isSubgraphMatch(n2, subtree, false));
EXPECT_TRUE(isSubgraphMatch(n5, subtree, false));
reset();
subtree = Tree(any(), {Tree(any()), Tree(Criteria("4"))});
EXPECT_TRUE(isSubtreeMatch(n2, subtree, false));
EXPECT_FALSE(isSubtreeMatch(n5, subtree, false));
EXPECT_TRUE(isSubgraphMatch(n2, subtree, false));
EXPECT_FALSE(isSubgraphMatch(n5, subtree, false));
reset();
// Accepts non terminal node
subtree = Tree(any(), {NonTerminal(any()), NonTerminal(any())});
EXPECT_TRUE(isSubtreeMatch(n1, subtree, false));
EXPECT_TRUE(isSubtreeMatch(n2, subtree, false));
EXPECT_TRUE(isSubtreeMatch(n5, subtree, false));
EXPECT_FALSE(isSubtreeMatch(n3, subtree, false));
EXPECT_FALSE(isSubtreeMatch(n4, subtree, false));
EXPECT_FALSE(isSubtreeMatch(n6, subtree, false));
EXPECT_FALSE(isSubtreeMatch(n7, subtree, false));
EXPECT_TRUE(isSubgraphMatch(n1, subtree, false));
EXPECT_TRUE(isSubgraphMatch(n2, subtree, false));
EXPECT_TRUE(isSubgraphMatch(n5, subtree, false));
EXPECT_FALSE(isSubgraphMatch(n3, subtree, false));
EXPECT_FALSE(isSubgraphMatch(n4, subtree, false));
EXPECT_FALSE(isSubgraphMatch(n6, subtree, false));
EXPECT_FALSE(isSubgraphMatch(n7, subtree, false));
}
// Test subtree matching in which * (repeated) matching of children is allowed.
@ -304,11 +304,11 @@ TEST(SubgraphMatcher, IsSubtreeMatchRepeated) {
reset();
auto subtree = Tree(any(), {Tree(Criteria("2"))});
EXPECT_FALSE(isSubtreeMatch(n1, subtree, false));
EXPECT_FALSE(isSubgraphMatch(n1, subtree, false));
reset();
subtree = Tree(any(), {Tree(Criteria("2"), {}, TestMatchNode::kStarCount)});
EXPECT_FALSE(isSubtreeMatch(n1, subtree, false));
EXPECT_FALSE(isSubgraphMatch(n1, subtree, false));
reset();
// clang-format off
@ -318,7 +318,7 @@ TEST(SubgraphMatcher, IsSubtreeMatchRepeated) {
Tree(Criteria("4"), {}, 2),
Tree(Criteria("5"), {}, 3)
});
EXPECT_TRUE(isSubtreeMatch(n1, subtree, false));
EXPECT_TRUE(isSubgraphMatch(n1, subtree, false));
reset();
subtree = Tree(any(), {
@ -328,7 +328,7 @@ TEST(SubgraphMatcher, IsSubtreeMatchRepeated) {
Tree(Criteria("5"), {}, 4)
});
// Failes because exepected 4 matches of n5 but found 3.
EXPECT_FALSE(isSubtreeMatch(n1, subtree, false));
EXPECT_FALSE(isSubgraphMatch(n1, subtree, false));
reset();
subtree = Tree(any(), {
@ -337,7 +337,7 @@ TEST(SubgraphMatcher, IsSubtreeMatchRepeated) {
Tree(Criteria("4"), {}, 2),
Tree(Criteria("5"), {}, TestMatchNode::kStarCount)
});
EXPECT_TRUE(isSubtreeMatch(n1, subtree, false));
EXPECT_TRUE(isSubgraphMatch(n1, subtree, false));
reset();
subtree = Tree(any(), {
@ -346,7 +346,7 @@ TEST(SubgraphMatcher, IsSubtreeMatchRepeated) {
Tree(Criteria("4"), {}, 2),
Tree(Criteria("5"), {}, TestMatchNode::kStarCount)
});
EXPECT_TRUE(isSubtreeMatch(n1, subtree, false));
EXPECT_TRUE(isSubgraphMatch(n1, subtree, false));
reset();
subtree = Tree(any(), {
@ -354,7 +354,7 @@ TEST(SubgraphMatcher, IsSubtreeMatchRepeated) {
Tree(Criteria("3"), {}, TestMatchNode::kStarCount),
});
// Fails because there are unmatched edges.
EXPECT_FALSE(isSubtreeMatch(n1, subtree, false));
EXPECT_FALSE(isSubgraphMatch(n1, subtree, false));
reset();
subtree = Tree(any(), {
@ -365,7 +365,7 @@ TEST(SubgraphMatcher, IsSubtreeMatchRepeated) {
});
// Fails because the count is wrong; we have 2 edges to node N4 while
// the pattern expects only 1.
EXPECT_FALSE(isSubtreeMatch(n1, subtree, false));
EXPECT_FALSE(isSubgraphMatch(n1, subtree, false));
// clang-format on
}
@ -376,7 +376,7 @@ TEST(SubgraphMatcher, DagMatching) {
auto n4match = Tree(Criteria("4"), {
Tree(Criteria("5"))
});
auto subtree = Tree(Criteria("1"), {
auto subgraph = Tree(Criteria("1"), {
Tree(Criteria("2"), {
n4match
}),
@ -409,7 +409,7 @@ TEST(SubgraphMatcher, DagMatching) {
N5
*/
EXPECT_TRUE(isSubtreeMatch(n1, subtree, false));
EXPECT_TRUE(isSubgraphMatch(n1, subgraph, false));
}
{
@ -438,7 +438,7 @@ TEST(SubgraphMatcher, DagMatching) {
*/
// This should fail because n4A and n4B are not the same node.
EXPECT_FALSE(isSubtreeMatch(n1, subtree, false));
EXPECT_FALSE(isSubgraphMatch(n1, subgraph, false));
}
}
@ -447,7 +447,7 @@ TEST(SubgraphMatcher, DagMatchingMultiEdges) {
// clang-format off
auto n2match = Tree(Criteria("2"));
auto subtree = Tree(Criteria("1"), {
auto subgraph = Tree(Criteria("1"), {
n2match,
n2match
});
@ -461,7 +461,7 @@ TEST(SubgraphMatcher, DagMatchingMultiEdges) {
graph.createEdge(n1, n2);
graph.createEdge(n1, n2);
EXPECT_TRUE(isSubtreeMatch(n1, subtree, false));
EXPECT_TRUE(isSubgraphMatch(n1, subgraph, false));
}
{
@ -473,7 +473,7 @@ TEST(SubgraphMatcher, DagMatchingMultiEdges) {
graph.createEdge(n1, n2A);
graph.createEdge(n1, n2B);
EXPECT_FALSE(isSubtreeMatch(n1, subtree, false));
EXPECT_FALSE(isSubgraphMatch(n1, subgraph, false));
}
}
@ -533,7 +533,7 @@ TEST(SubgraphMatcher, DagMatchingRandomLargeGraph) {
int countMatch = 0;
for (auto node : graph.getMutableNodes()) {
if (isSubtreeMatch(node, subtree, false)) {
if (isSubgraphMatch(node, subtree, false)) {
countMatch++;
}
}
@ -545,12 +545,12 @@ TEST(SubgraphMatcher, IsSubtreeMatchRealistic) {
auto graph = DataFlowTestGraph();
auto subtree = DataFlowTestGraphCriteria();
EXPECT_FALSE(isSubtreeMatch(graph.opF, subtree));
EXPECT_FALSE(isSubtreeMatch(graph.opC, subtree));
EXPECT_FALSE(isSubtreeMatch(graph.opB, subtree));
EXPECT_FALSE(isSubtreeMatch(graph.dataOut, subtree));
EXPECT_FALSE(isSubgraphMatch(graph.opF, subtree));
EXPECT_FALSE(isSubgraphMatch(graph.opC, subtree));
EXPECT_FALSE(isSubgraphMatch(graph.opB, subtree));
EXPECT_FALSE(isSubgraphMatch(graph.dataOut, subtree));
EXPECT_TRUE(isSubtreeMatch(graph.opG, subtree));
EXPECT_TRUE(isSubgraphMatch(graph.opG, subtree));
}
TEST(SubgraphMatcher, ReplaceSubtreeRealistic) {
@ -558,7 +558,7 @@ TEST(SubgraphMatcher, ReplaceSubtreeRealistic) {
auto graph = DataFlowTestGraph();
auto subtree = DataFlowTestGraphCriteria();
TestMatcher::replaceSubtree(
TestMatcher::replaceSubgraph(
graph.graph, subtree, [](TestGraph& g, TestGraph::NodeRef opG) {
auto opFused = g.createNode("opFused");