From 5f7ea7ca6acb9fe2d60faa36fba89689f92f5747 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 7 Nov 2024 16:12:40 -0800 Subject: [PATCH] [invoke_subgraph] Support symint/int as inputs (#140058) Pull Request resolved: https://github.com/pytorch/pytorch/pull/140058 Approved by: https://github.com/ydwu4, https://github.com/eellison ghstack dependencies: #139162 --- test/higher_order_ops/test_invoke_subgraph.py | 15 +++++++++++ torch/_higher_order_ops/invoke_subgraph.py | 16 +++++++----- torch/_inductor/ir.py | 26 ++++++++++++++----- 3 files changed, 43 insertions(+), 14 deletions(-) diff --git a/test/higher_order_ops/test_invoke_subgraph.py b/test/higher_order_ops/test_invoke_subgraph.py index bb8e9a3d33b..7a938556b58 100644 --- a/test/higher_order_ops/test_invoke_subgraph.py +++ b/test/higher_order_ops/test_invoke_subgraph.py @@ -558,6 +558,21 @@ class GraphModule(torch.nn.Module): """, ) + def test_dynamic(self): + @mark_compile_region + def gn(x): + return torch.sin(x) + + def fn(x): + return gn(x) + + x = torch.randn(8, 8, requires_grad=True) + torch._dynamo.mark_dynamic(x, 0) + ref = fn(x) + opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref, res) + if __name__ == "__main__": run_tests() diff --git a/torch/_higher_order_ops/invoke_subgraph.py b/torch/_higher_order_ops/invoke_subgraph.py index c72ef1a3190..b192e551669 100644 --- a/torch/_higher_order_ops/invoke_subgraph.py +++ b/torch/_higher_order_ops/invoke_subgraph.py @@ -15,6 +15,8 @@ from torch._higher_order_ops.utils import ( get_dummy_aot_autograd_config, prepare_fw_with_masks, reenter_make_fx, + save_tensors_and_symints_for_backward, + saved_tensors_and_symints, ) from torch._ops import HigherOrderOperator from torch._subclasses import FakeTensorMode @@ -41,8 +43,8 @@ class InvokeSubgraphHOP(HigherOrderOperator): subgraph: GraphModule, identifier: Optional[str], operands: Union[ - List[Union[torch.Tensor, torch.SymInt]], - Tuple[Union[torch.Tensor, torch.SymInt]], + List[Union[torch.Tensor, int, torch.SymInt]], + Tuple[Union[torch.Tensor, int, torch.SymInt]], ], ): assert identifier is None or isinstance( @@ -51,10 +53,10 @@ class InvokeSubgraphHOP(HigherOrderOperator): assert isinstance( operands, (list, tuple) - ), f"invoke_subgraph operands must be a list or tuple of tensors and SymInts {operands}" + ), f"invoke_subgraph operands must be a list or tuple of tensors/ints/SymInts {operands}" assert all( - isinstance(o, (torch.Tensor, torch.SymInt)) for o in operands - ), f"invoke_subgraph operands must be a list of tensors and SymInts {operands}" + isinstance(o, (torch.Tensor, int, torch.SymInt)) for o in operands + ), f"invoke_subgraph operands must be a list of tensors/ints/SymInts {operands}" return super().__call__(subgraph, identifier, operands) @@ -188,14 +190,14 @@ class InvokeSubgraphAutogradOp(torch.autograd.Function): operands, ) - ctx.save_for_backward(*operands) + save_tensors_and_symints_for_backward(ctx, operands) return out @staticmethod def backward(ctx, *grad_outs): bw_graph = ctx._bw_graph identifier = ctx._identifier - primals = ctx.saved_tensors + primals = saved_tensors_and_symints(ctx) num_fw_outs = ctx._num_fw_outs # While tracing we made the assumption that tangents are contiguous. So, diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index df483c540bf..2d296c7a58a 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -6784,12 +6784,15 @@ class InvokeSubgraph(ExternKernel): def handle_sym_expr(stride): # type: ignore[no-untyped-def] return [s.node.expr if isinstance(s, torch.SymInt) else s for s in stride] - fake_strides = [fake_operand.stride() for fake_operand in fake_operands] - fake_strides = [handle_sym_expr(stride) for stride in fake_strides] - operands = [ - cls.require_exact_strides(x, fake_strides[idx]) - for idx, x in enumerate(operands) - ] + new_operands = [] + for idx, operand in enumerate(operands): + if isinstance(operand, ShapeAsConstantBuffer): + new_operands.append(operand) + else: + example_stride = handle_sym_expr(fake_operands[idx].stride()) + new_operands.append(cls.require_exact_strides(operand, example_stride)) + + operands = new_operands if subgraph.graph is None: # create and lower subgraphs @@ -6802,7 +6805,16 @@ class InvokeSubgraph(ExternKernel): subgraph.graph.run(*fake_operands) outputs = subgraph.graph.graph_outputs - device = operands[0].get_device() + + # Find the device - operands could be integers from shapes, so we can't + # use operands[0] + device = None + for operand in operands: + if not isinstance(operand, ShapeAsConstantBuffer): + device = operand.get_device() + break + assert device is not None + invoke_subgraph = InvokeSubgraph( subgraph=subgraph, operands=operands,