Add Python serialization to Pattern Matcher patterns (#108894)

Adds a Python Pretty Printer to the pattern matcher that serializes patterns as python. Generating our fuse attention patterns was taking 4 seconds of compile time, which will only get worse as we add more variants (which I will do in the rest of this stack). To write out patterns, build pytorch, then run `gen_attention_patterns.py`.

Since there is a line limit for PRs  i'm only including the _sdpa_pattern1 in this first diff.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108894
Approved by: https://github.com/yanboliang
ghstack dependencies: #109663
This commit is contained in:
eellison 2023-09-19 17:08:09 -07:00 committed by PyTorch MergeBot
parent 1a5e0edf56
commit 16d608d70d
9 changed files with 476 additions and 35 deletions

View file

@ -16,6 +16,7 @@ exclude_patterns = [
'functorch/docs/**',
'functorch/examples/**',
'functorch/notebooks/**',
'torch/_inductor/fx_passes/serialized_patterns/**',
'scripts/**',
'test/generated_type_hints_smoketest.py',
# Tests from the NumPy test suite
@ -195,6 +196,7 @@ include_patterns = [
]
exclude_patterns = [
'**/fb/**',
'torch/_inductor/fx_passes/serialized_patterns/**',
'torch/_inductor/index_propagation.py',
'torch/_inductor/ir.py',
'torch/_inductor/scheduler.py',
@ -980,6 +982,7 @@ exclude_patterns = [
'**/fb/**',
'third_party/**/*.py',
'third_party/**/*.pyi',
'torch/_inductor/fx_passes/serialized_patterns/**',
# These files are all grandfathered in, feel free to remove from this list
# as necessary
'test/_nvfuser/__init__.py',
@ -2651,6 +2654,7 @@ exclude_patterns = [
'caffe2/**',
'functorch/docs/**',
'functorch/notebooks/**',
'torch/_inductor/fx_passes/serialized_patterns/**',
'scripts/**',
'third_party/**',
'fb/**',

View file

@ -7,7 +7,6 @@ import torch._inductor.config
import torch.utils.checkpoint
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.utils import counters
from torch._inductor import config
from torch._inductor.utils import run_and_get_code
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FUSED_ATTENTION,
@ -24,7 +23,6 @@ def checkpoint_wrapper(fn):
return inner
@config.patch(fallback_random=True)
class TestSDPAPatternRewriterTemplate(TestCase):
def _clone_inputs(self, inputs):
def clone(x):

View file

@ -8,6 +8,15 @@ import torch._inductor.config as inductor_config
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.utils import count_calls, counters
from torch._inductor.fx_passes import joint_graph
from torch._inductor.fx_passes.serialized_patterns.central_index import (
get_serialized_pattern,
)
from torch._inductor.pattern_matcher import (
_TargetExpr,
gen_pattern,
PatternExpr,
PatternPrettyPrinter,
)
from torch._inductor.utils import run_and_get_code
from torch.testing import FileCheck
from torch.testing._internal.common_cuda import SM80OrLater
@ -776,6 +785,66 @@ class TestPaternMatcher(TestCase):
_, (code) = run_and_get_code(fn2, args[0], args[1], args[2])
FileCheck().check_not("extern_kernels.addmm(").run(code[0])
def test_fuse_attention_roundtrip_pattern(self):
# are we losing anything in serialization
from torch._inductor.fx_passes.fuse_attention import _get_sfdp_patterns
global_vals = {
"aten": torch.ops.aten,
"prims": torch.ops.prims,
"torch": torch,
}
for name in dir(torch._inductor.pattern_matcher):
attr = getattr(torch._inductor.pattern_matcher, name)
if isinstance(attr, type) and issubclass(attr, (PatternExpr, _TargetExpr)):
global_vals[name] = attr
with torch._subclasses.FakeTensorMode():
for _, kwargs in _get_sfdp_patterns():
gen_kwargs = {
key: kwargs[key]
for key in (
"search_fn",
"example_inputs",
"trace_fn",
"scalar_workaround",
)
}
pattern = gen_pattern(**gen_kwargs)
pattern_pp = PatternPrettyPrinter.run(pattern)
env = global_vals.copy()
exec(pattern_pp, env)
pattern_2 = env["output"]
self.assertEqual(pattern_pp, PatternPrettyPrinter.run(pattern_2))
def test_fuse_attention_all_patterns_serialized(self):
from torch._inductor.fx_passes.fuse_attention import _get_sfdp_patterns
with torch._subclasses.FakeTensorMode():
for key, kwargs in _get_sfdp_patterns():
gen_kwargs = {
key: kwargs[key]
for key in (
"search_fn",
"example_inputs",
"trace_fn",
"scalar_workaround",
)
}
pattern = gen_pattern(**gen_kwargs)
pattern_pp = PatternPrettyPrinter.run(pattern)
search_fn_pattern = get_serialized_pattern(key)
if search_fn_pattern is None:
continue
self.assertEqual(
pattern_pp,
PatternPrettyPrinter.run(search_fn_pattern),
msg=f"Found mismatched pattern {key}. Run gen_attention_patterns.py",
)
if __name__ == "__main__":
if IS_LINUX and HAS_CUDA:

View file

@ -366,8 +366,7 @@ def partialize_and_update_signature(func, **kwargs):
return wrapper
@functools.lru_cache(None)
def _sfdp_init():
def _get_sfdp_patterns():
from .joint_graph import patterns
if torch.cuda.is_available():
@ -474,31 +473,46 @@ def _sfdp_init():
_sfdp_scale_factor_check(aten.div.Tensor),
),
]:
training_args = [*args, *workaround.values()] # type: ignore[attr-defined]
register_replacement(
pattern,
replacement,
training_args,
training_graph,
patterns,
extra_check=extra_check,
scalar_workaround=workaround,
)
# XXX: when adding a new pattern, re-run `gen_attention_patterns` so the pattern
# gets serialized to a python file and does not require tracing at runtime.
assert isinstance(workaround, dict)
training_args = [*args, *workaround.values()]
name = pattern.__name__
yield f"{name}_training", {
"search_fn": pattern,
"replace_fn": replacement,
"example_inputs": training_args,
"trace_fn": training_graph,
"pass_dict": patterns,
"extra_check": extra_check,
"scalar_workaround": workaround,
}
if workaround:
assert isinstance(workaround, dict)
assert len(workaround) == 1 and "dropout_p" in workaround
# functools.partial insufficient because we look at signature downstream
pattern = partialize_and_update_signature(pattern, dropout_p=0.0)
replacement = partialize_and_update_signature(replacement, dropout_p=0.0)
workaround = {}
yield f"{name}_inference", {
"search_fn": pattern,
"replace_fn": replacement,
"example_inputs": args,
"trace_fn": inference_graph,
"pass_dict": patterns,
"extra_check": extra_check,
"scalar_workaround": workaround,
}
@functools.lru_cache(None)
def _sfdp_init():
from .serialized_patterns.central_index import get_serialized_pattern
for key, register_replacement_kwargs in _get_sfdp_patterns():
search_fn_pattern = get_serialized_pattern(key)
register_replacement(
pattern,
replacement,
args,
inference_graph,
patterns,
extra_check=extra_check,
scalar_workaround=workaround,
**register_replacement_kwargs, search_fn_pattern=search_fn_pattern
)

View file

@ -0,0 +1,95 @@
# noqa: F401, E501
# This is an auto-generated file. Please do not modify it by hand.
# To re-generate, run:
# cd ~/pytorch && python
# torchgen/fuse_attention_patterns/gen_attention_patterns.py
import torch
import torch._inductor
aten = torch.ops.aten
prims = torch.ops.prims
from torch._inductor.pattern_matcher import (
Arg,
CallFunction,
CallFunctionVarArgs,
CallMethod,
CallMethodVarArgs,
ExclusiveKeywordArg,
Ignored,
KeywordArg,
ListOf,
MultiOutputPattern,
PatternExpr,
RepeatedExpr,
_TargetArgsExpr,
_TargetExpr,
_TargetExprVarArgs,
)
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2)
amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2)
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
mul_Tensor_1 = CallFunction(aten.mul.Tensor, div_Tensor_1, sum_dim_IntList_1)
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, KeywordArg('inv_scale'))
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
_sfdp_pattern_1_training = MultiOutputPattern([view_default_5,
view_default_9,
permute_default_4,
view_default_11,
None
])
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
view_default = CallFunction(aten.view.default, expand_default, Ignored())
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2)
amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
_sfdp_pattern_1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())

