From 39e623878872ee365400a1d1534b1d85a5b1bdb1 Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Fri, 12 Aug 2022 06:56:08 +0000 Subject: [PATCH] Support regex-style matching for Any and Oneof (#82853) pseudo.any is a wildcard node that can be matched with any fx node with arbitrary number of inputs and outputs. For example, to match relu followed by one fx node: ``` def pattern(a): y = a.relu() z = torch.ops.pseudo.any(y) return z ``` pseudo.oneof is a special node that can be matched with a fx node whose target is in the permissible list. `targets` must be be a list of qualified name for operators, e.g. ["operator.add", "torch.sigmoid", "torch.ops.aten.foo", "torch.ops.prims.bar"] For example, using following pattern with pseudo.oneof ``` def pattern(a): y = a.relu() z = torch.ops.pseudo.oneof(y, targets=["relu", "torch.sigmoid", "operator.add"]) return z ``` It will have 3 matches in the following function ``` def forward(y): z = y.relu() x = z.relu() # first match x = x.relu() x = torch.sigmoid(x) # second match x = x.relu() return x + 1 # third match ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/82853 Approved by: https://github.com/ezyang --- test/test_fx_passes.py | 60 +++++++++++++++++++++++++- torch/fx/passes/utils/matcher_utils.py | 49 +++++++++++++++++++++ 2 files changed, 107 insertions(+), 2 deletions(-) diff --git a/test/test_fx_passes.py b/test/test_fx_passes.py index af02a5a2e83..2ff1959fe21 100644 --- a/test/test_fx_passes.py +++ b/test/test_fx_passes.py @@ -572,7 +572,6 @@ class MultipleOutputsIdenticalAnchor: TestCase(False, True, 0), ] - class MultipleOutputsHorizontalPattern: @staticmethod def forward(x): @@ -599,6 +598,61 @@ class MultipleOutputsHorizontalPattern: TestCase(True, True, 0) ] +class PatternWithPseudoAny: + @staticmethod + def forward(x): + x = x.relu() + x = x.sigmoid() + + y = x.relu() + y = y + 1 + + z = y.relu() + z = z.relu() + + return z + + @staticmethod + def pattern(a): + y = a.relu() + z = torch.ops.pseudo.any(y) + return z + + test_cases = [ + # match_output, match_placeholder, num_matches + TestCase(False, False, 3), + TestCase(True, False, 1), + TestCase(False, True, 1), + TestCase(True, True, 0) + ] + +class PatternWithPseudoOneof: + @staticmethod + def forward(x): + x = x.relu() + x = torch.sigmoid(x) + + z = x.relu() + z = torch.relu(z) + + y = x.relu() + y = y + 1 + + return y + + @staticmethod + def pattern(a): + y = a.relu() + z = torch.ops.pseudo.oneof(y, targets=["torch.sigmoid", "operator.add"]) + return z + + test_cases = [ + # match_output, match_placeholder, num_matches + TestCase(False, False, 2), + TestCase(True, False, 1), + TestCase(False, True, 1), + TestCase(True, True, 0) + ] @instantiate_parametrized_tests class TestFXMatcherUtils(JitTestCase): @@ -616,7 +670,9 @@ class TestFXMatcherUtils(JitTestCase): MultipleOutputsMultipleOverlappingMatches, MultipleOutputsMultipleNonOverlappingMatches, MultipleOutputsIdenticalAnchor, - MultipleOutputsHorizontalPattern + MultipleOutputsHorizontalPattern, + PatternWithPseudoAny, + PatternWithPseudoOneof, ]) def test_subgraph_matcher(self, test_model): traced = symbolic_trace(test_model.forward) diff --git a/torch/fx/passes/utils/matcher_utils.py b/torch/fx/passes/utils/matcher_utils.py index 31ae96a47fe..13d34339882 100644 --- a/torch/fx/passes/utils/matcher_utils.py +++ b/torch/fx/passes/utils/matcher_utils.py @@ -1,6 +1,7 @@ from dataclasses import dataclass, field from collections import defaultdict import copy +import torch.library from torch.fx.graph import Graph from torch.fx.node import Node from torch.fx._compatibility import compatibility @@ -9,6 +10,42 @@ from typing import Dict, List, Set __all__ = ['SubgraphMatcher', 'InternalMatch'] +pseudo = torch.library.Library("pseudo", "DEF") + +pseudo.define("any() -> ()") +""" +pseudo.any is a wildcard node that can be matched with any fx node with arbitrary number of inputs and outputs. +For example, to match relu followed by one fx node: + def pattern(a): + y = a.relu() + z = torch.ops.pseudo.any(y) + return z +""" + +pseudo.define("oneof(*, str[] targets) -> ()") +""" +pseudo.oneof is a special node that can be matched with a fx node whose target is in the permissible list. +`targets` must be be a list of qualified name for operators, e.g. ["operator.add", "torch.sigmoid", +"torch.ops.aten.foo", "torch.ops.prims.bar"] + +For example, using following pattern with pseudo.oneof + def pattern(a): + y = a.relu() + z = torch.ops.pseudo.oneof(y, targets=["relu", "torch.sigmoid", "operator.add"]) + return z + +It will have 3 matches in the following function + def forward(y): + z = y.relu() + x = z.relu() # first match + + x = x.relu() + x = torch.sigmoid(x) # second match + + x = x.relu() + return x + 1 # third match +""" + @compatibility(is_backward_compatible=False) @dataclass class InternalMatch(): @@ -80,6 +117,18 @@ class SubgraphMatcher: if not self.match_placeholder and pn.op == "placeholder": return True + if pn.target == torch.ops.pseudo.any: + return True + + if pn.target == torch.ops.pseudo.oneof: + permissible_targets: List[str] = pn.kwargs.get("targets", list()) # type: ignore[assignment] + assert isinstance(permissible_targets, list), \ + "pseudo.oneof(permissible_targets=[\"foo\", \"bar\"]) only accept targets as a list" + assert len(permissible_targets) > 0, "please specific as least one target for pseudo.oneof" + + if gn._pretty_print_target(gn.target) in permissible_targets: + return True + if pn.op == gn.op: if pn.op == "placeholder" or pn.op == "output": return True