[invoke_subgraph] Support symint/int as inputs (#140058)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140058
Approved by: https://github.com/ydwu4, https://github.com/eellison
ghstack dependencies: #139162
This commit is contained in:
Animesh Jain 2024-11-07 16:12:40 -08:00 committed by PyTorch MergeBot
parent d4cdc09881
commit 5f7ea7ca6a
3 changed files with 43 additions and 14 deletions

View file

@ -558,6 +558,21 @@ class GraphModule(torch.nn.Module):
""",
)
def test_dynamic(self):
@mark_compile_region
def gn(x):
return torch.sin(x)
def fn(x):
return gn(x)
x = torch.randn(8, 8, requires_grad=True)
torch._dynamo.mark_dynamic(x, 0)
ref = fn(x)
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
res = opt_fn(x)
self.assertEqual(ref, res)
if __name__ == "__main__":
run_tests()

View file

@ -15,6 +15,8 @@ from torch._higher_order_ops.utils import (
get_dummy_aot_autograd_config,
prepare_fw_with_masks,
reenter_make_fx,
save_tensors_and_symints_for_backward,
saved_tensors_and_symints,
)
from torch._ops import HigherOrderOperator
from torch._subclasses import FakeTensorMode
@ -41,8 +43,8 @@ class InvokeSubgraphHOP(HigherOrderOperator):
subgraph: GraphModule,
identifier: Optional[str],
operands: Union[
List[Union[torch.Tensor, torch.SymInt]],
Tuple[Union[torch.Tensor, torch.SymInt]],
List[Union[torch.Tensor, int, torch.SymInt]],
Tuple[Union[torch.Tensor, int, torch.SymInt]],
],
):
assert identifier is None or isinstance(
@ -51,10 +53,10 @@ class InvokeSubgraphHOP(HigherOrderOperator):
assert isinstance(
operands, (list, tuple)
), f"invoke_subgraph operands must be a list or tuple of tensors and SymInts {operands}"
), f"invoke_subgraph operands must be a list or tuple of tensors/ints/SymInts {operands}"
assert all(
isinstance(o, (torch.Tensor, torch.SymInt)) for o in operands
), f"invoke_subgraph operands must be a list of tensors and SymInts {operands}"
isinstance(o, (torch.Tensor, int, torch.SymInt)) for o in operands
), f"invoke_subgraph operands must be a list of tensors/ints/SymInts {operands}"
return super().__call__(subgraph, identifier, operands)
@ -188,14 +190,14 @@ class InvokeSubgraphAutogradOp(torch.autograd.Function):
operands,
)
ctx.save_for_backward(*operands)
save_tensors_and_symints_for_backward(ctx, operands)
return out
@staticmethod
def backward(ctx, *grad_outs):
bw_graph = ctx._bw_graph
identifier = ctx._identifier
primals = ctx.saved_tensors
primals = saved_tensors_and_symints(ctx)
num_fw_outs = ctx._num_fw_outs
# While tracing we made the assumption that tangents are contiguous. So,

View file

@ -6784,12 +6784,15 @@ class InvokeSubgraph(ExternKernel):
def handle_sym_expr(stride): # type: ignore[no-untyped-def]
return [s.node.expr if isinstance(s, torch.SymInt) else s for s in stride]
fake_strides = [fake_operand.stride() for fake_operand in fake_operands]
fake_strides = [handle_sym_expr(stride) for stride in fake_strides]
operands = [
cls.require_exact_strides(x, fake_strides[idx])
for idx, x in enumerate(operands)
]
new_operands = []
for idx, operand in enumerate(operands):
if isinstance(operand, ShapeAsConstantBuffer):
new_operands.append(operand)
else:
example_stride = handle_sym_expr(fake_operands[idx].stride())
new_operands.append(cls.require_exact_strides(operand, example_stride))
operands = new_operands
if subgraph.graph is None:
# create and lower subgraphs
@ -6802,7 +6805,16 @@ class InvokeSubgraph(ExternKernel):
subgraph.graph.run(*fake_operands)
outputs = subgraph.graph.graph_outputs
device = operands[0].get_device()
# Find the device - operands could be integers from shapes, so we can't
# use operands[0]
device = None
for operand in operands:
if not isinstance(operand, ShapeAsConstantBuffer):
device = operand.get_device()
break
assert device is not None
invoke_subgraph = InvokeSubgraph(
subgraph=subgraph,
operands=operands,