View file

@ -0,0 +1,22 @@
# This is an auto-generated file. Please do not modify it by hand.
# To re-generate, run:
# cd ~/pytorch && python
# torchgen/fuse_attention_patterns/gen_attention_patterns.py
from ._sfdp_pattern_1 import (_sfdp_pattern_1_training, _sfdp_pattern_1_inference)
central_index = {
'_sfdp_pattern_1_training': _sfdp_pattern_1_training,
'_sfdp_pattern_1_inference': _sfdp_pattern_1_inference,
}
def get_serialized_pattern(key):
import torch._inductor # noqa: F401
from torch._inductor import config
if config.fallback_random:
return None
# TODO - could add more validation that the same set of decomps used when
# tracing SDPA are also used in current context. softmax, dropout, etc
# decomp use is stable so not an issue in practice.
return central_index.get(key)

View file

@ -204,6 +204,9 @@ class Ignored(PatternExpr):
def __repr__(self):
return "*"
def pretty_print(self, pp: "PatternPrettyPrinter"):
return "Ignored()"
class KeywordArg(PatternExpr):
"""
@ -262,11 +265,15 @@ class _TargetExpr(PatternExpr):
self.users = users
def fns_repr(self):
return (
f"[{self.fns[0].__name__}, ...]"
if len(self.fns) > 1
else self.fns[0].__name__
)
fn_name = self.fns[0].__name__
if len(self.fns) > 1:
return f"[{fn_name}, ...]"
elif self.fns[0] is getattr(torch, fn_name, None):
return f"torch.{fn_name}"
elif isinstance(self.fns[0], torch._ops.OpOverload):
return str(self.fns[0])
else:
return self.fns[0].__name__
def __repr__(self):
return f"{self.__class__.__name__}({self.fns_repr()})"
@ -338,6 +345,18 @@ class _TargetArgsExpr(_TargetExpr):
]
return f"{self.__class__.__name__}({', '.join(args)})"
def pretty_print(self, pp: "PatternPrettyPrinter"):
args = [
self.fns_repr(),
*(pp.pretty_print(x) for x in self.args),
*[f"{k}={pp.pretty_print(v)}" for k, v in self.kwargs.items()],
]
if self.users > 1:
args.append(f"_users={self.users}")
joiner_str = ", "
return f"{self.__class__.__name__}({joiner_str.join(args)})"
def _match(self, node: torch.fx.Node, ctx: MatchContext):
if (
not self._match_fns(node)
@ -488,6 +507,13 @@ class MultiOutputPattern(PatternExpr):
def __repr__(self):
return f"{self.__class__.__name__}({self.outputs})"
def pretty_print(self, pp: "PatternPrettyPrinter"):
args = [pp.pretty_print(x) for x in self.outputs]
joiner_str = f",\n{' '}"
str_out = f"{self.__class__.__name__}([{joiner_str.join(args)}"
str_out = f"{str_out}\n])"
return str_out
def _match(self, node: torch.fx.Node, ctx: MatchContext):
m = ctx.match(self.outputs[0], node)
if not m:
@ -553,6 +579,58 @@ class RepeatedExpr(PatternExpr):
return m
class PatternPrettyPrinter:
"""
Serializes Patterns to executable python.
XXX: currently only used and tested for fuse attention patterns. May not cover
all patterns.
"""
def __init__(self):
self.namespace = torch.fx.graph._Namespace()
self.memoized_objs_names: Dict[PatternExpr, str] = {}
self.memoized_objs_pp: Dict[PatternExpr, str] = {}
@staticmethod
def run(obj: PatternExpr, output_name="output"):
"""
Serializes obj to python code with obj written out to `output_name`
"""
pp = PatternPrettyPrinter()
out_str = obj.pretty_print(pp=pp)
output = []
for key in pp.memoized_objs_names:
output.append(f"{pp.memoized_objs_names[key]} = {pp.memoized_objs_pp[key]}")
output.append(f"{output_name} = {out_str}")
return "\n".join(output)
def pretty_print(self, obj):
if isinstance(obj, _TargetArgsExpr):
if memoized_name := self.memoized_objs_names.get(obj):
return memoized_name
else:
return self.memoize(obj)
if hasattr(obj, "pretty_print"):
return obj.pretty_print(self)
return repr(obj)
def memoize(self, obj):
obj_str = obj.pretty_print(self)
obj_name = obj.fns_repr()
for prefix in ("aten.", "torch.", "prims."):
obj_name = obj_name.replace(prefix, "")
tmp_name = self.namespace.create_name(obj_name, None)
self.memoized_objs_names[obj] = tmp_name
self.memoized_objs_pp[obj] = obj_str
return tmp_name
@dataclasses.dataclass
class PatternEntry:
pattern: PatternExpr
@ -706,6 +784,7 @@ def register_replacement(
extra_check=_return_true,
scalar_workaround=(),
exclusive_arg_names=(),
search_fn_pattern=None,
):
"""
Create a replacement rule based on example functions that get traced
@ -772,14 +851,17 @@ def register_replacement(
requires_grad = [
isinstance(x, torch.Tensor) and x.requires_grad for x in example_inputs
]
search_gm = trace_fn(search_fn, example_inputs)
pattern = fx_to_pattern(
search_gm,
ignore_types=(int, float, list, torch.device, torch.dtype),
argnames=argnames,
scalar_workaround=scalar_workaround,
exclusive_arg_names=exclusive_arg_names,
)
if search_fn_pattern is None:
pattern = gen_pattern(
search_fn,
example_inputs,
trace_fn,
scalar_workaround,
exclusive_arg_names,
)
else:
pattern = search_fn_pattern
assert repr(pattern) not in _seen_patterns
_seen_patterns.add(repr(pattern))
pattern = ReplacementPatternEntry(
@ -790,6 +872,21 @@ def register_replacement(
pattern.register(pass_dict)
@functorch_config.patch(functionalize_rng_ops=False)
def gen_pattern(
search_fn, example_inputs, trace_fn, scalar_workaround=(), exclusive_arg_names=()
) -> PatternExpr:
argnames = [*inspect.signature(search_fn).parameters.keys()]
search_gm = trace_fn(search_fn, example_inputs)
return fx_to_pattern(
search_gm,
ignore_types=(int, float, list, torch.device, torch.dtype),
argnames=argnames,
scalar_workaround=scalar_workaround,
exclusive_arg_names=exclusive_arg_names,
)
def register_lowering_pattern(
pattern, extra_check=_return_true, *, pass_dict, prepend=False
):

View file

@ -0,0 +1,142 @@
#!/usr/bin/env python3
import os
import shutil
from collections import defaultdict
from pathlib import Path
import torch._inductor
from torch._inductor.fx_passes.fuse_attention import _get_sfdp_patterns
from torch._inductor.pattern_matcher import (
_TargetExpr,
gen_pattern,
PatternExpr,
PatternPrettyPrinter,
)
auto_generated_msg = """# This is an auto-generated file. Please do not modify it by hand.
# To re-generate, run:
# cd ~/pytorch && python
# torchgen/fuse_attention_patterns/gen_attention_patterns.py
"""
def get_file_template() -> str:
file_template = f"""# noqa: F401, E501
{auto_generated_msg}
import torch
import torch._inductor
aten = torch.ops.aten
prims = torch.ops.prims
"""
pattern_matcher_imports = []
for name in dir(torch._inductor.pattern_matcher):
attr = getattr(torch._inductor.pattern_matcher, name)
if isinstance(attr, type) and issubclass(attr, (PatternExpr, _TargetExpr)):
pattern_matcher_imports.append(name)
formatted_imports = ",\n ".join(pattern_matcher_imports)
formatted_imports = (
f"from torch._inductor.pattern_matcher import (\n {formatted_imports},\n)\n"
)
return f"{file_template}{formatted_imports}"
def get_central_index_epilogue() -> str:
epilogue = """
def get_serialized_pattern(key):
import torch._inductor # noqa: F401
from torch._inductor import config
if config.fallback_random:
return None
# TODO - could add more validation that the same set of decomps used when
# tracing SDPA are also used in current context. softmax, dropout, etc
# decomp use is stable so not an issue in practice.
return central_index.get(key)
"""
return epilogue
def clean_directory(dir: Path) -> None:
for filename in os.listdir(dir):
file_path = os.path.join(dir, filename)
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
def serialize_functions() -> None:
file_path = Path.cwd() / "torch" / "_inductor" / "fx_passes" / "serialized_patterns"
if not file_path.exists():
raise Exception(
"Could not find serialized patterns directory. Make sure you are at Pytorch root directory"
)
clean_directory(file_path)
with open(file_path / "__init__.py", "w"):
pass
central_index = {}
file_to_keys = defaultdict(list)
seen_patterns = set()
file_template = get_file_template()
for i, (
key,
kwargs,
) in enumerate(_get_sfdp_patterns()):
pattern_name = kwargs["search_fn"].__name__
gen_kwargs = {
key: kwargs[key]
for key in ("search_fn", "example_inputs", "trace_fn", "scalar_workaround")
}
# temporary to batch adding new patterns
if i >= 2:
continue
from torch._functorch import config as functorch_config
with functorch_config.patch(functionalize_rng_ops=False):
pattern = gen_pattern(**gen_kwargs)
serialized_pattern = PatternPrettyPrinter.run(pattern, output_name=key)
if pattern_name not in seen_patterns:
write_mode = "w"
seen_patterns.add(pattern_name)
else:
write_mode = "a"
with open(file_path / f"{pattern_name}.py", write_mode) as f:
if write_mode == "w":
f.write(file_template)
else:
f.write("\n\n")
f.write(serialized_pattern)
f.write("\n")
central_index[f"{key}"] = f"{pattern_name}.py"
file_to_keys[pattern_name].append(f"{key}")
with open(file_path / "central_index.py", "w") as f:
f.write(auto_generated_msg)
for pattern_name, keys in file_to_keys.items():
f.write(f"from .{pattern_name} import ({', '.join(keys)})\n")
f.write("\ncentral_index = {\n")
for k in central_index.keys():
f.write(f" '{k}': {k},\n")
f.write("}\n\n")
f.write(get_central_index_epilogue())
if __name__ == "__main__":
with torch._subclasses.FakeTensorMode():
serialize_functions()