mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
A resubmit of https://github.com/pytorch/pytorch/pull/108447. Copy over the descriptions: This is a follow-up of the discussion in https://github.com/pytorch/pytorch/pull/108356, where we want to repalce source_fn with source_fn_stack Before this PR, for the following example: ```python backend = EagerAndRecordGraphs() @torch.compile(backend=backend, fullgraph=True) def cond_f(pred, pred2, x, y): def true_fn(pred2, x, y): return x + y def false_fn(pred2, x, y): def true_fn2(x, y): return x.sin() - y.cos() def false_fn2(x, y): return x.cos() - y.sin() return control_flow.cond(pred2, true_fn2, false_fn2, (x, y)) return control_flow.cond(pred, true_fn, false_fn, (pred2, x, y)) ``` The graph captured is shown below: ```python class GraphModule(torch.nn.Module): def forward(self, L_pred_ : torch.Tensor, L_pred2_ : torch.Tensor, L_x_ : torch.Tensor, L_y_ : torch.Tensor): l_pred_ = L_pred_ l_pred2_ = L_pred2_ l_x_ = L_x_ l_y_ = L_y_ cond_true_1 = self.cond_true_1 cond_false_1 = self.cond_false_1 cond = torch.ops.higher_order.cond(l_pred_, cond_true_1, cond_false_1, [l_pred2_, l_x_, l_y_]); l_pred_ = cond_true_1 = cond_false_1 = l_pred2_ = l_x_ = l_y_ = None return (cond,) class GraphModule(torch.nn.Module): def forward(self, l_pred2_, l_x_, l_y_): add = l_x_ + l_y_; l_x_ = l_y_ = None return add class GraphModule(torch.nn.Module): def forward(self, l_pred2_, l_x_, l_y_): cond_true_0 = self.cond_true_0 cond_false_0 = self.cond_false_0 cond = torch.ops.higher_order.cond(l_pred2_, cond_true_0, cond_false_0, [l_x_, l_y_]); l_pred2_ = cond_true_0 = cond_false_0 = l_x_ = l_y_ = None return cond class GraphModule(torch.nn.Module): def forward(self, l_x_, l_y_): sin = l_x_.sin(); l_x_ = None cos = l_y_.cos(); l_y_ = None sub = sin - cos; sin = cos = None return sub class GraphModule(torch.nn.Module): def forward(self, l_x_, l_y_): cos = l_x_.cos(); l_x_ = None sin = l_y_.sin(); l_y_ = None sub = cos - sin; cos = sin = None return sub ``` the source_fn for inner cond, sin, cos will be a (name, target) tuple: ``` ('cond', <torch._ops.HigherOrderOperator object at xxx>) ('sin', 'sin') ('cos', 'cos') ('sub'. <built-in function sub>) ``` After this pr, the source_fn_stack will be a list of (name, target) tuple. The bottom of stack is the end of the list. ``` [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>)], [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sin', 'sin')], [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cos', 'cos')] [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sub', <built-in function sub>)] ``` Test Plan: See added tests in test_higher_order_ops.py and modify existing test. Pull Request resolved: https://github.com/pytorch/pytorch/pull/108595 Approved by: https://github.com/angelayi, https://github.com/zou3519 |
||
|---|---|---|
| .. | ||
| attn_ft.py | ||
| attn_positional.py | ||
| common_utils.py | ||
| discover_coverage.py | ||
| functorch_additional_op_db.py | ||
| test_aotdispatch.py | ||
| test_control_flow.py | ||
| test_dims.py | ||
| test_eager_transforms.py | ||
| test_logging.py | ||
| test_memory_efficient_fusion.py | ||
| test_minifier.py | ||
| test_ops.py | ||
| test_parsing.py | ||
| test_rearrange.py | ||
| test_vmap.py | ||
| test_vmap_registrations.py | ||
| xfail_suggester.py | ||