Support regex-style matching for Any and Oneof (#82853)

pseudo.any is a wildcard node that can be matched with any fx node with arbitrary number of inputs and outputs.
For example, to match relu followed by one fx node:
```
    def pattern(a):
        y = a.relu()
        z = torch.ops.pseudo.any(y)
        return z
```

pseudo.oneof is a special node that can be matched with a fx node whose target is in the permissible list.
`targets` must be be a list of qualified name for operators, e.g. ["operator.add", "torch.sigmoid",
"torch.ops.aten.foo", "torch.ops.prims.bar"]

For example, using following pattern with pseudo.oneof
```
    def pattern(a):
        y = a.relu()
        z = torch.ops.pseudo.oneof(y, targets=["relu", "torch.sigmoid", "operator.add"])
        return z
```

It will have 3 matches in the following function
```
    def forward(y):
        z = y.relu()
        x = z.relu()    # first match

        x = x.relu()
        x = torch.sigmoid(x)    # second match

        x = x.relu()
        return x + 1    # third match
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82853
Approved by: https://github.com/ezyang
This commit is contained in:
Sherlock Huang 2022-08-12 06:56:08 +00:00 committed by PyTorch MergeBot
parent 0cd8526b07
commit 39e6238788
2 changed files with 107 additions and 2 deletions

View file

@ -572,7 +572,6 @@ class MultipleOutputsIdenticalAnchor:
TestCase(False, True, 0),
]
class MultipleOutputsHorizontalPattern:
@staticmethod
def forward(x):
@ -599,6 +598,61 @@ class MultipleOutputsHorizontalPattern:
TestCase(True, True, 0)
]
class PatternWithPseudoAny:
@staticmethod
def forward(x):
x = x.relu()
x = x.sigmoid()
y = x.relu()
y = y + 1
z = y.relu()
z = z.relu()
return z
@staticmethod
def pattern(a):
y = a.relu()
z = torch.ops.pseudo.any(y)
return z
test_cases = [
# match_output, match_placeholder, num_matches
TestCase(False, False, 3),
TestCase(True, False, 1),
TestCase(False, True, 1),
TestCase(True, True, 0)
]
class PatternWithPseudoOneof:
@staticmethod
def forward(x):
x = x.relu()
x = torch.sigmoid(x)
z = x.relu()
z = torch.relu(z)
y = x.relu()
y = y + 1
return y
@staticmethod
def pattern(a):
y = a.relu()
z = torch.ops.pseudo.oneof(y, targets=["torch.sigmoid", "operator.add"])
return z
test_cases = [
# match_output, match_placeholder, num_matches
TestCase(False, False, 2),
TestCase(True, False, 1),
TestCase(False, True, 1),
TestCase(True, True, 0)
]
@instantiate_parametrized_tests
class TestFXMatcherUtils(JitTestCase):
@ -616,7 +670,9 @@ class TestFXMatcherUtils(JitTestCase):
MultipleOutputsMultipleOverlappingMatches,
MultipleOutputsMultipleNonOverlappingMatches,
MultipleOutputsIdenticalAnchor,
MultipleOutputsHorizontalPattern
MultipleOutputsHorizontalPattern,
PatternWithPseudoAny,
PatternWithPseudoOneof,
])
def test_subgraph_matcher(self, test_model):
traced = symbolic_trace(test_model.forward)

View file

@ -1,6 +1,7 @@
from dataclasses import dataclass, field
from collections import defaultdict
import copy
import torch.library
from torch.fx.graph import Graph
from torch.fx.node import Node
from torch.fx._compatibility import compatibility
@ -9,6 +10,42 @@ from typing import Dict, List, Set
__all__ = ['SubgraphMatcher', 'InternalMatch']
pseudo = torch.library.Library("pseudo", "DEF")
pseudo.define("any() -> ()")
"""
pseudo.any is a wildcard node that can be matched with any fx node with arbitrary number of inputs and outputs.
For example, to match relu followed by one fx node:
def pattern(a):
y = a.relu()
z = torch.ops.pseudo.any(y)
return z
"""
pseudo.define("oneof(*, str[] targets) -> ()")
"""
pseudo.oneof is a special node that can be matched with a fx node whose target is in the permissible list.
`targets` must be be a list of qualified name for operators, e.g. ["operator.add", "torch.sigmoid",
"torch.ops.aten.foo", "torch.ops.prims.bar"]
For example, using following pattern with pseudo.oneof
def pattern(a):
y = a.relu()
z = torch.ops.pseudo.oneof(y, targets=["relu", "torch.sigmoid", "operator.add"])
return z
It will have 3 matches in the following function
def forward(y):
z = y.relu()
x = z.relu() # first match
x = x.relu()
x = torch.sigmoid(x) # second match
x = x.relu()
return x + 1 # third match
"""
@compatibility(is_backward_compatible=False)
@dataclass
class InternalMatch():
@ -80,6 +117,18 @@ class SubgraphMatcher:
if not self.match_placeholder and pn.op == "placeholder":
return True
if pn.target == torch.ops.pseudo.any:
return True
if pn.target == torch.ops.pseudo.oneof:
permissible_targets: List[str] = pn.kwargs.get("targets", list()) # type: ignore[assignment]
assert isinstance(permissible_targets, list), \
"pseudo.oneof(permissible_targets=[\"foo\", \"bar\"]) only accept targets as a list"
assert len(permissible_targets) > 0, "please specific as least one target for pseudo.oneof"
if gn._pretty_print_target(gn.target) in permissible_targets:
return True
if pn.op == gn.op:
if pn.op == "placeholder" or pn.op == "output":
return True