pytorch/torchgen/fuse/gen_patterns.py
Aaron Orenstein 4044e93a51 Add mm_pattern and bmm_pattern to serialized_patterns (#121313)
Make it easier to serialize patterns by adding `pattern_matcher.gen_register_replacement()` which is like `pattern_matcher.register_replacement()` but also requires the replacement to be precompiled.

To precompile patterns (and save to disk) run:
```
torchgen/fuse_attention_patterns/gen_attention_patterns.py
```

- Updated the sfdp patterns to use `gen_register_replacement`.
- Add serialized patterns for mm_pattern and bmm_pattern (The 'misc' patterns don't serialize cleanly so can't be added).
- Updated the testing so it checked the round-trip patterns match and not just that it serialized the same way.
- Checking that the patterns round-trip properly found that the `users` field wasn't being serialized properly.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121313
Approved by: https://github.com/eellison
2024-04-09 19:42:19 +00:00

19 lines
661 B
Python

#!/usr/bin/env python3
import os
from torch._inductor import pattern_matcher
from torch._inductor.fx_passes import joint_graph
if __name__ == "__main__":
# Start by deleting all the existing patterns.
for file in os.listdir(pattern_matcher.SERIALIZED_PATTERN_PATH):
if file in ("__init__.py", "__pycache__"):
continue
file = pattern_matcher.SERIALIZED_PATTERN_PATH / file
if file.is_file():
file.unlink()
# Now have joint_graph load all known patterns and tell the pattern matcher
# to serialize the patterns as it goes.
os.environ["PYTORCH_GEN_PATTERNS"] = "1"
joint_graph.lazy_init()