mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add Python serialization to Pattern Matcher patterns (#108894)
Adds a Python Pretty Printer to the pattern matcher that serializes patterns as python. Generating our fuse attention patterns was taking 4 seconds of compile time, which will only get worse as we add more variants (which I will do in the rest of this stack). To write out patterns, build pytorch, then run `gen_attention_patterns.py`. Since there is a line limit for PRs i'm only including the _sdpa_pattern1 in this first diff. Pull Request resolved: https://github.com/pytorch/pytorch/pull/108894 Approved by: https://github.com/yanboliang ghstack dependencies: #109663
This commit is contained in:
parent
1a5e0edf56
commit
16d608d70d
9 changed files with 476 additions and 35 deletions
|
|
@ -16,6 +16,7 @@ exclude_patterns = [
|
|||
'functorch/docs/**',
|
||||
'functorch/examples/**',
|
||||
'functorch/notebooks/**',
|
||||
'torch/_inductor/fx_passes/serialized_patterns/**',
|
||||
'scripts/**',
|
||||
'test/generated_type_hints_smoketest.py',
|
||||
# Tests from the NumPy test suite
|
||||
|
|
@ -195,6 +196,7 @@ include_patterns = [
|
|||
]
|
||||
exclude_patterns = [
|
||||
'**/fb/**',
|
||||
'torch/_inductor/fx_passes/serialized_patterns/**',
|
||||
'torch/_inductor/index_propagation.py',
|
||||
'torch/_inductor/ir.py',
|
||||
'torch/_inductor/scheduler.py',
|
||||
|
|
@ -980,6 +982,7 @@ exclude_patterns = [
|
|||
'**/fb/**',
|
||||
'third_party/**/*.py',
|
||||
'third_party/**/*.pyi',
|
||||
'torch/_inductor/fx_passes/serialized_patterns/**',
|
||||
# These files are all grandfathered in, feel free to remove from this list
|
||||
# as necessary
|
||||
'test/_nvfuser/__init__.py',
|
||||
|
|
@ -2651,6 +2654,7 @@ exclude_patterns = [
|
|||
'caffe2/**',
|
||||
'functorch/docs/**',
|
||||
'functorch/notebooks/**',
|
||||
'torch/_inductor/fx_passes/serialized_patterns/**',
|
||||
'scripts/**',
|
||||
'third_party/**',
|
||||
'fb/**',
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ import torch._inductor.config
|
|||
import torch.utils.checkpoint
|
||||
from torch._dynamo.test_case import run_tests, TestCase
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor import config
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
from torch.testing._internal.common_cuda import (
|
||||
PLATFORM_SUPPORTS_FUSED_ATTENTION,
|
||||
|
|
@ -24,7 +23,6 @@ def checkpoint_wrapper(fn):
|
|||
return inner
|
||||
|
||||
|
||||
@config.patch(fallback_random=True)
|
||||
class TestSDPAPatternRewriterTemplate(TestCase):
|
||||
def _clone_inputs(self, inputs):
|
||||
def clone(x):
|
||||
|
|
|
|||
|
|
@ -8,6 +8,15 @@ import torch._inductor.config as inductor_config
|
|||
from torch._dynamo.test_case import run_tests, TestCase
|
||||
from torch._dynamo.utils import count_calls, counters
|
||||
from torch._inductor.fx_passes import joint_graph
|
||||
from torch._inductor.fx_passes.serialized_patterns.central_index import (
|
||||
get_serialized_pattern,
|
||||
)
|
||||
from torch._inductor.pattern_matcher import (
|
||||
_TargetExpr,
|
||||
gen_pattern,
|
||||
PatternExpr,
|
||||
PatternPrettyPrinter,
|
||||
)
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_cuda import SM80OrLater
|
||||
|
|
@ -776,6 +785,66 @@ class TestPaternMatcher(TestCase):
|
|||
_, (code) = run_and_get_code(fn2, args[0], args[1], args[2])
|
||||
FileCheck().check_not("extern_kernels.addmm(").run(code[0])
|
||||
|
||||
def test_fuse_attention_roundtrip_pattern(self):
|
||||
# are we losing anything in serialization
|
||||
from torch._inductor.fx_passes.fuse_attention import _get_sfdp_patterns
|
||||
|
||||
global_vals = {
|
||||
"aten": torch.ops.aten,
|
||||
"prims": torch.ops.prims,
|
||||
"torch": torch,
|
||||
}
|
||||
|
||||
for name in dir(torch._inductor.pattern_matcher):
|
||||
attr = getattr(torch._inductor.pattern_matcher, name)
|
||||
if isinstance(attr, type) and issubclass(attr, (PatternExpr, _TargetExpr)):
|
||||
global_vals[name] = attr
|
||||
|
||||
with torch._subclasses.FakeTensorMode():
|
||||
for _, kwargs in _get_sfdp_patterns():
|
||||
gen_kwargs = {
|
||||
key: kwargs[key]
|
||||
for key in (
|
||||
"search_fn",
|
||||
"example_inputs",
|
||||
"trace_fn",
|
||||
"scalar_workaround",
|
||||
)
|
||||
}
|
||||
pattern = gen_pattern(**gen_kwargs)
|
||||
pattern_pp = PatternPrettyPrinter.run(pattern)
|
||||
env = global_vals.copy()
|
||||
exec(pattern_pp, env)
|
||||
pattern_2 = env["output"]
|
||||
self.assertEqual(pattern_pp, PatternPrettyPrinter.run(pattern_2))
|
||||
|
||||
def test_fuse_attention_all_patterns_serialized(self):
|
||||
from torch._inductor.fx_passes.fuse_attention import _get_sfdp_patterns
|
||||
|
||||
with torch._subclasses.FakeTensorMode():
|
||||
for key, kwargs in _get_sfdp_patterns():
|
||||
gen_kwargs = {
|
||||
key: kwargs[key]
|
||||
for key in (
|
||||
"search_fn",
|
||||
"example_inputs",
|
||||
"trace_fn",
|
||||
"scalar_workaround",
|
||||
)
|
||||
}
|
||||
pattern = gen_pattern(**gen_kwargs)
|
||||
pattern_pp = PatternPrettyPrinter.run(pattern)
|
||||
|
||||
search_fn_pattern = get_serialized_pattern(key)
|
||||
if search_fn_pattern is None:
|
||||
continue
|
||||
|
||||
self.assertEqual(
|
||||
pattern_pp,
|
||||
PatternPrettyPrinter.run(search_fn_pattern),
|
||||
msg=f"Found mismatched pattern {key}. Run gen_attention_patterns.py",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if IS_LINUX and HAS_CUDA:
|
||||
|
|
|
|||
|
|
@ -366,8 +366,7 @@ def partialize_and_update_signature(func, **kwargs):
|
|||
return wrapper
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _sfdp_init():
|
||||
def _get_sfdp_patterns():
|
||||
from .joint_graph import patterns
|
||||
|
||||
if torch.cuda.is_available():
|
||||
|
|
@ -474,31 +473,46 @@ def _sfdp_init():
|
|||
_sfdp_scale_factor_check(aten.div.Tensor),
|
||||
),
|
||||
]:
|
||||
training_args = [*args, *workaround.values()] # type: ignore[attr-defined]
|
||||
register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
training_args,
|
||||
training_graph,
|
||||
patterns,
|
||||
extra_check=extra_check,
|
||||
scalar_workaround=workaround,
|
||||
)
|
||||
# XXX: when adding a new pattern, re-run `gen_attention_patterns` so the pattern
|
||||
# gets serialized to a python file and does not require tracing at runtime.
|
||||
assert isinstance(workaround, dict)
|
||||
training_args = [*args, *workaround.values()]
|
||||
name = pattern.__name__
|
||||
|
||||
yield f"{name}_training", {
|
||||
"search_fn": pattern,
|
||||
"replace_fn": replacement,
|
||||
"example_inputs": training_args,
|
||||
"trace_fn": training_graph,
|
||||
"pass_dict": patterns,
|
||||
"extra_check": extra_check,
|
||||
"scalar_workaround": workaround,
|
||||
}
|
||||
|
||||
if workaround:
|
||||
assert isinstance(workaround, dict)
|
||||
assert len(workaround) == 1 and "dropout_p" in workaround
|
||||
# functools.partial insufficient because we look at signature downstream
|
||||
pattern = partialize_and_update_signature(pattern, dropout_p=0.0)
|
||||
replacement = partialize_and_update_signature(replacement, dropout_p=0.0)
|
||||
workaround = {}
|
||||
|
||||
yield f"{name}_inference", {
|
||||
"search_fn": pattern,
|
||||
"replace_fn": replacement,
|
||||
"example_inputs": args,
|
||||
"trace_fn": inference_graph,
|
||||
"pass_dict": patterns,
|
||||
"extra_check": extra_check,
|
||||
"scalar_workaround": workaround,
|
||||
}
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _sfdp_init():
|
||||
from .serialized_patterns.central_index import get_serialized_pattern
|
||||
|
||||
for key, register_replacement_kwargs in _get_sfdp_patterns():
|
||||
search_fn_pattern = get_serialized_pattern(key)
|
||||
register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
args,
|
||||
inference_graph,
|
||||
patterns,
|
||||
extra_check=extra_check,
|
||||
scalar_workaround=workaround,
|
||||
**register_replacement_kwargs, search_fn_pattern=search_fn_pattern
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,95 @@
|
|||
# noqa: F401, E501
|
||||
# This is an auto-generated file. Please do not modify it by hand.
|
||||
# To re-generate, run:
|
||||
# cd ~/pytorch && python
|
||||
# torchgen/fuse_attention_patterns/gen_attention_patterns.py
|
||||
|
||||
import torch
|
||||
import torch._inductor
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
|
||||
from torch._inductor.pattern_matcher import (
|
||||
Arg,
|
||||
CallFunction,
|
||||
CallFunctionVarArgs,
|
||||
CallMethod,
|
||||
CallMethodVarArgs,
|
||||
ExclusiveKeywordArg,
|
||||
Ignored,
|
||||
KeywordArg,
|
||||
ListOf,
|
||||
MultiOutputPattern,
|
||||
PatternExpr,
|
||||
RepeatedExpr,
|
||||
_TargetArgsExpr,
|
||||
_TargetExpr,
|
||||
_TargetExprVarArgs,
|
||||
)
|
||||
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
||||
view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2)
|
||||
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
||||
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
||||
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2)
|
||||
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
||||
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
||||
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2)
|
||||
amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
|
||||
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3)
|
||||
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2)
|
||||
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
||||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2)
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2)
|
||||
permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored())
|
||||
bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1)
|
||||
view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored())
|
||||
mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2)
|
||||
sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True)
|
||||
mul_Tensor_1 = CallFunction(aten.mul.Tensor, div_Tensor_1, sum_dim_IntList_1)
|
||||
sub_Tensor_1 = CallFunction(aten.sub.Tensor, mul_Tensor, mul_Tensor_1)
|
||||
div_Tensor_2 = CallFunction(aten.div.Tensor, sub_Tensor_1, KeywordArg('inv_scale'))
|
||||
view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2)
|
||||
permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored())
|
||||
bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2)
|
||||
view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored())
|
||||
permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored())
|
||||
bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8)
|
||||
view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored())
|
||||
permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored())
|
||||
permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored())
|
||||
bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6)
|
||||
view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored())
|
||||
_sfdp_pattern_1_training = MultiOutputPattern([view_default_5,
|
||||
view_default_9,
|
||||
permute_default_4,
|
||||
view_default_11,
|
||||
None
|
||||
])
|
||||
|
||||
|
||||
expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored())
|
||||
view_default = CallFunction(aten.view.default, expand_default, Ignored())
|
||||
permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored())
|
||||
expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored())
|
||||
view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored())
|
||||
bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1)
|
||||
view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored())
|
||||
div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2)
|
||||
amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True)
|
||||
sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default)
|
||||
exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2)
|
||||
sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True)
|
||||
div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList)
|
||||
expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored())
|
||||
view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored())
|
||||
expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored())
|
||||
view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored())
|
||||
bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4)
|
||||
_sfdp_pattern_1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored())
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
# This is an auto-generated file. Please do not modify it by hand.
|
||||
# To re-generate, run:
|
||||
# cd ~/pytorch && python
|
||||
# torchgen/fuse_attention_patterns/gen_attention_patterns.py
|
||||
from ._sfdp_pattern_1 import (_sfdp_pattern_1_training, _sfdp_pattern_1_inference)
|
||||
|
||||
central_index = {
|
||||
'_sfdp_pattern_1_training': _sfdp_pattern_1_training,
|
||||
'_sfdp_pattern_1_inference': _sfdp_pattern_1_inference,
|
||||
}
|
||||
|
||||
|
||||
def get_serialized_pattern(key):
|
||||
import torch._inductor # noqa: F401
|
||||
from torch._inductor import config
|
||||
if config.fallback_random:
|
||||
return None
|
||||
|
||||
# TODO - could add more validation that the same set of decomps used when
|
||||
# tracing SDPA are also used in current context. softmax, dropout, etc
|
||||
# decomp use is stable so not an issue in practice.
|
||||
return central_index.get(key)
|
||||
|
|
@ -204,6 +204,9 @@ class Ignored(PatternExpr):
|
|||
def __repr__(self):
|
||||
return "*"
|
||||
|
||||
def pretty_print(self, pp: "PatternPrettyPrinter"):
|
||||
return "Ignored()"
|
||||
|
||||
|
||||
class KeywordArg(PatternExpr):
|
||||
"""
|
||||
|
|
@ -262,11 +265,15 @@ class _TargetExpr(PatternExpr):
|
|||
self.users = users
|
||||
|
||||
def fns_repr(self):
|
||||
return (
|
||||
f"[{self.fns[0].__name__}, ...]"
|
||||
if len(self.fns) > 1
|
||||
else self.fns[0].__name__
|
||||
)
|
||||
fn_name = self.fns[0].__name__
|
||||
if len(self.fns) > 1:
|
||||
return f"[{fn_name}, ...]"
|
||||
elif self.fns[0] is getattr(torch, fn_name, None):
|
||||
return f"torch.{fn_name}"
|
||||
elif isinstance(self.fns[0], torch._ops.OpOverload):
|
||||
return str(self.fns[0])
|
||||
else:
|
||||
return self.fns[0].__name__
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}({self.fns_repr()})"
|
||||
|
|
@ -338,6 +345,18 @@ class _TargetArgsExpr(_TargetExpr):
|
|||
]
|
||||
return f"{self.__class__.__name__}({', '.join(args)})"
|
||||
|
||||
def pretty_print(self, pp: "PatternPrettyPrinter"):
|
||||
args = [
|
||||
self.fns_repr(),
|
||||
*(pp.pretty_print(x) for x in self.args),
|
||||
*[f"{k}={pp.pretty_print(v)}" for k, v in self.kwargs.items()],
|
||||
]
|
||||
if self.users > 1:
|
||||
args.append(f"_users={self.users}")
|
||||
|
||||
joiner_str = ", "
|
||||
return f"{self.__class__.__name__}({joiner_str.join(args)})"
|
||||
|
||||
def _match(self, node: torch.fx.Node, ctx: MatchContext):
|
||||
if (
|
||||
not self._match_fns(node)
|
||||
|
|
@ -488,6 +507,13 @@ class MultiOutputPattern(PatternExpr):
|
|||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}({self.outputs})"
|
||||
|
||||
def pretty_print(self, pp: "PatternPrettyPrinter"):
|
||||
args = [pp.pretty_print(x) for x in self.outputs]
|
||||
joiner_str = f",\n{' '}"
|
||||
str_out = f"{self.__class__.__name__}([{joiner_str.join(args)}"
|
||||
str_out = f"{str_out}\n])"
|
||||
return str_out
|
||||
|
||||
def _match(self, node: torch.fx.Node, ctx: MatchContext):
|
||||
m = ctx.match(self.outputs[0], node)
|
||||
if not m:
|
||||
|
|
@ -553,6 +579,58 @@ class RepeatedExpr(PatternExpr):
|
|||
return m
|
||||
|
||||
|
||||
class PatternPrettyPrinter:
|
||||
"""
|
||||
Serializes Patterns to executable python.
|
||||
XXX: currently only used and tested for fuse attention patterns. May not cover
|
||||
all patterns.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.namespace = torch.fx.graph._Namespace()
|
||||
self.memoized_objs_names: Dict[PatternExpr, str] = {}
|
||||
self.memoized_objs_pp: Dict[PatternExpr, str] = {}
|
||||
|
||||
@staticmethod
|
||||
def run(obj: PatternExpr, output_name="output"):
|
||||
"""
|
||||
Serializes obj to python code with obj written out to `output_name`
|
||||
"""
|
||||
|
||||
pp = PatternPrettyPrinter()
|
||||
out_str = obj.pretty_print(pp=pp)
|
||||
|
||||
output = []
|
||||
for key in pp.memoized_objs_names:
|
||||
output.append(f"{pp.memoized_objs_names[key]} = {pp.memoized_objs_pp[key]}")
|
||||
|
||||
output.append(f"{output_name} = {out_str}")
|
||||
|
||||
return "\n".join(output)
|
||||
|
||||
def pretty_print(self, obj):
|
||||
if isinstance(obj, _TargetArgsExpr):
|
||||
if memoized_name := self.memoized_objs_names.get(obj):
|
||||
return memoized_name
|
||||
else:
|
||||
return self.memoize(obj)
|
||||
if hasattr(obj, "pretty_print"):
|
||||
return obj.pretty_print(self)
|
||||
|
||||
return repr(obj)
|
||||
|
||||
def memoize(self, obj):
|
||||
obj_str = obj.pretty_print(self)
|
||||
obj_name = obj.fns_repr()
|
||||
for prefix in ("aten.", "torch.", "prims."):
|
||||
obj_name = obj_name.replace(prefix, "")
|
||||
|
||||
tmp_name = self.namespace.create_name(obj_name, None)
|
||||
self.memoized_objs_names[obj] = tmp_name
|
||||
self.memoized_objs_pp[obj] = obj_str
|
||||
return tmp_name
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PatternEntry:
|
||||
pattern: PatternExpr
|
||||
|
|
@ -706,6 +784,7 @@ def register_replacement(
|
|||
extra_check=_return_true,
|
||||
scalar_workaround=(),
|
||||
exclusive_arg_names=(),
|
||||
search_fn_pattern=None,
|
||||
):
|
||||
"""
|
||||
Create a replacement rule based on example functions that get traced
|
||||
|
|
@ -772,14 +851,17 @@ def register_replacement(
|
|||
requires_grad = [
|
||||
isinstance(x, torch.Tensor) and x.requires_grad for x in example_inputs
|
||||
]
|
||||
search_gm = trace_fn(search_fn, example_inputs)
|
||||
pattern = fx_to_pattern(
|
||||
search_gm,
|
||||
ignore_types=(int, float, list, torch.device, torch.dtype),
|
||||
argnames=argnames,
|
||||
scalar_workaround=scalar_workaround,
|
||||
exclusive_arg_names=exclusive_arg_names,
|
||||
)
|
||||
if search_fn_pattern is None:
|
||||
pattern = gen_pattern(
|
||||
search_fn,
|
||||
example_inputs,
|
||||
trace_fn,
|
||||
scalar_workaround,
|
||||
exclusive_arg_names,
|
||||
)
|
||||
else:
|
||||
pattern = search_fn_pattern
|
||||
|
||||
assert repr(pattern) not in _seen_patterns
|
||||
_seen_patterns.add(repr(pattern))
|
||||
pattern = ReplacementPatternEntry(
|
||||
|
|
@ -790,6 +872,21 @@ def register_replacement(
|
|||
pattern.register(pass_dict)
|
||||
|
||||
|
||||
@functorch_config.patch(functionalize_rng_ops=False)
|
||||
def gen_pattern(
|
||||
search_fn, example_inputs, trace_fn, scalar_workaround=(), exclusive_arg_names=()
|
||||
) -> PatternExpr:
|
||||
argnames = [*inspect.signature(search_fn).parameters.keys()]
|
||||
search_gm = trace_fn(search_fn, example_inputs)
|
||||
return fx_to_pattern(
|
||||
search_gm,
|
||||
ignore_types=(int, float, list, torch.device, torch.dtype),
|
||||
argnames=argnames,
|
||||
scalar_workaround=scalar_workaround,
|
||||
exclusive_arg_names=exclusive_arg_names,
|
||||
)
|
||||
|
||||
|
||||
def register_lowering_pattern(
|
||||
pattern, extra_check=_return_true, *, pass_dict, prepend=False
|
||||
):
|
||||
|
|
|
|||
142
torchgen/fuse_attention_patterns/gen_attention_patterns.py
Normal file
142
torchgen/fuse_attention_patterns/gen_attention_patterns.py
Normal file
|
|
@ -0,0 +1,142 @@
|
|||
#!/usr/bin/env python3
|
||||
import os
|
||||
import shutil
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import torch._inductor
|
||||
|
||||
from torch._inductor.fx_passes.fuse_attention import _get_sfdp_patterns
|
||||
from torch._inductor.pattern_matcher import (
|
||||
_TargetExpr,
|
||||
gen_pattern,
|
||||
PatternExpr,
|
||||
PatternPrettyPrinter,
|
||||
)
|
||||
|
||||
auto_generated_msg = """# This is an auto-generated file. Please do not modify it by hand.
|
||||
# To re-generate, run:
|
||||
# cd ~/pytorch && python
|
||||
# torchgen/fuse_attention_patterns/gen_attention_patterns.py
|
||||
"""
|
||||
|
||||
|
||||
def get_file_template() -> str:
|
||||
file_template = f"""# noqa: F401, E501
|
||||
{auto_generated_msg}
|
||||
import torch
|
||||
import torch._inductor
|
||||
|
||||
aten = torch.ops.aten
|
||||
prims = torch.ops.prims
|
||||
|
||||
"""
|
||||
pattern_matcher_imports = []
|
||||
for name in dir(torch._inductor.pattern_matcher):
|
||||
attr = getattr(torch._inductor.pattern_matcher, name)
|
||||
if isinstance(attr, type) and issubclass(attr, (PatternExpr, _TargetExpr)):
|
||||
pattern_matcher_imports.append(name)
|
||||
|
||||
formatted_imports = ",\n ".join(pattern_matcher_imports)
|
||||
formatted_imports = (
|
||||
f"from torch._inductor.pattern_matcher import (\n {formatted_imports},\n)\n"
|
||||
)
|
||||
return f"{file_template}{formatted_imports}"
|
||||
|
||||
|
||||
def get_central_index_epilogue() -> str:
|
||||
epilogue = """
|
||||
def get_serialized_pattern(key):
|
||||
import torch._inductor # noqa: F401
|
||||
from torch._inductor import config
|
||||
if config.fallback_random:
|
||||
return None
|
||||
|
||||
# TODO - could add more validation that the same set of decomps used when
|
||||
# tracing SDPA are also used in current context. softmax, dropout, etc
|
||||
# decomp use is stable so not an issue in practice.
|
||||
return central_index.get(key)
|
||||
"""
|
||||
return epilogue
|
||||
|
||||
|
||||
def clean_directory(dir: Path) -> None:
|
||||
for filename in os.listdir(dir):
|
||||
file_path = os.path.join(dir, filename)
|
||||
if os.path.isfile(file_path) or os.path.islink(file_path):
|
||||
os.unlink(file_path)
|
||||
elif os.path.isdir(file_path):
|
||||
shutil.rmtree(file_path)
|
||||
|
||||
|
||||
def serialize_functions() -> None:
|
||||
file_path = Path.cwd() / "torch" / "_inductor" / "fx_passes" / "serialized_patterns"
|
||||
if not file_path.exists():
|
||||
raise Exception(
|
||||
"Could not find serialized patterns directory. Make sure you are at Pytorch root directory"
|
||||
)
|
||||
|
||||
clean_directory(file_path)
|
||||
|
||||
with open(file_path / "__init__.py", "w"):
|
||||
pass
|
||||
|
||||
central_index = {}
|
||||
file_to_keys = defaultdict(list)
|
||||
seen_patterns = set()
|
||||
|
||||
file_template = get_file_template()
|
||||
for i, (
|
||||
key,
|
||||
kwargs,
|
||||
) in enumerate(_get_sfdp_patterns()):
|
||||
pattern_name = kwargs["search_fn"].__name__
|
||||
gen_kwargs = {
|
||||
key: kwargs[key]
|
||||
for key in ("search_fn", "example_inputs", "trace_fn", "scalar_workaround")
|
||||
}
|
||||
|
||||
# temporary to batch adding new patterns
|
||||
if i >= 2:
|
||||
continue
|
||||
|
||||
from torch._functorch import config as functorch_config
|
||||
|
||||
with functorch_config.patch(functionalize_rng_ops=False):
|
||||
pattern = gen_pattern(**gen_kwargs)
|
||||
|
||||
serialized_pattern = PatternPrettyPrinter.run(pattern, output_name=key)
|
||||
if pattern_name not in seen_patterns:
|
||||
write_mode = "w"
|
||||
seen_patterns.add(pattern_name)
|
||||
else:
|
||||
write_mode = "a"
|
||||
|
||||
with open(file_path / f"{pattern_name}.py", write_mode) as f:
|
||||
if write_mode == "w":
|
||||
f.write(file_template)
|
||||
else:
|
||||
f.write("\n\n")
|
||||
f.write(serialized_pattern)
|
||||
f.write("\n")
|
||||
|
||||
central_index[f"{key}"] = f"{pattern_name}.py"
|
||||
|
||||
file_to_keys[pattern_name].append(f"{key}")
|
||||
|
||||
with open(file_path / "central_index.py", "w") as f:
|
||||
f.write(auto_generated_msg)
|
||||
for pattern_name, keys in file_to_keys.items():
|
||||
f.write(f"from .{pattern_name} import ({', '.join(keys)})\n")
|
||||
|
||||
f.write("\ncentral_index = {\n")
|
||||
for k in central_index.keys():
|
||||
f.write(f" '{k}': {k},\n")
|
||||
f.write("}\n\n")
|
||||
|
||||
f.write(get_central_index_epilogue())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with torch._subclasses.FakeTensorMode():
|
||||
serialize_functions()
|
||||
Loading…
Reference in a new issue