diff --git a/test/inductor/test_split_cat_fx_aten_passes.py b/test/inductor/test_split_cat_fx_aten_passes.py index 317ba11830b..1d9826c7644 100644 --- a/test/inductor/test_split_cat_fx_aten_passes.py +++ b/test/inductor/test_split_cat_fx_aten_passes.py @@ -48,6 +48,23 @@ class TestSplitCat(torch.nn.Module): return cat_1 +class TestSelectCat(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x: torch.Tensor): + select = torch.ops.aten.select.int(x, 1, 0) + select_1 = torch.ops.aten.select.int(x, 1, 1) + select_2 = torch.ops.aten.select.int(x, 1, 2) + select_3 = torch.ops.aten.select.int(x, 1, 3) + select_4 = torch.ops.aten.select.int(x, 1, 4) + select_5 = torch.ops.aten.select.int(x, 1, 5) + cat = torch.ops.aten.cat.default( + [select, select_1, select_2, select_3, select_4, select_5], 1 + ) + return cat + + class TestSplitCatAten(TestCase): def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3): if len(set(ref_dict.keys())) != len(set(res_dict.keys())): @@ -101,6 +118,30 @@ class TestSplitCatAten(TestCase): self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8) counters.clear() + @requires_cuda + @torch._inductor.config.patch( + pre_grad_fusion_options={}, + post_grad_fusion_options={ + "normalization_aten_pass": {}, + "select_cat_aten_pass": {}, + }, + ) + def test_select_cat_post_grad(self): + counters.clear() + inputs = [ + torch.randn(1024, 6, 128, device=torch.device(device=GPU_TYPE)), + ] + module = TestSelectCat() + traced = torch.compile(module) + ref = module(*inputs) + res = traced(*inputs) + self.compare_pred(module, traced, inputs) + self.assertEqual(counters["inductor"]["normalization_aten_pass"], 1) + self.assertEqual(counters["inductor"]["select_cat_aten_pass"], 1) + self.assertEqual(ref, res, rtol=1e-8, atol=1e-8) + self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8) + counters.clear() + if __name__ == "__main__": run_tests() diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index bfef16f4403..db6e188e233 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -70,6 +70,7 @@ post_grad_pass_names = [ "shape_padding_multiplier", "pad_aten_mm_pass", "split_cat_aten_pass", + "select_cat_aten_pass", ] for pass_name in pre_grad_pass_names: @@ -1711,8 +1712,8 @@ def merge_split_cat_aten(match: Match, *args, **kwargs): # check split node and cat node has same dim if split_dim != cat_dim: return - # check the split node has consecutive indices - indices = [getitem.args[1] for getitem in getitem_nodes] + # check the cat node has consecutive indices + indices = [arg.args[1] for arg in cat_node.args[0]] # type: ignore[union-attr] if not is_sorted_and_consecutive(indices) and len(getitem_nodes) != len(cat_inputs): # type: ignore[arg-type] return # replace the users of the cat node to be the input of the split node @@ -1728,6 +1729,58 @@ def merge_split_cat_aten(match: Match, *args, **kwargs): counters["inductor"]["split_cat_aten_pass"] += 1 +@register_graph_pattern( + CallFunction( + torch.ops.aten.cat.default, + ListOf( + CallFunctionVarArgs(torch.ops.aten.select.int, users=MULTIPLE), + partial=True, + ), + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("select_cat_aten_pass"), +) +def merge_select_cat_aten(match: Match, *args, **kwargs): + graph = match.graph + node = match.nodes[0] + node_input = get_arg_value(node, 0, "tensors") + # get the select nodes from the node + select_nodes = list(node_input.users.keys()) + cat_node = next(iter(select_nodes[0].users.keys())) + cat_dim = get_arg_value(cat_node, 1, "dim") + cat_inputs = get_arg_value(cat_node, 0, "tensors") + # check all select nodes has same slice dim + if not all( + select_node.args[1] == select_nodes[0].args[1] for select_node in select_nodes + ): + return + # We only consider the case where selece slice dim and cat node has same dim + if select_nodes[0].args[1] != cat_dim: + return + if not is_node_meta_valid(cat_node): + return + # check the cat node has consecutive indices + indices = [select.args[2] for select in cat_node.args[0]] # type: ignore[union-attr] + if not is_sorted_and_consecutive(indices) and len(select_nodes) != len(cat_inputs): # type: ignore[arg-type] + return + # reshape the node input to be the same shape as the cat node + with graph.inserting_before(node): + view_node = graph.call_function( + torch.ops.aten.view.default, + args=(node_input, cat_node.meta["val"].shape), + ) + # replace the node input with the new node + cat_node.replace_all_uses_with(view_node) + view_node.meta.update(cat_node.meta) + # remove the cat node + graph.erase_node(cat_node) + for select_node in select_nodes: + if len(select_node.users) == 0: + graph.erase_node(select_node) + counters["inductor"]["select_cat_aten_pass"] += 1 + + @register_graph_pattern( CallFunctionVarArgs(torch.ops.aten.cat.default, users=MULTIPLE), pass_dict=construct_pattern_matcher_pass("normalization_aten_pass"),