mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[Customized Optimus][Inductor] Add split cat pattern in aten level (#145721)
Summary: Thanks Microve for discovering that recGPT has some repeated similar kernels that might be optimized through optimus. After investigation, I designed a pattern in the aten level to remove such excessive kernels. trace: https://fburl.com/perfdoctor/82fauil7 tlparse: https://fburl.com/98q6tadx Test Plan: # unit test ``` buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/inductor:split_cat_fx_aten_passes -- test_split_cat_post_grad ``` Buck UI: https://www.internalfb.com/buck2/e8458d63-b8ca-498b-a731-77a83fb4d1cb Test UI: https://www.internalfb.com/intern/testinfra/testrun/16325548715106567 Network: Up: 341KiB Down: 359KiB (reSessionID-7d3de666-7fc1-4988-8d11-d75ba958016d) Executing actions. Remaining 0/3 Command: test. Finished 2 local Time elapsed: 3:04.8s Tests finished: Pass 2. Fail 0. Fatal 0. Skip 0. Build failure 0 # local run ``` buck2 run @//mode/opt aps_models/ads/recgpt_exp:recgpt_launcher -- mode=local_recgpt_ranking_30x_v0_unified_seq_1115 ``` https://www.internalfb.com/mlhub/pipeline/1630903954173593 # E2E ``` buck2 run @//mode/opt aps_models/ads/recgpt_exp:recgpt_launcher -- mode=mast_recgpt_ranking_30x_v0_unified_seq_1115 launcher.oncall=ads_model_platform launcher.data_project=ai_large_scale launcher.fbl_entitlement=ads_global_tc_training_efficiency launcher.tags=[ads_ranking_taxonomy_mc_qps_optimization] launcher.hardware=SMC_T20 launcher.job_name=recgpt_ranking_1115_pt2_with_optimus data_loader.dataset.table_ds=[2024-12-13,2024-12-14,2024-12-15,2024-12-16,2024-12-17,2024-12-18] ``` ### how to add the config Add the following patterns to the dynamo config ``` post_grad_fusion_options: { "normalization_aten_pass": {}, "split_cat_aten_pass": {}, } ``` {F1974700331} baseline: aps-recgpt_ranking_1115_pt2_5-8cb4905c7d {F1974700216} proposal: Differential Revision: D68695717 Pull Request resolved: https://github.com/pytorch/pytorch/pull/145721 Approved by: https://github.com/Yuzhen11
This commit is contained in:
parent
331f49057d
commit
29521256e1
2 changed files with 199 additions and 4 deletions
106
test/inductor/test_split_cat_fx_aten_passes.py
Normal file
106
test/inductor/test_split_cat_fx_aten_passes.py
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
# Owner(s): ["module: inductor"]
|
||||
|
||||
import torch
|
||||
import torch._inductor
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE
|
||||
from torch.testing._internal.triton_utils import requires_cuda
|
||||
|
||||
|
||||
try:
|
||||
# importing this will register fbgemm lowerings for inductor
|
||||
import deeplearning.fbgemm.fbgemm_gpu.fb.inductor_lowerings # noqa: F401
|
||||
|
||||
has_fbgemm = True
|
||||
except Exception:
|
||||
has_fbgemm = False
|
||||
|
||||
|
||||
class TestSplitCat(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
||||
cat = torch.ops.aten.cat.default([x, y], 1)
|
||||
split = torch.ops.aten.split.Tensor(cat, 32, 1)
|
||||
getitem = split[0]
|
||||
getitem_1 = split[1]
|
||||
getitem_2 = split[2]
|
||||
getitem_3 = split[3]
|
||||
getitem_4 = split[4]
|
||||
getitem_5 = split[5]
|
||||
getitem_6 = split[6]
|
||||
getitem_7 = split[7]
|
||||
cat_1 = torch.ops.aten.cat.default(
|
||||
[
|
||||
getitem,
|
||||
getitem_1,
|
||||
getitem_2,
|
||||
getitem_3,
|
||||
getitem_4,
|
||||
getitem_5,
|
||||
getitem_6,
|
||||
getitem_7,
|
||||
],
|
||||
1,
|
||||
)
|
||||
return cat_1
|
||||
|
||||
|
||||
class TestSplitCatAten(TestCase):
|
||||
def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3):
|
||||
if len(set(ref_dict.keys())) != len(set(res_dict.keys())):
|
||||
return False
|
||||
for key1 in ref_dict.keys():
|
||||
key2 = "_orig_mod." + key1
|
||||
assert key2 in res_dict, f"{key1} does not exist in traced module"
|
||||
if not torch.allclose(ref_dict[key1], res_dict[key2], rtol=rtol, atol=atol):
|
||||
return False
|
||||
return True
|
||||
|
||||
def compare_pred(self, module, traced, input, rtol=1e-3, atol=1e-3):
|
||||
ref = module(*input)
|
||||
res = traced(*input)
|
||||
self.assertEqual(ref, res, rtol=rtol, atol=atol)
|
||||
|
||||
def compare_parameters(self, module, traced, rtol=1e-3, atol=1e-3):
|
||||
ref_params = dict(module.named_parameters())
|
||||
res_params = dict(traced.named_parameters())
|
||||
self.assertTrue(self.compare_dict_tensors(ref_params, res_params, rtol, atol))
|
||||
|
||||
def compare_gradients(self, module, traced, rtol=1e-3, atol=1e-3):
|
||||
ref_grad = {key: param.grad for key, param in module.named_parameters()}
|
||||
res_grad = {key: param.grad for key, param in traced.named_parameters()}
|
||||
self.assertTrue(
|
||||
self.compare_dict_tensors(ref_grad, res_grad, rtol=rtol, atol=atol)
|
||||
)
|
||||
|
||||
@requires_cuda
|
||||
@torch._inductor.config.patch(
|
||||
pre_grad_fusion_options={},
|
||||
post_grad_fusion_options={
|
||||
"normalization_aten_pass": {},
|
||||
"split_cat_aten_pass": {},
|
||||
},
|
||||
)
|
||||
def test_split_cat_post_grad(self):
|
||||
counters.clear()
|
||||
inputs = [
|
||||
torch.randn(1024, 128, device=torch.device(device=GPU_TYPE)),
|
||||
torch.randn(1024, 128, device=torch.device(device=GPU_TYPE)),
|
||||
]
|
||||
module = TestSplitCat()
|
||||
traced = torch.compile(module)
|
||||
ref = module(*inputs)
|
||||
res = traced(*inputs)
|
||||
self.compare_pred(module, traced, inputs)
|
||||
self.assertEqual(counters["inductor"]["normalization_aten_pass"], 3)
|
||||
self.assertEqual(counters["inductor"]["split_cat_aten_pass"], 1)
|
||||
self.assertEqual(ref, res, rtol=1e-8, atol=1e-8)
|
||||
self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8)
|
||||
counters.clear()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
@ -69,6 +69,7 @@ post_grad_pass_names = [
|
|||
"unbind_stack_aten_pass",
|
||||
"shape_padding_multiplier",
|
||||
"pad_aten_mm_pass",
|
||||
"split_cat_aten_pass",
|
||||
]
|
||||
|
||||
for pass_name in pre_grad_pass_names:
|
||||
|
|
@ -1635,6 +1636,98 @@ def mutate_cat_node(match: Match, split_sections: list[int], dim: int):
|
|||
counters["inductor"]["mutate_cat_pass"] += 1
|
||||
|
||||
|
||||
getitem_split_aten = ListOf(
|
||||
CallFunction(
|
||||
operator.getitem,
|
||||
CallFunctionVarArgs(torch.ops.aten.split.Tensor, users=MULTIPLE),
|
||||
Ignored(),
|
||||
_users=MULTIPLE,
|
||||
),
|
||||
partial=True,
|
||||
)
|
||||
|
||||
|
||||
@register_graph_pattern(
|
||||
CallFunctionVarArgs(torch.ops.aten.split.Tensor, users=MULTIPLE),
|
||||
pass_dict=construct_pattern_matcher_pass("normalization_aten_pass"),
|
||||
)
|
||||
def normalize_split_default_aten(match: Match, *args, **kwargs):
|
||||
split_node = match.nodes[0]
|
||||
graph = match.graph
|
||||
split_input, split_size, split_dim = _get_split_args_default(split_node)
|
||||
if split_input is None or split_dim is None or split_size is None:
|
||||
log.debug("couldn't find split args")
|
||||
return
|
||||
if not is_node_meta_valid(split_node):
|
||||
log.debug("val absent for node: %s", split_node)
|
||||
return
|
||||
assert isinstance(split_node.meta["val"], (list, tuple))
|
||||
split_sections = [t.size()[split_dim] for t in split_node.meta["val"]]
|
||||
if any(isinstance(section, torch.SymInt) for section in split_sections):
|
||||
# TODO dynamic_shapes with assume_static_by_default=False fails while AOT Autograd tracing.
|
||||
return
|
||||
if split_dim < 0: # Normalize split dim
|
||||
split_dim += split_input.meta["val"].dim()
|
||||
|
||||
new_args = (split_input, split_size)
|
||||
new_kwargs = {"dim": split_dim}
|
||||
if (
|
||||
split_node.args == new_args
|
||||
and split_node.kwargs == new_kwargs
|
||||
and split_node.op == "call_function"
|
||||
):
|
||||
return
|
||||
|
||||
with graph.inserting_after(split_node):
|
||||
new_split_node = graph.call_function(
|
||||
torch.ops.aten.split.Tensor,
|
||||
args=new_args,
|
||||
kwargs=new_kwargs, # type: ignore[arg-type]
|
||||
)
|
||||
split_node.replace_all_uses_with(new_split_node)
|
||||
new_split_node.meta.update(split_node.meta)
|
||||
graph.erase_node(split_node)
|
||||
counters["inductor"]["normalization_aten_pass"] += 1
|
||||
|
||||
|
||||
@register_graph_pattern(
|
||||
CallFunction(
|
||||
torch.ops.aten.cat.default,
|
||||
getitem_split_aten,
|
||||
dim=Ignored(),
|
||||
_users=MULTIPLE,
|
||||
),
|
||||
pass_dict=construct_pattern_matcher_pass("split_cat_aten_pass"),
|
||||
)
|
||||
def merge_split_cat_aten(match: Match, *args, **kwargs):
|
||||
graph = match.graph
|
||||
split_node = match.nodes[0]
|
||||
split_input, _, split_dim = _get_split_args_default(split_node)
|
||||
# get the getitem nodes from the split node
|
||||
getitem_nodes = list(split_node.users.keys())
|
||||
cat_node = next(iter(getitem_nodes[0].users.keys()))
|
||||
cat_dim = get_arg_value(cat_node, 1, "dim")
|
||||
cat_inputs = get_arg_value(cat_node, 0, "tensors")
|
||||
# check split node and cat node has same dim
|
||||
if split_dim != cat_dim:
|
||||
return
|
||||
# check the split node has consecutive indices
|
||||
indices = [getitem.args[1] for getitem in getitem_nodes]
|
||||
if not is_sorted_and_consecutive(indices) and len(getitem_nodes) != len(cat_inputs): # type: ignore[arg-type]
|
||||
return
|
||||
# replace the users of the cat node to be the input of the split node
|
||||
cat_node.replace_all_uses_with(split_input)
|
||||
# remove the cat node
|
||||
graph.erase_node(cat_node)
|
||||
# remove getitem nodes and split node with no users
|
||||
for getitem_node in getitem_nodes:
|
||||
if len(getitem_node.users) == 0:
|
||||
graph.erase_node(getitem_node)
|
||||
if len(split_node.users) == 0:
|
||||
graph.erase_node(split_node)
|
||||
counters["inductor"]["split_cat_aten_pass"] += 1
|
||||
|
||||
|
||||
@register_graph_pattern(
|
||||
CallFunctionVarArgs(torch.ops.aten.cat.default, users=MULTIPLE),
|
||||
pass_dict=construct_pattern_matcher_pass("normalization_aten_pass"),
|
||||
|
|
@ -2034,10 +2127,6 @@ def remove_split_unbind_children(graph: torch.fx.Graph, inputs: list[torch.fx.No
|
|||
# cat (user=mul, dim=1, split_node)
|
||||
|
||||
|
||||
@register_graph_pattern(
|
||||
CallFunctionVarArgs(torch.cat, users=MULTIPLE),
|
||||
pass_dict=construct_pattern_matcher_pass("split_cat_to_slices_pass"),
|
||||
)
|
||||
@register_graph_pattern(
|
||||
CallFunction(
|
||||
torch.cat,
|
||||
|
|
|
|||
Loading…
Reference in a new issue