mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Support regex-style matching for Any and Oneof (#82853)
pseudo.any is a wildcard node that can be matched with any fx node with arbitrary number of inputs and outputs.
For example, to match relu followed by one fx node:
```
def pattern(a):
y = a.relu()
z = torch.ops.pseudo.any(y)
return z
```
pseudo.oneof is a special node that can be matched with a fx node whose target is in the permissible list.
`targets` must be be a list of qualified name for operators, e.g. ["operator.add", "torch.sigmoid",
"torch.ops.aten.foo", "torch.ops.prims.bar"]
For example, using following pattern with pseudo.oneof
```
def pattern(a):
y = a.relu()
z = torch.ops.pseudo.oneof(y, targets=["relu", "torch.sigmoid", "operator.add"])
return z
```
It will have 3 matches in the following function
```
def forward(y):
z = y.relu()
x = z.relu() # first match
x = x.relu()
x = torch.sigmoid(x) # second match
x = x.relu()
return x + 1 # third match
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82853
Approved by: https://github.com/ezyang
This commit is contained in:
parent
0cd8526b07
commit
39e6238788
2 changed files with 107 additions and 2 deletions
|
|
@ -572,7 +572,6 @@ class MultipleOutputsIdenticalAnchor:
|
|||
TestCase(False, True, 0),
|
||||
]
|
||||
|
||||
|
||||
class MultipleOutputsHorizontalPattern:
|
||||
@staticmethod
|
||||
def forward(x):
|
||||
|
|
@ -599,6 +598,61 @@ class MultipleOutputsHorizontalPattern:
|
|||
TestCase(True, True, 0)
|
||||
]
|
||||
|
||||
class PatternWithPseudoAny:
|
||||
@staticmethod
|
||||
def forward(x):
|
||||
x = x.relu()
|
||||
x = x.sigmoid()
|
||||
|
||||
y = x.relu()
|
||||
y = y + 1
|
||||
|
||||
z = y.relu()
|
||||
z = z.relu()
|
||||
|
||||
return z
|
||||
|
||||
@staticmethod
|
||||
def pattern(a):
|
||||
y = a.relu()
|
||||
z = torch.ops.pseudo.any(y)
|
||||
return z
|
||||
|
||||
test_cases = [
|
||||
# match_output, match_placeholder, num_matches
|
||||
TestCase(False, False, 3),
|
||||
TestCase(True, False, 1),
|
||||
TestCase(False, True, 1),
|
||||
TestCase(True, True, 0)
|
||||
]
|
||||
|
||||
class PatternWithPseudoOneof:
|
||||
@staticmethod
|
||||
def forward(x):
|
||||
x = x.relu()
|
||||
x = torch.sigmoid(x)
|
||||
|
||||
z = x.relu()
|
||||
z = torch.relu(z)
|
||||
|
||||
y = x.relu()
|
||||
y = y + 1
|
||||
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def pattern(a):
|
||||
y = a.relu()
|
||||
z = torch.ops.pseudo.oneof(y, targets=["torch.sigmoid", "operator.add"])
|
||||
return z
|
||||
|
||||
test_cases = [
|
||||
# match_output, match_placeholder, num_matches
|
||||
TestCase(False, False, 2),
|
||||
TestCase(True, False, 1),
|
||||
TestCase(False, True, 1),
|
||||
TestCase(True, True, 0)
|
||||
]
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class TestFXMatcherUtils(JitTestCase):
|
||||
|
|
@ -616,7 +670,9 @@ class TestFXMatcherUtils(JitTestCase):
|
|||
MultipleOutputsMultipleOverlappingMatches,
|
||||
MultipleOutputsMultipleNonOverlappingMatches,
|
||||
MultipleOutputsIdenticalAnchor,
|
||||
MultipleOutputsHorizontalPattern
|
||||
MultipleOutputsHorizontalPattern,
|
||||
PatternWithPseudoAny,
|
||||
PatternWithPseudoOneof,
|
||||
])
|
||||
def test_subgraph_matcher(self, test_model):
|
||||
traced = symbolic_trace(test_model.forward)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from dataclasses import dataclass, field
|
||||
from collections import defaultdict
|
||||
import copy
|
||||
import torch.library
|
||||
from torch.fx.graph import Graph
|
||||
from torch.fx.node import Node
|
||||
from torch.fx._compatibility import compatibility
|
||||
|
|
@ -9,6 +10,42 @@ from typing import Dict, List, Set
|
|||
__all__ = ['SubgraphMatcher', 'InternalMatch']
|
||||
|
||||
|
||||
pseudo = torch.library.Library("pseudo", "DEF")
|
||||
|
||||
pseudo.define("any() -> ()")
|
||||
"""
|
||||
pseudo.any is a wildcard node that can be matched with any fx node with arbitrary number of inputs and outputs.
|
||||
For example, to match relu followed by one fx node:
|
||||
def pattern(a):
|
||||
y = a.relu()
|
||||
z = torch.ops.pseudo.any(y)
|
||||
return z
|
||||
"""
|
||||
|
||||
pseudo.define("oneof(*, str[] targets) -> ()")
|
||||
"""
|
||||
pseudo.oneof is a special node that can be matched with a fx node whose target is in the permissible list.
|
||||
`targets` must be be a list of qualified name for operators, e.g. ["operator.add", "torch.sigmoid",
|
||||
"torch.ops.aten.foo", "torch.ops.prims.bar"]
|
||||
|
||||
For example, using following pattern with pseudo.oneof
|
||||
def pattern(a):
|
||||
y = a.relu()
|
||||
z = torch.ops.pseudo.oneof(y, targets=["relu", "torch.sigmoid", "operator.add"])
|
||||
return z
|
||||
|
||||
It will have 3 matches in the following function
|
||||
def forward(y):
|
||||
z = y.relu()
|
||||
x = z.relu() # first match
|
||||
|
||||
x = x.relu()
|
||||
x = torch.sigmoid(x) # second match
|
||||
|
||||
x = x.relu()
|
||||
return x + 1 # third match
|
||||
"""
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
@dataclass
|
||||
class InternalMatch():
|
||||
|
|
@ -80,6 +117,18 @@ class SubgraphMatcher:
|
|||
if not self.match_placeholder and pn.op == "placeholder":
|
||||
return True
|
||||
|
||||
if pn.target == torch.ops.pseudo.any:
|
||||
return True
|
||||
|
||||
if pn.target == torch.ops.pseudo.oneof:
|
||||
permissible_targets: List[str] = pn.kwargs.get("targets", list()) # type: ignore[assignment]
|
||||
assert isinstance(permissible_targets, list), \
|
||||
"pseudo.oneof(permissible_targets=[\"foo\", \"bar\"]) only accept targets as a list"
|
||||
assert len(permissible_targets) > 0, "please specific as least one target for pseudo.oneof"
|
||||
|
||||
if gn._pretty_print_target(gn.target) in permissible_targets:
|
||||
return True
|
||||
|
||||
if pn.op == gn.op:
|
||||
if pn.op == "placeholder" or pn.op == "output":
|
||||
return True
|
||||
|
|
|
|||
Loading…
Reference in a new issue