[Customized Optimus] Add select cat aten pass (#145918)

Summary: This is a follow up work of D68695717, where we can further reduce the number of cat kernels in the backward by designing new aten pass in the aten level.

Test Plan:
# unit test

```
buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/inductor:split_cat_fx_aten_passes -- test_select_cat_post_grad
```

Buck UI: https://www.internalfb.com/buck2/6943087f-91be-4dbd-9693-df0a11a50b73
Test UI: https://www.internalfb.com/intern/testinfra/testrun/11821949087998233
Network: Up: 101KiB  Down: 132KiB  (reSessionID-60e898af-f366-4247-a9f7-d8d7cd129fe0)
Analyzing targets. Remaining      0/78148
Executing actions. Remaining      0/476147
Command: test.     Finished 2 local
Tests finished: Pass 3. Fail 0. Fatal 0. Skip 0. Build failure 0

# E2E

### how to add the config

```
        post_grad_fusion_options: {
          "normalization_aten_pass": {},
          "split_cat_aten_pass": {},
          "select_cat_aten_pass": {},
        }
```

{F1974778773}

baseline:

aps-recgpt_ranking_1115_pt2_optimus-e52c1f277e

proposal

aps-recgpt_ranking_1115_pt2_optimus-1b0047ee0e

Differential Revision: D68803384

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145918
Approved by: https://github.com/Yuzhen11
This commit is contained in:
Menglu Yu 2025-01-31 02:35:10 +00:00 committed by PyTorch MergeBot
parent 08d88127fe
commit e01c898e51
2 changed files with 96 additions and 2 deletions

View file

@ -48,6 +48,23 @@ class TestSplitCat(torch.nn.Module):
return cat_1
class TestSelectCat(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor):
select = torch.ops.aten.select.int(x, 1, 0)
select_1 = torch.ops.aten.select.int(x, 1, 1)
select_2 = torch.ops.aten.select.int(x, 1, 2)
select_3 = torch.ops.aten.select.int(x, 1, 3)
select_4 = torch.ops.aten.select.int(x, 1, 4)
select_5 = torch.ops.aten.select.int(x, 1, 5)
cat = torch.ops.aten.cat.default(
[select, select_1, select_2, select_3, select_4, select_5], 1
)
return cat
class TestSplitCatAten(TestCase):
def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3):
if len(set(ref_dict.keys())) != len(set(res_dict.keys())):
@ -101,6 +118,30 @@ class TestSplitCatAten(TestCase):
self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8)
counters.clear()
@requires_cuda
@torch._inductor.config.patch(
pre_grad_fusion_options={},
post_grad_fusion_options={
"normalization_aten_pass": {},
"select_cat_aten_pass": {},
},
)
def test_select_cat_post_grad(self):
counters.clear()
inputs = [
torch.randn(1024, 6, 128, device=torch.device(device=GPU_TYPE)),
]
module = TestSelectCat()
traced = torch.compile(module)
ref = module(*inputs)
res = traced(*inputs)
self.compare_pred(module, traced, inputs)
self.assertEqual(counters["inductor"]["normalization_aten_pass"], 1)
self.assertEqual(counters["inductor"]["select_cat_aten_pass"], 1)
self.assertEqual(ref, res, rtol=1e-8, atol=1e-8)
self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8)
counters.clear()
if __name__ == "__main__":
run_tests()

View file

@ -70,6 +70,7 @@ post_grad_pass_names = [
"shape_padding_multiplier",
"pad_aten_mm_pass",
"split_cat_aten_pass",
"select_cat_aten_pass",
]
for pass_name in pre_grad_pass_names:
@ -1711,8 +1712,8 @@ def merge_split_cat_aten(match: Match, *args, **kwargs):
# check split node and cat node has same dim
if split_dim != cat_dim:
return
# check the split node has consecutive indices
indices = [getitem.args[1] for getitem in getitem_nodes]
# check the cat node has consecutive indices
indices = [arg.args[1] for arg in cat_node.args[0]] # type: ignore[union-attr]
if not is_sorted_and_consecutive(indices) and len(getitem_nodes) != len(cat_inputs): # type: ignore[arg-type]
return
# replace the users of the cat node to be the input of the split node
@ -1728,6 +1729,58 @@ def merge_split_cat_aten(match: Match, *args, **kwargs):
counters["inductor"]["split_cat_aten_pass"] += 1
@register_graph_pattern(
CallFunction(
torch.ops.aten.cat.default,
ListOf(
CallFunctionVarArgs(torch.ops.aten.select.int, users=MULTIPLE),
partial=True,
),
dim=Ignored(),
_users=MULTIPLE,
),
pass_dict=construct_pattern_matcher_pass("select_cat_aten_pass"),
)
def merge_select_cat_aten(match: Match, *args, **kwargs):
graph = match.graph
node = match.nodes[0]
node_input = get_arg_value(node, 0, "tensors")
# get the select nodes from the node
select_nodes = list(node_input.users.keys())
cat_node = next(iter(select_nodes[0].users.keys()))
cat_dim = get_arg_value(cat_node, 1, "dim")
cat_inputs = get_arg_value(cat_node, 0, "tensors")
# check all select nodes has same slice dim
if not all(
select_node.args[1] == select_nodes[0].args[1] for select_node in select_nodes
):
return
# We only consider the case where selece slice dim and cat node has same dim
if select_nodes[0].args[1] != cat_dim:
return
if not is_node_meta_valid(cat_node):
return
# check the cat node has consecutive indices
indices = [select.args[2] for select in cat_node.args[0]] # type: ignore[union-attr]
if not is_sorted_and_consecutive(indices) and len(select_nodes) != len(cat_inputs): # type: ignore[arg-type]
return
# reshape the node input to be the same shape as the cat node
with graph.inserting_before(node):
view_node = graph.call_function(
torch.ops.aten.view.default,
args=(node_input, cat_node.meta["val"].shape),
)
# replace the node input with the new node
cat_node.replace_all_uses_with(view_node)
view_node.meta.update(cat_node.meta)
# remove the cat node
graph.erase_node(cat_node)
for select_node in select_nodes:
if len(select_node.users) == 0:
graph.erase_node(select_node)
counters["inductor"]["select_cat_aten_pass"] += 1
@register_graph_pattern(
CallFunctionVarArgs(torch.ops.aten.cat.default, users=MULTIPLE),
pass_dict=construct_pattern_matcher_pass("normalization_aten_pass"),