[while_loop][aot] auto-unspecialize int input and output to unbacked symints (#143105)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143105
Approved by: https://github.com/zou3519
This commit is contained in:
Yidi Wu 2025-01-02 14:19:45 -08:00 committed by PyTorch MergeBot
parent 56f6289f6a
commit 6e8dca9ff3
2 changed files with 343 additions and 32 deletions

View file

@ -3815,6 +3815,13 @@ class TestControlFlowTraced(TestCase):
compiled_fn = torch.compile(fn, backend=backend)
self.assertEqual(compiled_fn(*args), eager_res)
def _check_export(self, fn, args, *, strict=False, dynamic_shapes=None):
eg_out = fn(*args)
ep = torch.export.export(fn, args, strict=strict, dynamic_shapes=dynamic_shapes)
ep_out = ep.module()(*args)
self.assertEqual(eg_out, ep_out)
return ep
def test_cond_traced_not_nested(self):
def true_fn(x):
return x.sin()
@ -6437,6 +6444,184 @@ def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor, L_add_closure_0_
self.assertEqual(eager_out, exp_out)
self.assertEqual(compiled_out, exp_out)
# TODO: add dynamo support for int carries
@skipIfTorchDynamo("Skip because we haven't support dynamo")
@parametrize("strict", [False])
@parametrize("dynamic", [True, False])
def test_while_loop_op_int_carry_export(self, strict, dynamic):
class Mod(torch.nn.Module):
def forward(self, x):
def cond_fn(it, x):
return it < x.shape[0]
def body_fn(it, x):
x_clone = x.clone()
# Need these checks to select from x
torch._check(it >= 0)
torch._check(it < x.shape[0])
x_clone.select(0, it).copy_(x_clone.select(0, it) + it)
return it + 1, x_clone
# We invoke the hop directly to avoid triggering dyanmo tracing
out_it, out_x = torch.ops.higher_order.while_loop(
cond_fn, body_fn, (0, x), tuple()
)
# We need torch._check to use it in torch.ones call
torch._check(out_it > 0)
return (
out_it + 1,
out_it + out_x,
out_it < x.shape[0],
torch.ones(out_it * 2),
)
# Eager Run:
x = torch.randn((2, 3), requires_grad=True)
m = Mod()
dynamic_shapes = None
if dynamic:
dynamic_shapes = {"x": {0: torch.export.Dim("dim_x")}}
ep = self._check_export(m, (x,), strict=strict, dynamic_shapes=dynamic_shapes)
if not strict and dynamic:
self.assertExpectedInline(
normalize_gm(ep.module().print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, x):
x: "f32[s0, 3]";
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
sym_size_int_1: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0)
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
while_loop_body_graph_0 = self.while_loop_body_graph_0
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (0, x), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = x = None
getitem: "Sym(u1)" = while_loop[0]
getitem_1: "f32[s0, 3]" = while_loop[1]; while_loop = None
add: "Sym(u1 + 1)" = getitem + 1
add_1: "f32[s0, 3]" = torch.ops.aten.add.Tensor(getitem_1, getitem); getitem_1 = None
lt: "Sym(u1 < s0)" = getitem < sym_size_int_1; sym_size_int_1 = None
mul: "Sym(2*u1)" = getitem * 2; getitem = None
ones: "f32[2*u1]" = torch.ops.aten.ones.default([mul], device = device(type='cpu'), pin_memory = False); mul = None
return pytree.tree_unflatten((add, add_1, lt, ones), self._out_spec)
class while_loop_cond_graph_0(torch.nn.Module):
def forward(self, it_1: "Sym(u0)", x_1: "f32[s0, 3]"):
sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None
lt: "Sym(u0 < s0)" = it_1 < sym_size_int; it_1 = sym_size_int = None
return lt
class while_loop_body_graph_0(torch.nn.Module):
def forward(self, it_1: "Sym(u0)", x_1: "f32[s0, 3]"):
clone: "f32[s0, 3]" = torch.ops.aten.clone.default(x_1); x_1 = None
select: "f32[3]" = torch.ops.aten.select.int(clone, 0, it_1)
select_1: "f32[3]" = torch.ops.aten.select.int(clone, 0, it_1)
add: "f32[3]" = torch.ops.aten.add.Tensor(select_1, it_1); select_1 = None
copy_: "f32[3]" = torch.ops.aten.copy_.default(select, add); select = add = copy_ = None
add_1: "Sym(u0 + 1)" = it_1 + 1; it_1 = None
return (add_1, clone)
""", # noqa: B950
)
# TODO: add dynamo support for int carries
@skipIfTorchDynamo("Skip because we haven't support dynamo")
@parametrize("strict", [False])
@parametrize("dynamic", [True, False])
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_while_loop_op_constant_and_symint_output(self, strict, dynamic):
class Mod(torch.nn.Module):
def forward(self, t):
a = t.shape[0]
b = t.shape[1]
def cond_fn(a, b, c1, c2, c3, c0, u0, x):
return c1 * c2 * c3 < a * b
def body_fn(a, b, c1, c2, c3, c0, u0, x):
return b, c1, c2, c3, a, 0, u0 + 1, x + 1
carry = (a, b, 1, 1, 1, a + 1, t.sum().to(torch.int64).item(), t.sin())
out_it = torch.ops.higher_order.while_loop(
cond_fn, body_fn, carry, tuple()
)
out_inc = pytree.tree_map(lambda x: x + 1, out_it)
out_add = pytree.tree_map(lambda x: x + t, out_it)
return out_inc, out_add
dynamic_shapes = {"t": {0: torch.export.Dim("dim_t")}} if dynamic else None
x = torch.randn(2, 3, requires_grad=True) # trigger autograd key
m = Mod()
ep = self._check_export(m, (x,), strict=strict, dynamic_shapes=dynamic_shapes)
if not strict and dynamic:
self.assertExpectedInline(
normalize_gm(ep.module().print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, t):
t: "f32[s0, 3]";
t, = fx_pytree.tree_flatten_spec(([t], {}), self._in_spec)
sym_size_int_1: "Sym(s0)" = torch.ops.aten.sym_size.int(t, 0)
add: "Sym(s0 + 1)" = sym_size_int_1 + 1
sum_1: "f32[]" = torch.ops.aten.sum.default(t)
to: "i64[]" = torch.ops.aten.to.dtype(sum_1, torch.int64); sum_1 = None
item: "Sym(u0)" = torch.ops.aten.item.default(to); to = None
sin: "f32[s0, 3]" = torch.ops.aten.sin.default(t)
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
while_loop_body_graph_0 = self.while_loop_body_graph_0
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (sym_size_int_1, 3, 1, 1, 1, add, item, sin), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = sym_size_int_1 = add = item = sin = None
getitem: "Sym(u8)" = while_loop[0]
getitem_1: "Sym(u9)" = while_loop[1]
getitem_2: "Sym(u10)" = while_loop[2]
getitem_3: "Sym(u11)" = while_loop[3]
getitem_4: "Sym(u12)" = while_loop[4]
getitem_5: "Sym(u13)" = while_loop[5]
getitem_6: "Sym(u14)" = while_loop[6]
getitem_7: "f32[s0, 3]" = while_loop[7]; while_loop = None
add_1: "Sym(u8 + 1)" = getitem + 1
add_2: "Sym(u9 + 1)" = getitem_1 + 1
add_3: "Sym(u10 + 1)" = getitem_2 + 1
add_4: "Sym(u11 + 1)" = getitem_3 + 1
add_5: "Sym(u12 + 1)" = getitem_4 + 1
add_6: "Sym(u13 + 1)" = getitem_5 + 1
add_7: "Sym(u14 + 1)" = getitem_6 + 1
add_8: "f32[s0, 3]" = torch.ops.aten.add.Tensor(getitem_7, 1)
add_9: "f32[s0, 3]" = torch.ops.aten.add.Tensor(t, getitem); getitem = None
add_10: "f32[s0, 3]" = torch.ops.aten.add.Tensor(t, getitem_1); getitem_1 = None
add_11: "f32[s0, 3]" = torch.ops.aten.add.Tensor(t, getitem_2); getitem_2 = None
add_12: "f32[s0, 3]" = torch.ops.aten.add.Tensor(t, getitem_3); getitem_3 = None
add_13: "f32[s0, 3]" = torch.ops.aten.add.Tensor(t, getitem_4); getitem_4 = None
add_14: "f32[s0, 3]" = torch.ops.aten.add.Tensor(t, getitem_5); getitem_5 = None
add_15: "f32[s0, 3]" = torch.ops.aten.add.Tensor(t, getitem_6); getitem_6 = None
add_16: "f32[s0, 3]" = torch.ops.aten.add.Tensor(getitem_7, t); getitem_7 = t = None
return pytree.tree_unflatten((add_1, add_2, add_3, add_4, add_5, add_6, add_7, add_8, add_9, add_10, add_11, add_12, add_13, add_14, add_15, add_16), self._out_spec)
class while_loop_cond_graph_0(torch.nn.Module):
def forward(self, a_1: "Sym(u1)", b_1: "Sym(u2)", c1_1: "Sym(u3)", c2_1: "Sym(u4)", c3_1: "Sym(u5)", c0_1: "Sym(u6)", u0_1: "Sym(u7)", x_1: "f32[s0, 3]"):
mul: "Sym(u3*u4)" = c1_1 * c2_1; c1_1 = c2_1 = None
mul_1: "Sym(u3*u4*u5)" = mul * c3_1; mul = c3_1 = None
mul_2: "Sym(u1*u2)" = a_1 * b_1; a_1 = b_1 = None
lt: "Sym(u3*u4*u5 < u1*u2)" = mul_1 < mul_2; mul_1 = mul_2 = None
return lt
class while_loop_body_graph_0(torch.nn.Module):
def forward(self, a_1: "Sym(u1)", b_1: "Sym(u2)", c1_1: "Sym(u3)", c2_1: "Sym(u4)", c3_1: "Sym(u5)", c0_1: "Sym(u6)", u0_1: "Sym(u7)", x_1: "f32[s0, 3]"):
add: "Sym(u7 + 1)" = u0_1 + 1; u0_1 = None
add_1: "f32[s0, 3]" = torch.ops.aten.add.Tensor(x_1, 1); x_1 = None
return (b_1, c1_1, c2_1, c3_1, a_1, 0, add, add_1)
""", # noqa: B950
)
_hop_schema_test_schema_types = [
"bool",

View file

@ -75,14 +75,16 @@ def while_loop(cond_fn, body_fn, carried_inputs):
return val
Args:
cond_fn (Callable): A callable function that returns a boolean Scalar tensor.
cond_fn (Callable): A callable function that returns a boolean Scalar tensor or a python boolean.
body_fn (Callable): A callable function that takes the same inputs as `cond_fn` and returns a tuple of tensors
body_fn (Callable): A callable function that takes the same inputs as `cond_fn` and returns a tuple of tensors or ints
carried_inputs (Tuple of possibly nested dict/list/tuple of tensors): A tuple of inputs to cond_fn and body_fn. It's also
the initial value of states that are carried across iterations.
carried_inputs (Tuple of possibly nested dict/list/tuple of tensors or ints): A tuple of inputs to cond_fn and body_fn.
It's also the initial value of states that are carried across iterations. Note that when pass an integer as carry,
the corresponding return of while_loop will be another int with unknown values because we don't know how many
iterations while_loop will run.
Example:
Example 1:
def cond_fn(iter, x):
return iter.sum() < 10
@ -92,9 +94,19 @@ def while_loop(cond_fn, body_fn, carried_inputs):
while_loop(cond_fn, body_fn, (torch.zeros(1), torch.randn(3, 4)))
Example 2:
def cond_fn(int_iter, x):
return 2 * int_iter < x.shape[0]
def body_fn(int_iter, x):
return int_iter + 1, x + int_iter
while_loop(cond,_fn, body_fn, (0, torch.randn(3, 4)))
Restrictions:
- body_fn must return tensors with the same metadata (e.g.shape, dtype) as inputs.
- body_fn must return tensors or int with the same metadata (e.g.shape, dtype) as inputs.
- body_fn and cond_fn must not in-place mutate the carried_inputs. A clone before the mutation is required.
@ -171,12 +183,17 @@ def while_loop(cond_fn, body_fn, carried_inputs):
def while_loop_dense(cond_fn, body_fn, carried_inputs, additional_inputs):
carried_vals = carried_inputs
def _is_boolean_scalar_tensor(pred):
return (
def _validate_cond_output(pred):
if (
isinstance(pred, torch.Tensor)
and pred.size() == torch.Size([])
and pred.dtype == torch.bool
)
) or isinstance(pred, bool):
return
else:
raise RuntimeError(
f"cond_fn must return a boolean scalar tensor or a boolean but got {pred}"
)
if not isinstance(carried_inputs, tuple):
raise RuntimeError(
@ -184,10 +201,7 @@ def while_loop_dense(cond_fn, body_fn, carried_inputs, additional_inputs):
)
while pred := cond_fn(*carried_vals, *additional_inputs):
if not _is_boolean_scalar_tensor(pred):
raise RuntimeError(
f"cond_fn must return a boolean scalar tensor but got {pred}"
)
_validate_cond_output(pred)
out = body_fn(*carried_vals, *additional_inputs)
assert isinstance(
out, tuple
@ -204,13 +218,79 @@ while_loop_op.py_impl(DispatchKey.Autograd)(
)
def _find_or_create_fake_mode() -> FakeTensorMode:
from torch.fx.experimental.symbolic_shapes import ShapeEnv
fake_mode = torch._guards.detect_fake_mode()
if fake_mode is None:
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
return fake_mode
def _create_unbacked_symint(fake_mode: FakeTensorMode) -> torch.SymInt:
assert (
fake_mode is not None and fake_mode.shape_env is not None
), "Must provide a fake_mode with shape_env."
with fake_mode.shape_env.ignore_fresh_unbacked_symbols():
return fake_mode.shape_env.create_unbacked_symint()
@while_loop_op.py_impl(ProxyTorchDispatchMode)
def while_loop_tracing(mode, cond_fn, body_fn, carried_inputs, additional_inputs):
def _trace_while_loop(
proxy_mode, while_loop_op, cond_fn, body_fn, carried_inputs, additional_inputs
):
cond_graph = reenter_make_fx(cond_fn)(*carried_inputs, *additional_inputs)
body_graph = reenter_make_fx(body_fn)(*carried_inputs, *additional_inputs)
# NOTE [unspecialize int carry with unbacked symints]
# When we support int carry, we'll also need to support int output of body_fn because.
# previous iteration's output is next iteration's input and they must match.
# For carries, when we start tracing while_loop, they can be
# - constants e.g. (0, [1, 3])
# - backed symints (x.shape[0], [x.shape[1] + x.stride[1], x.shape[2]])
# - unbacked symints e.g. (u0, [u0 + u1, u2])
# We choose the most conservative design: in all cases, we create new unbacked symints to trace the
# subgraph. It's possible to do some analysis on initial carry and the output of first
# iteration to determine a better range for the output unbacked symbol e.g. when input is an unbacked
# symint >= 0 before the while_loop but in general this is difficult because we don't know
# the number of iterations. Users would have to re-constrain the unbacked symint in subgraph if needed.
#
# For output of fake cond_fn, it could be constant bool or SymBool (e.g. return x.shape[0] < 4,
# where x.shape[0] can be either static of dynamic). In the case of constant bool, we should do a
# specialization (NYI).
# For output of fake body_fn, it could be all three types though from user's point of view,
# they're all integers e.g.
# init_carry = (0, s0, u1, t)
# def body_fn(u0, s0, u1, t):
# ...
# return (t.shape[0], t.shape[1], t.shape[2], y + 1)
#
# It may seem that a constant output isn't possible: users shouldn't write a while_loop
# that always return 0. But it could be that a shape is not set as dynamic properly (e.g.
# automatic dynamic hasn't been triggered).
#
# For this reason, we treat int, symint outputs in the same way:
# - they can match against any of int, symint carry
# - we unspecialize them with new unbacked symints in fake while_loop
# Similarly, we could do some analysis to refine the output ranges but it's eaiser to start with
# fresh unbacked symints. One suprising case can be: an input unbacked symint is constrained by
# users to be >= 0 (either before while_loop or inside body_fn) and it increments by 1 in each
# iteration. Ideally, we should know that the final output is >= 0 but we didn't constrain the
# unbacked symint output of subgraph as of today because this requires a smart range analysis.
fake_mode: FakeTensorMode = _find_or_create_fake_mode()
unspecialized_carried_inputs = pytree.tree_map_only(
(int, torch.SymInt),
lambda _: _create_unbacked_symint(fake_mode),
carried_inputs,
)
cond_graph = reenter_make_fx(cond_fn)(
*unspecialized_carried_inputs, *additional_inputs
)
body_graph = reenter_make_fx(body_fn)(
*unspecialized_carried_inputs, *additional_inputs
)
next_name = None
i = 0
@ -235,7 +315,9 @@ def while_loop_tracing(mode, cond_fn, body_fn, carried_inputs, additional_inputs
"call_function", while_loop_op, proxy_args, {}, name="while_loop"
)
out = while_loop_op(cond_graph, body_graph, carried_inputs, additional_inputs)
out = while_loop_op(
cond_graph, body_graph, unspecialized_carried_inputs, additional_inputs
)
return track_tensor_tree(
out, out_proxy, constant=None, tracer=proxy_mode.tracer
)
@ -246,16 +328,57 @@ def while_loop_tracing(mode, cond_fn, body_fn, carried_inputs, additional_inputs
def check_outputs_carry_consistency(
outs: List[torch.Tensor], carries: List[torch.Tensor]
outs: List[Union[torch.Tensor, torch.SymInt, int]],
carries: List[Union[torch.Tensor, torch.SymInt, int]],
) -> None:
all_diffs_in_meta = []
for out, cry in zip(outs, carries):
if diff := diff_tensor_meta(
_extract_tensor_metadata(cry), _extract_tensor_metadata(out)
):
all_diffs_in_meta.append(",".join(diff))
if all_diffs_in_meta:
diff_str = "\n".join(all_diffs_in_meta)
def diff_meta_pairs(
lhs_list: List[Union[torch.Tensor, torch.SymInt, int]],
rhs_list: List[Union[torch.Tensor, torch.SymInt, int]],
) -> List[str]:
def diff_meta(
lhs: Union[torch.Tensor, torch.SymInt, int],
rhs: Union[torch.Tensor, torch.SymInt, int],
) -> str:
if isinstance(lhs, torch.Tensor) and isinstance(rhs, torch.Tensor):
return ", ".join(
diff_tensor_meta(
# We set include contiguity=False because we have vmap x cond tests, where if
# include_contiguity=True will call t.is_contiguous inside of vmap and get an error
# "querying is_contiguous inside of vmap for memory_format other than
# torch.contiguous_format is not yet implemented". This is good for because stride
# is still checked.
_extract_tensor_metadata(lhs, include_contiguity=False),
_extract_tensor_metadata(rhs, include_contiguity=False),
check_grad=False,
)
)
else:
def _both_int_types(lhs, rhs):
return isinstance(lhs, (int, torch.SymInt)) and isinstance(
rhs, (int, torch.SymInt)
)
def _both_tensor(lhs, rhs):
return isinstance(lhs, torch.Tensor) and isinstance(
rhs, torch.Tensor
)
if not _both_int_types(lhs, rhs) and not _both_tensor(lhs, rhs):
return f"type: {lhs} vs {rhs}"
return ""
all_diffs = []
for i, (lhs, rhs) in enumerate(zip(lhs_list, rhs_list)):
if diff := diff_meta(lhs, rhs):
all_diffs.append(
f"pair[{i}] differ in {diff}, where lhs is {lhs} and rhs is {rhs}"
)
return all_diffs
if all_diffs := diff_meta_pairs(outs, carries):
diff_str = "\n".join(all_diffs)
raise RuntimeError(
f"Expected carried_inputs and body outputs return tensors with same metadata but found:\n{diff_str}"
)
@ -266,12 +389,12 @@ def while_loop_fake_tensor_mode(
mode, cond_fn, body_fn, carried_inputs, additional_inputs
):
with mode:
# NOTE: [Handling unback symints created in subgraph of while_loop]
# NOTE: [Handling unback symints in subgraph of while_loop]
# The idea is that the scope of unbacked symints are limited to the subgraph.
#
# We're implementing the fake tensor mode of while_loop operator.
# and we run body_fn once to get an fake output.
# Let's only consider tensor output for now:
# Let's first consider the case that unbacked symints are tensor shapes:
#
# Case 1:
# if the unbacked symints is local to the subgraph e.g.
@ -282,8 +405,8 @@ def while_loop_fake_tensor_mode(
# no effect on the output of while_loop and it's tracked when we tracing.
# the subgraph.
#
# Case 2.1:
# if the unbacked symints are part of output of while_loop e.g.
# Case 2:
# if the unbacked symints are shape of output of while_loop e.g.
# def body_fn(it, x):
# nz = x.nonzero()
# return it+1, nz
@ -291,8 +414,8 @@ def while_loop_fake_tensor_mode(
# must match the output shape as nz.shape contains newly allocated unbacked symint, this
# won't match the carried_input's shape.
#
# Case 2.2:
# if the unbacked symints are part of carried_inputs e.g.
# Case 3:
# if the unbacked symints are shape of carried_inputs e.g.
# nz = a.nonzero()
# body_fn(it, nz):
# return it+1. nz.sin() + 1,
@ -302,7 +425,10 @@ def while_loop_fake_tensor_mode(
# so we could just return the output after one iteration.
body_outs = body_fn(*carried_inputs, *additional_inputs)
check_outputs_carry_consistency(body_outs, carried_inputs)
return body_outs
# See NOTE [unspecialize int carry with unbacked symints]
return pytree.tree_map_only(
(int, torch.SymInt), lambda _: _create_unbacked_symint(mode), body_outs
)
@while_loop_op.py_functionalize_impl