mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/73345 For complex patterns we need to identify which node is the root, so that we can eliminate all other nodes and only preserve the root, e.g. (torch.add, MatchAllNode, (torch.nn.ReLU, torch.nn.Conv2d)), we can preserve the torch.nn.Conv2d as root node, and remove other nodes. Prevoiusly we assumed the root_node of a pattern is the "last node" of the pattern, computed by: ``` def default_root_node_getter(node_pattern): while not isinstance(node_pattern[-1], Node): node_pattern = node_pattern[-1] return node_pattern[-1] ``` This PR enables user configuration to define their own root_node_getter, that means we can define root_node for patterns like: (torch.add, (torch.nn.ReLU, torch.nn.Conv2d), MatchAllNode) Test Plan: python test/test_quantize_fx.py TestFuseFx.test_root_node_getter Imported from OSS Reviewed By: VitalyFedyunin Differential Revision: D34442193 fbshipit-source-id: 2f6da69a5b6527b49710ae32820e8e2915d9af37 (cherry picked from commit 8b49bf0d7d53cdcf2c9f40f8e25bc843e8814026) |
||
|---|---|---|
| .. | ||
| ao_migration | ||
| bc | ||
| core | ||
| dbr | ||
| eager | ||
| fx | ||
| jit | ||
| serialized | ||
| __init__.py | ||