mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[Customized Optimus] Add select cat aten pass (#145918)
Summary: This is a follow up work of D68695717, where we can further reduce the number of cat kernels in the backward by designing new aten pass in the aten level. Test Plan: # unit test ``` buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/inductor:split_cat_fx_aten_passes -- test_select_cat_post_grad ``` Buck UI: https://www.internalfb.com/buck2/6943087f-91be-4dbd-9693-df0a11a50b73 Test UI: https://www.internalfb.com/intern/testinfra/testrun/11821949087998233 Network: Up: 101KiB Down: 132KiB (reSessionID-60e898af-f366-4247-a9f7-d8d7cd129fe0) Analyzing targets. Remaining 0/78148 Executing actions. Remaining 0/476147 Command: test. Finished 2 local Tests finished: Pass 3. Fail 0. Fatal 0. Skip 0. Build failure 0 # E2E ### how to add the config ``` post_grad_fusion_options: { "normalization_aten_pass": {}, "split_cat_aten_pass": {}, "select_cat_aten_pass": {}, } ``` {F1974778773} baseline: aps-recgpt_ranking_1115_pt2_optimus-e52c1f277e proposal aps-recgpt_ranking_1115_pt2_optimus-1b0047ee0e Differential Revision: D68803384 Pull Request resolved: https://github.com/pytorch/pytorch/pull/145918 Approved by: https://github.com/Yuzhen11
This commit is contained in:
parent
08d88127fe
commit
e01c898e51
2 changed files with 96 additions and 2 deletions
|
|
@ -48,6 +48,23 @@ class TestSplitCat(torch.nn.Module):
|
|||
return cat_1
|
||||
|
||||
|
||||
class TestSelectCat(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
select = torch.ops.aten.select.int(x, 1, 0)
|
||||
select_1 = torch.ops.aten.select.int(x, 1, 1)
|
||||
select_2 = torch.ops.aten.select.int(x, 1, 2)
|
||||
select_3 = torch.ops.aten.select.int(x, 1, 3)
|
||||
select_4 = torch.ops.aten.select.int(x, 1, 4)
|
||||
select_5 = torch.ops.aten.select.int(x, 1, 5)
|
||||
cat = torch.ops.aten.cat.default(
|
||||
[select, select_1, select_2, select_3, select_4, select_5], 1
|
||||
)
|
||||
return cat
|
||||
|
||||
|
||||
class TestSplitCatAten(TestCase):
|
||||
def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3):
|
||||
if len(set(ref_dict.keys())) != len(set(res_dict.keys())):
|
||||
|
|
@ -101,6 +118,30 @@ class TestSplitCatAten(TestCase):
|
|||
self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8)
|
||||
counters.clear()
|
||||
|
||||
@requires_cuda
|
||||
@torch._inductor.config.patch(
|
||||
pre_grad_fusion_options={},
|
||||
post_grad_fusion_options={
|
||||
"normalization_aten_pass": {},
|
||||
"select_cat_aten_pass": {},
|
||||
},
|
||||
)
|
||||
def test_select_cat_post_grad(self):
|
||||
counters.clear()
|
||||
inputs = [
|
||||
torch.randn(1024, 6, 128, device=torch.device(device=GPU_TYPE)),
|
||||
]
|
||||
module = TestSelectCat()
|
||||
traced = torch.compile(module)
|
||||
ref = module(*inputs)
|
||||
res = traced(*inputs)
|
||||
self.compare_pred(module, traced, inputs)
|
||||
self.assertEqual(counters["inductor"]["normalization_aten_pass"], 1)
|
||||
self.assertEqual(counters["inductor"]["select_cat_aten_pass"], 1)
|
||||
self.assertEqual(ref, res, rtol=1e-8, atol=1e-8)
|
||||
self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8)
|
||||
counters.clear()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -70,6 +70,7 @@ post_grad_pass_names = [
|
|||
"shape_padding_multiplier",
|
||||
"pad_aten_mm_pass",
|
||||
"split_cat_aten_pass",
|
||||
"select_cat_aten_pass",
|
||||
]
|
||||
|
||||
for pass_name in pre_grad_pass_names:
|
||||
|
|
@ -1711,8 +1712,8 @@ def merge_split_cat_aten(match: Match, *args, **kwargs):
|
|||
# check split node and cat node has same dim
|
||||
if split_dim != cat_dim:
|
||||
return
|
||||
# check the split node has consecutive indices
|
||||
indices = [getitem.args[1] for getitem in getitem_nodes]
|
||||
# check the cat node has consecutive indices
|
||||
indices = [arg.args[1] for arg in cat_node.args[0]] # type: ignore[union-attr]
|
||||
if not is_sorted_and_consecutive(indices) and len(getitem_nodes) != len(cat_inputs): # type: ignore[arg-type]
|
||||
return
|
||||
# replace the users of the cat node to be the input of the split node
|
||||
|
|
@ -1728,6 +1729,58 @@ def merge_split_cat_aten(match: Match, *args, **kwargs):
|
|||
counters["inductor"]["split_cat_aten_pass"] += 1
|
||||
|
||||
|
||||
@register_graph_pattern(
|
||||
CallFunction(
|
||||
torch.ops.aten.cat.default,
|
||||
ListOf(
|
||||
CallFunctionVarArgs(torch.ops.aten.select.int, users=MULTIPLE),
|
||||
partial=True,
|
||||
),
|
||||
dim=Ignored(),
|
||||
_users=MULTIPLE,
|
||||
),
|
||||
pass_dict=construct_pattern_matcher_pass("select_cat_aten_pass"),
|
||||
)
|
||||
def merge_select_cat_aten(match: Match, *args, **kwargs):
|
||||
graph = match.graph
|
||||
node = match.nodes[0]
|
||||
node_input = get_arg_value(node, 0, "tensors")
|
||||
# get the select nodes from the node
|
||||
select_nodes = list(node_input.users.keys())
|
||||
cat_node = next(iter(select_nodes[0].users.keys()))
|
||||
cat_dim = get_arg_value(cat_node, 1, "dim")
|
||||
cat_inputs = get_arg_value(cat_node, 0, "tensors")
|
||||
# check all select nodes has same slice dim
|
||||
if not all(
|
||||
select_node.args[1] == select_nodes[0].args[1] for select_node in select_nodes
|
||||
):
|
||||
return
|
||||
# We only consider the case where selece slice dim and cat node has same dim
|
||||
if select_nodes[0].args[1] != cat_dim:
|
||||
return
|
||||
if not is_node_meta_valid(cat_node):
|
||||
return
|
||||
# check the cat node has consecutive indices
|
||||
indices = [select.args[2] for select in cat_node.args[0]] # type: ignore[union-attr]
|
||||
if not is_sorted_and_consecutive(indices) and len(select_nodes) != len(cat_inputs): # type: ignore[arg-type]
|
||||
return
|
||||
# reshape the node input to be the same shape as the cat node
|
||||
with graph.inserting_before(node):
|
||||
view_node = graph.call_function(
|
||||
torch.ops.aten.view.default,
|
||||
args=(node_input, cat_node.meta["val"].shape),
|
||||
)
|
||||
# replace the node input with the new node
|
||||
cat_node.replace_all_uses_with(view_node)
|
||||
view_node.meta.update(cat_node.meta)
|
||||
# remove the cat node
|
||||
graph.erase_node(cat_node)
|
||||
for select_node in select_nodes:
|
||||
if len(select_node.users) == 0:
|
||||
graph.erase_node(select_node)
|
||||
counters["inductor"]["select_cat_aten_pass"] += 1
|
||||
|
||||
|
||||
@register_graph_pattern(
|
||||
CallFunctionVarArgs(torch.ops.aten.cat.default, users=MULTIPLE),
|
||||
pass_dict=construct_pattern_matcher_pass("normalization_aten_pass"),
|
||||
|
|
|
|||
Loading…
Reference in a new issue