mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[invoke_subgraph] Support symint/int as inputs (#140058)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140058 Approved by: https://github.com/ydwu4, https://github.com/eellison ghstack dependencies: #139162
This commit is contained in:
parent
d4cdc09881
commit
5f7ea7ca6a
3 changed files with 43 additions and 14 deletions
|
|
@ -558,6 +558,21 @@ class GraphModule(torch.nn.Module):
|
|||
""",
|
||||
)
|
||||
|
||||
def test_dynamic(self):
|
||||
@mark_compile_region
|
||||
def gn(x):
|
||||
return torch.sin(x)
|
||||
|
||||
def fn(x):
|
||||
return gn(x)
|
||||
|
||||
x = torch.randn(8, 8, requires_grad=True)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
ref = fn(x)
|
||||
opt_fn = torch.compile(fn, backend="inductor", fullgraph=True)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@ from torch._higher_order_ops.utils import (
|
|||
get_dummy_aot_autograd_config,
|
||||
prepare_fw_with_masks,
|
||||
reenter_make_fx,
|
||||
save_tensors_and_symints_for_backward,
|
||||
saved_tensors_and_symints,
|
||||
)
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses import FakeTensorMode
|
||||
|
|
@ -41,8 +43,8 @@ class InvokeSubgraphHOP(HigherOrderOperator):
|
|||
subgraph: GraphModule,
|
||||
identifier: Optional[str],
|
||||
operands: Union[
|
||||
List[Union[torch.Tensor, torch.SymInt]],
|
||||
Tuple[Union[torch.Tensor, torch.SymInt]],
|
||||
List[Union[torch.Tensor, int, torch.SymInt]],
|
||||
Tuple[Union[torch.Tensor, int, torch.SymInt]],
|
||||
],
|
||||
):
|
||||
assert identifier is None or isinstance(
|
||||
|
|
@ -51,10 +53,10 @@ class InvokeSubgraphHOP(HigherOrderOperator):
|
|||
|
||||
assert isinstance(
|
||||
operands, (list, tuple)
|
||||
), f"invoke_subgraph operands must be a list or tuple of tensors and SymInts {operands}"
|
||||
), f"invoke_subgraph operands must be a list or tuple of tensors/ints/SymInts {operands}"
|
||||
assert all(
|
||||
isinstance(o, (torch.Tensor, torch.SymInt)) for o in operands
|
||||
), f"invoke_subgraph operands must be a list of tensors and SymInts {operands}"
|
||||
isinstance(o, (torch.Tensor, int, torch.SymInt)) for o in operands
|
||||
), f"invoke_subgraph operands must be a list of tensors/ints/SymInts {operands}"
|
||||
|
||||
return super().__call__(subgraph, identifier, operands)
|
||||
|
||||
|
|
@ -188,14 +190,14 @@ class InvokeSubgraphAutogradOp(torch.autograd.Function):
|
|||
operands,
|
||||
)
|
||||
|
||||
ctx.save_for_backward(*operands)
|
||||
save_tensors_and_symints_for_backward(ctx, operands)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_outs):
|
||||
bw_graph = ctx._bw_graph
|
||||
identifier = ctx._identifier
|
||||
primals = ctx.saved_tensors
|
||||
primals = saved_tensors_and_symints(ctx)
|
||||
num_fw_outs = ctx._num_fw_outs
|
||||
|
||||
# While tracing we made the assumption that tangents are contiguous. So,
|
||||
|
|
|
|||
|
|
@ -6784,12 +6784,15 @@ class InvokeSubgraph(ExternKernel):
|
|||
def handle_sym_expr(stride): # type: ignore[no-untyped-def]
|
||||
return [s.node.expr if isinstance(s, torch.SymInt) else s for s in stride]
|
||||
|
||||
fake_strides = [fake_operand.stride() for fake_operand in fake_operands]
|
||||
fake_strides = [handle_sym_expr(stride) for stride in fake_strides]
|
||||
operands = [
|
||||
cls.require_exact_strides(x, fake_strides[idx])
|
||||
for idx, x in enumerate(operands)
|
||||
]
|
||||
new_operands = []
|
||||
for idx, operand in enumerate(operands):
|
||||
if isinstance(operand, ShapeAsConstantBuffer):
|
||||
new_operands.append(operand)
|
||||
else:
|
||||
example_stride = handle_sym_expr(fake_operands[idx].stride())
|
||||
new_operands.append(cls.require_exact_strides(operand, example_stride))
|
||||
|
||||
operands = new_operands
|
||||
|
||||
if subgraph.graph is None:
|
||||
# create and lower subgraphs
|
||||
|
|
@ -6802,7 +6805,16 @@ class InvokeSubgraph(ExternKernel):
|
|||
subgraph.graph.run(*fake_operands)
|
||||
|
||||
outputs = subgraph.graph.graph_outputs
|
||||
device = operands[0].get_device()
|
||||
|
||||
# Find the device - operands could be integers from shapes, so we can't
|
||||
# use operands[0]
|
||||
device = None
|
||||
for operand in operands:
|
||||
if not isinstance(operand, ShapeAsConstantBuffer):
|
||||
device = operand.get_device()
|
||||
break
|
||||
assert device is not None
|
||||
|
||||
invoke_subgraph = InvokeSubgraph(
|
||||
subgraph=subgraph,
|
||||
operands=operands,
|
||||
|
|
|
|||
Loading…
Reference in a new issue