mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[while_loop][aot] auto-unspecialize int input and output to unbacked symints (#143105)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143105 Approved by: https://github.com/zou3519
This commit is contained in:
parent
56f6289f6a
commit
6e8dca9ff3
2 changed files with 343 additions and 32 deletions
|
|
@ -3815,6 +3815,13 @@ class TestControlFlowTraced(TestCase):
|
|||
compiled_fn = torch.compile(fn, backend=backend)
|
||||
self.assertEqual(compiled_fn(*args), eager_res)
|
||||
|
||||
def _check_export(self, fn, args, *, strict=False, dynamic_shapes=None):
|
||||
eg_out = fn(*args)
|
||||
ep = torch.export.export(fn, args, strict=strict, dynamic_shapes=dynamic_shapes)
|
||||
ep_out = ep.module()(*args)
|
||||
self.assertEqual(eg_out, ep_out)
|
||||
return ep
|
||||
|
||||
def test_cond_traced_not_nested(self):
|
||||
def true_fn(x):
|
||||
return x.sin()
|
||||
|
|
@ -6437,6 +6444,184 @@ def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor, L_add_closure_0_
|
|||
self.assertEqual(eager_out, exp_out)
|
||||
self.assertEqual(compiled_out, exp_out)
|
||||
|
||||
# TODO: add dynamo support for int carries
|
||||
@skipIfTorchDynamo("Skip because we haven't support dynamo")
|
||||
@parametrize("strict", [False])
|
||||
@parametrize("dynamic", [True, False])
|
||||
def test_while_loop_op_int_carry_export(self, strict, dynamic):
|
||||
class Mod(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
def cond_fn(it, x):
|
||||
return it < x.shape[0]
|
||||
|
||||
def body_fn(it, x):
|
||||
x_clone = x.clone()
|
||||
# Need these checks to select from x
|
||||
torch._check(it >= 0)
|
||||
torch._check(it < x.shape[0])
|
||||
x_clone.select(0, it).copy_(x_clone.select(0, it) + it)
|
||||
return it + 1, x_clone
|
||||
|
||||
# We invoke the hop directly to avoid triggering dyanmo tracing
|
||||
out_it, out_x = torch.ops.higher_order.while_loop(
|
||||
cond_fn, body_fn, (0, x), tuple()
|
||||
)
|
||||
# We need torch._check to use it in torch.ones call
|
||||
torch._check(out_it > 0)
|
||||
return (
|
||||
out_it + 1,
|
||||
out_it + out_x,
|
||||
out_it < x.shape[0],
|
||||
torch.ones(out_it * 2),
|
||||
)
|
||||
|
||||
# Eager Run:
|
||||
x = torch.randn((2, 3), requires_grad=True)
|
||||
m = Mod()
|
||||
dynamic_shapes = None
|
||||
if dynamic:
|
||||
dynamic_shapes = {"x": {0: torch.export.Dim("dim_x")}}
|
||||
|
||||
ep = self._check_export(m, (x,), strict=strict, dynamic_shapes=dynamic_shapes)
|
||||
|
||||
if not strict and dynamic:
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(ep.module().print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
x: "f32[s0, 3]";
|
||||
|
||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
sym_size_int_1: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0)
|
||||
|
||||
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
|
||||
while_loop_body_graph_0 = self.while_loop_body_graph_0
|
||||
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (0, x), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = x = None
|
||||
getitem: "Sym(u1)" = while_loop[0]
|
||||
getitem_1: "f32[s0, 3]" = while_loop[1]; while_loop = None
|
||||
|
||||
add: "Sym(u1 + 1)" = getitem + 1
|
||||
|
||||
add_1: "f32[s0, 3]" = torch.ops.aten.add.Tensor(getitem_1, getitem); getitem_1 = None
|
||||
|
||||
lt: "Sym(u1 < s0)" = getitem < sym_size_int_1; sym_size_int_1 = None
|
||||
|
||||
mul: "Sym(2*u1)" = getitem * 2; getitem = None
|
||||
ones: "f32[2*u1]" = torch.ops.aten.ones.default([mul], device = device(type='cpu'), pin_memory = False); mul = None
|
||||
return pytree.tree_unflatten((add, add_1, lt, ones), self._out_spec)
|
||||
|
||||
class while_loop_cond_graph_0(torch.nn.Module):
|
||||
def forward(self, it_1: "Sym(u0)", x_1: "f32[s0, 3]"):
|
||||
sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None
|
||||
lt: "Sym(u0 < s0)" = it_1 < sym_size_int; it_1 = sym_size_int = None
|
||||
return lt
|
||||
|
||||
class while_loop_body_graph_0(torch.nn.Module):
|
||||
def forward(self, it_1: "Sym(u0)", x_1: "f32[s0, 3]"):
|
||||
clone: "f32[s0, 3]" = torch.ops.aten.clone.default(x_1); x_1 = None
|
||||
select: "f32[3]" = torch.ops.aten.select.int(clone, 0, it_1)
|
||||
select_1: "f32[3]" = torch.ops.aten.select.int(clone, 0, it_1)
|
||||
add: "f32[3]" = torch.ops.aten.add.Tensor(select_1, it_1); select_1 = None
|
||||
copy_: "f32[3]" = torch.ops.aten.copy_.default(select, add); select = add = copy_ = None
|
||||
add_1: "Sym(u0 + 1)" = it_1 + 1; it_1 = None
|
||||
return (add_1, clone)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
# TODO: add dynamo support for int carries
|
||||
@skipIfTorchDynamo("Skip because we haven't support dynamo")
|
||||
@parametrize("strict", [False])
|
||||
@parametrize("dynamic", [True, False])
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_while_loop_op_constant_and_symint_output(self, strict, dynamic):
|
||||
class Mod(torch.nn.Module):
|
||||
def forward(self, t):
|
||||
a = t.shape[0]
|
||||
b = t.shape[1]
|
||||
|
||||
def cond_fn(a, b, c1, c2, c3, c0, u0, x):
|
||||
return c1 * c2 * c3 < a * b
|
||||
|
||||
def body_fn(a, b, c1, c2, c3, c0, u0, x):
|
||||
return b, c1, c2, c3, a, 0, u0 + 1, x + 1
|
||||
|
||||
carry = (a, b, 1, 1, 1, a + 1, t.sum().to(torch.int64).item(), t.sin())
|
||||
out_it = torch.ops.higher_order.while_loop(
|
||||
cond_fn, body_fn, carry, tuple()
|
||||
)
|
||||
out_inc = pytree.tree_map(lambda x: x + 1, out_it)
|
||||
out_add = pytree.tree_map(lambda x: x + t, out_it)
|
||||
return out_inc, out_add
|
||||
|
||||
dynamic_shapes = {"t": {0: torch.export.Dim("dim_t")}} if dynamic else None
|
||||
x = torch.randn(2, 3, requires_grad=True) # trigger autograd key
|
||||
m = Mod()
|
||||
ep = self._check_export(m, (x,), strict=strict, dynamic_shapes=dynamic_shapes)
|
||||
if not strict and dynamic:
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(ep.module().print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, t):
|
||||
t: "f32[s0, 3]";
|
||||
|
||||
t, = fx_pytree.tree_flatten_spec(([t], {}), self._in_spec)
|
||||
sym_size_int_1: "Sym(s0)" = torch.ops.aten.sym_size.int(t, 0)
|
||||
|
||||
add: "Sym(s0 + 1)" = sym_size_int_1 + 1
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(t)
|
||||
to: "i64[]" = torch.ops.aten.to.dtype(sum_1, torch.int64); sum_1 = None
|
||||
item: "Sym(u0)" = torch.ops.aten.item.default(to); to = None
|
||||
sin: "f32[s0, 3]" = torch.ops.aten.sin.default(t)
|
||||
|
||||
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
|
||||
while_loop_body_graph_0 = self.while_loop_body_graph_0
|
||||
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (sym_size_int_1, 3, 1, 1, 1, add, item, sin), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = sym_size_int_1 = add = item = sin = None
|
||||
getitem: "Sym(u8)" = while_loop[0]
|
||||
getitem_1: "Sym(u9)" = while_loop[1]
|
||||
getitem_2: "Sym(u10)" = while_loop[2]
|
||||
getitem_3: "Sym(u11)" = while_loop[3]
|
||||
getitem_4: "Sym(u12)" = while_loop[4]
|
||||
getitem_5: "Sym(u13)" = while_loop[5]
|
||||
getitem_6: "Sym(u14)" = while_loop[6]
|
||||
getitem_7: "f32[s0, 3]" = while_loop[7]; while_loop = None
|
||||
|
||||
add_1: "Sym(u8 + 1)" = getitem + 1
|
||||
add_2: "Sym(u9 + 1)" = getitem_1 + 1
|
||||
add_3: "Sym(u10 + 1)" = getitem_2 + 1
|
||||
add_4: "Sym(u11 + 1)" = getitem_3 + 1
|
||||
add_5: "Sym(u12 + 1)" = getitem_4 + 1
|
||||
add_6: "Sym(u13 + 1)" = getitem_5 + 1
|
||||
add_7: "Sym(u14 + 1)" = getitem_6 + 1
|
||||
add_8: "f32[s0, 3]" = torch.ops.aten.add.Tensor(getitem_7, 1)
|
||||
|
||||
add_9: "f32[s0, 3]" = torch.ops.aten.add.Tensor(t, getitem); getitem = None
|
||||
add_10: "f32[s0, 3]" = torch.ops.aten.add.Tensor(t, getitem_1); getitem_1 = None
|
||||
add_11: "f32[s0, 3]" = torch.ops.aten.add.Tensor(t, getitem_2); getitem_2 = None
|
||||
add_12: "f32[s0, 3]" = torch.ops.aten.add.Tensor(t, getitem_3); getitem_3 = None
|
||||
add_13: "f32[s0, 3]" = torch.ops.aten.add.Tensor(t, getitem_4); getitem_4 = None
|
||||
add_14: "f32[s0, 3]" = torch.ops.aten.add.Tensor(t, getitem_5); getitem_5 = None
|
||||
add_15: "f32[s0, 3]" = torch.ops.aten.add.Tensor(t, getitem_6); getitem_6 = None
|
||||
add_16: "f32[s0, 3]" = torch.ops.aten.add.Tensor(getitem_7, t); getitem_7 = t = None
|
||||
return pytree.tree_unflatten((add_1, add_2, add_3, add_4, add_5, add_6, add_7, add_8, add_9, add_10, add_11, add_12, add_13, add_14, add_15, add_16), self._out_spec)
|
||||
|
||||
class while_loop_cond_graph_0(torch.nn.Module):
|
||||
def forward(self, a_1: "Sym(u1)", b_1: "Sym(u2)", c1_1: "Sym(u3)", c2_1: "Sym(u4)", c3_1: "Sym(u5)", c0_1: "Sym(u6)", u0_1: "Sym(u7)", x_1: "f32[s0, 3]"):
|
||||
mul: "Sym(u3*u4)" = c1_1 * c2_1; c1_1 = c2_1 = None
|
||||
mul_1: "Sym(u3*u4*u5)" = mul * c3_1; mul = c3_1 = None
|
||||
mul_2: "Sym(u1*u2)" = a_1 * b_1; a_1 = b_1 = None
|
||||
lt: "Sym(u3*u4*u5 < u1*u2)" = mul_1 < mul_2; mul_1 = mul_2 = None
|
||||
return lt
|
||||
|
||||
class while_loop_body_graph_0(torch.nn.Module):
|
||||
def forward(self, a_1: "Sym(u1)", b_1: "Sym(u2)", c1_1: "Sym(u3)", c2_1: "Sym(u4)", c3_1: "Sym(u5)", c0_1: "Sym(u6)", u0_1: "Sym(u7)", x_1: "f32[s0, 3]"):
|
||||
add: "Sym(u7 + 1)" = u0_1 + 1; u0_1 = None
|
||||
add_1: "f32[s0, 3]" = torch.ops.aten.add.Tensor(x_1, 1); x_1 = None
|
||||
return (b_1, c1_1, c2_1, c3_1, a_1, 0, add, add_1)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
|
||||
_hop_schema_test_schema_types = [
|
||||
"bool",
|
||||
|
|
|
|||
|
|
@ -75,14 +75,16 @@ def while_loop(cond_fn, body_fn, carried_inputs):
|
|||
return val
|
||||
|
||||
Args:
|
||||
cond_fn (Callable): A callable function that returns a boolean Scalar tensor.
|
||||
cond_fn (Callable): A callable function that returns a boolean Scalar tensor or a python boolean.
|
||||
|
||||
body_fn (Callable): A callable function that takes the same inputs as `cond_fn` and returns a tuple of tensors
|
||||
body_fn (Callable): A callable function that takes the same inputs as `cond_fn` and returns a tuple of tensors or ints
|
||||
|
||||
carried_inputs (Tuple of possibly nested dict/list/tuple of tensors): A tuple of inputs to cond_fn and body_fn. It's also
|
||||
the initial value of states that are carried across iterations.
|
||||
carried_inputs (Tuple of possibly nested dict/list/tuple of tensors or ints): A tuple of inputs to cond_fn and body_fn.
|
||||
It's also the initial value of states that are carried across iterations. Note that when pass an integer as carry,
|
||||
the corresponding return of while_loop will be another int with unknown values because we don't know how many
|
||||
iterations while_loop will run.
|
||||
|
||||
Example:
|
||||
Example 1:
|
||||
|
||||
def cond_fn(iter, x):
|
||||
return iter.sum() < 10
|
||||
|
|
@ -92,9 +94,19 @@ def while_loop(cond_fn, body_fn, carried_inputs):
|
|||
|
||||
while_loop(cond_fn, body_fn, (torch.zeros(1), torch.randn(3, 4)))
|
||||
|
||||
Example 2:
|
||||
|
||||
def cond_fn(int_iter, x):
|
||||
return 2 * int_iter < x.shape[0]
|
||||
|
||||
def body_fn(int_iter, x):
|
||||
return int_iter + 1, x + int_iter
|
||||
|
||||
while_loop(cond,_fn, body_fn, (0, torch.randn(3, 4)))
|
||||
|
||||
Restrictions:
|
||||
|
||||
- body_fn must return tensors with the same metadata (e.g.shape, dtype) as inputs.
|
||||
- body_fn must return tensors or int with the same metadata (e.g.shape, dtype) as inputs.
|
||||
|
||||
- body_fn and cond_fn must not in-place mutate the carried_inputs. A clone before the mutation is required.
|
||||
|
||||
|
|
@ -171,12 +183,17 @@ def while_loop(cond_fn, body_fn, carried_inputs):
|
|||
def while_loop_dense(cond_fn, body_fn, carried_inputs, additional_inputs):
|
||||
carried_vals = carried_inputs
|
||||
|
||||
def _is_boolean_scalar_tensor(pred):
|
||||
return (
|
||||
def _validate_cond_output(pred):
|
||||
if (
|
||||
isinstance(pred, torch.Tensor)
|
||||
and pred.size() == torch.Size([])
|
||||
and pred.dtype == torch.bool
|
||||
)
|
||||
) or isinstance(pred, bool):
|
||||
return
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"cond_fn must return a boolean scalar tensor or a boolean but got {pred}"
|
||||
)
|
||||
|
||||
if not isinstance(carried_inputs, tuple):
|
||||
raise RuntimeError(
|
||||
|
|
@ -184,10 +201,7 @@ def while_loop_dense(cond_fn, body_fn, carried_inputs, additional_inputs):
|
|||
)
|
||||
|
||||
while pred := cond_fn(*carried_vals, *additional_inputs):
|
||||
if not _is_boolean_scalar_tensor(pred):
|
||||
raise RuntimeError(
|
||||
f"cond_fn must return a boolean scalar tensor but got {pred}"
|
||||
)
|
||||
_validate_cond_output(pred)
|
||||
out = body_fn(*carried_vals, *additional_inputs)
|
||||
assert isinstance(
|
||||
out, tuple
|
||||
|
|
@ -204,13 +218,79 @@ while_loop_op.py_impl(DispatchKey.Autograd)(
|
|||
)
|
||||
|
||||
|
||||
def _find_or_create_fake_mode() -> FakeTensorMode:
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
|
||||
fake_mode = torch._guards.detect_fake_mode()
|
||||
if fake_mode is None:
|
||||
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
|
||||
|
||||
return fake_mode
|
||||
|
||||
|
||||
def _create_unbacked_symint(fake_mode: FakeTensorMode) -> torch.SymInt:
|
||||
assert (
|
||||
fake_mode is not None and fake_mode.shape_env is not None
|
||||
), "Must provide a fake_mode with shape_env."
|
||||
with fake_mode.shape_env.ignore_fresh_unbacked_symbols():
|
||||
return fake_mode.shape_env.create_unbacked_symint()
|
||||
|
||||
|
||||
@while_loop_op.py_impl(ProxyTorchDispatchMode)
|
||||
def while_loop_tracing(mode, cond_fn, body_fn, carried_inputs, additional_inputs):
|
||||
def _trace_while_loop(
|
||||
proxy_mode, while_loop_op, cond_fn, body_fn, carried_inputs, additional_inputs
|
||||
):
|
||||
cond_graph = reenter_make_fx(cond_fn)(*carried_inputs, *additional_inputs)
|
||||
body_graph = reenter_make_fx(body_fn)(*carried_inputs, *additional_inputs)
|
||||
# NOTE [unspecialize int carry with unbacked symints]
|
||||
# When we support int carry, we'll also need to support int output of body_fn because.
|
||||
# previous iteration's output is next iteration's input and they must match.
|
||||
# For carries, when we start tracing while_loop, they can be
|
||||
# - constants e.g. (0, [1, 3])
|
||||
# - backed symints (x.shape[0], [x.shape[1] + x.stride[1], x.shape[2]])
|
||||
# - unbacked symints e.g. (u0, [u0 + u1, u2])
|
||||
# We choose the most conservative design: in all cases, we create new unbacked symints to trace the
|
||||
# subgraph. It's possible to do some analysis on initial carry and the output of first
|
||||
# iteration to determine a better range for the output unbacked symbol e.g. when input is an unbacked
|
||||
# symint >= 0 before the while_loop but in general this is difficult because we don't know
|
||||
# the number of iterations. Users would have to re-constrain the unbacked symint in subgraph if needed.
|
||||
#
|
||||
# For output of fake cond_fn, it could be constant bool or SymBool (e.g. return x.shape[0] < 4,
|
||||
# where x.shape[0] can be either static of dynamic). In the case of constant bool, we should do a
|
||||
# specialization (NYI).
|
||||
|
||||
# For output of fake body_fn, it could be all three types though from user's point of view,
|
||||
# they're all integers e.g.
|
||||
|
||||
# init_carry = (0, s0, u1, t)
|
||||
# def body_fn(u0, s0, u1, t):
|
||||
# ...
|
||||
# return (t.shape[0], t.shape[1], t.shape[2], y + 1)
|
||||
#
|
||||
# It may seem that a constant output isn't possible: users shouldn't write a while_loop
|
||||
# that always return 0. But it could be that a shape is not set as dynamic properly (e.g.
|
||||
# automatic dynamic hasn't been triggered).
|
||||
#
|
||||
# For this reason, we treat int, symint outputs in the same way:
|
||||
# - they can match against any of int, symint carry
|
||||
# - we unspecialize them with new unbacked symints in fake while_loop
|
||||
# Similarly, we could do some analysis to refine the output ranges but it's eaiser to start with
|
||||
# fresh unbacked symints. One suprising case can be: an input unbacked symint is constrained by
|
||||
# users to be >= 0 (either before while_loop or inside body_fn) and it increments by 1 in each
|
||||
# iteration. Ideally, we should know that the final output is >= 0 but we didn't constrain the
|
||||
# unbacked symint output of subgraph as of today because this requires a smart range analysis.
|
||||
fake_mode: FakeTensorMode = _find_or_create_fake_mode()
|
||||
unspecialized_carried_inputs = pytree.tree_map_only(
|
||||
(int, torch.SymInt),
|
||||
lambda _: _create_unbacked_symint(fake_mode),
|
||||
carried_inputs,
|
||||
)
|
||||
|
||||
cond_graph = reenter_make_fx(cond_fn)(
|
||||
*unspecialized_carried_inputs, *additional_inputs
|
||||
)
|
||||
body_graph = reenter_make_fx(body_fn)(
|
||||
*unspecialized_carried_inputs, *additional_inputs
|
||||
)
|
||||
|
||||
next_name = None
|
||||
i = 0
|
||||
|
|
@ -235,7 +315,9 @@ def while_loop_tracing(mode, cond_fn, body_fn, carried_inputs, additional_inputs
|
|||
"call_function", while_loop_op, proxy_args, {}, name="while_loop"
|
||||
)
|
||||
|
||||
out = while_loop_op(cond_graph, body_graph, carried_inputs, additional_inputs)
|
||||
out = while_loop_op(
|
||||
cond_graph, body_graph, unspecialized_carried_inputs, additional_inputs
|
||||
)
|
||||
return track_tensor_tree(
|
||||
out, out_proxy, constant=None, tracer=proxy_mode.tracer
|
||||
)
|
||||
|
|
@ -246,16 +328,57 @@ def while_loop_tracing(mode, cond_fn, body_fn, carried_inputs, additional_inputs
|
|||
|
||||
|
||||
def check_outputs_carry_consistency(
|
||||
outs: List[torch.Tensor], carries: List[torch.Tensor]
|
||||
outs: List[Union[torch.Tensor, torch.SymInt, int]],
|
||||
carries: List[Union[torch.Tensor, torch.SymInt, int]],
|
||||
) -> None:
|
||||
all_diffs_in_meta = []
|
||||
for out, cry in zip(outs, carries):
|
||||
if diff := diff_tensor_meta(
|
||||
_extract_tensor_metadata(cry), _extract_tensor_metadata(out)
|
||||
):
|
||||
all_diffs_in_meta.append(",".join(diff))
|
||||
if all_diffs_in_meta:
|
||||
diff_str = "\n".join(all_diffs_in_meta)
|
||||
def diff_meta_pairs(
|
||||
lhs_list: List[Union[torch.Tensor, torch.SymInt, int]],
|
||||
rhs_list: List[Union[torch.Tensor, torch.SymInt, int]],
|
||||
) -> List[str]:
|
||||
def diff_meta(
|
||||
lhs: Union[torch.Tensor, torch.SymInt, int],
|
||||
rhs: Union[torch.Tensor, torch.SymInt, int],
|
||||
) -> str:
|
||||
if isinstance(lhs, torch.Tensor) and isinstance(rhs, torch.Tensor):
|
||||
return ", ".join(
|
||||
diff_tensor_meta(
|
||||
# We set include contiguity=False because we have vmap x cond tests, where if
|
||||
# include_contiguity=True will call t.is_contiguous inside of vmap and get an error
|
||||
# "querying is_contiguous inside of vmap for memory_format other than
|
||||
# torch.contiguous_format is not yet implemented". This is good for because stride
|
||||
# is still checked.
|
||||
_extract_tensor_metadata(lhs, include_contiguity=False),
|
||||
_extract_tensor_metadata(rhs, include_contiguity=False),
|
||||
check_grad=False,
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
||||
def _both_int_types(lhs, rhs):
|
||||
return isinstance(lhs, (int, torch.SymInt)) and isinstance(
|
||||
rhs, (int, torch.SymInt)
|
||||
)
|
||||
|
||||
def _both_tensor(lhs, rhs):
|
||||
return isinstance(lhs, torch.Tensor) and isinstance(
|
||||
rhs, torch.Tensor
|
||||
)
|
||||
|
||||
if not _both_int_types(lhs, rhs) and not _both_tensor(lhs, rhs):
|
||||
return f"type: {lhs} vs {rhs}"
|
||||
|
||||
return ""
|
||||
|
||||
all_diffs = []
|
||||
for i, (lhs, rhs) in enumerate(zip(lhs_list, rhs_list)):
|
||||
if diff := diff_meta(lhs, rhs):
|
||||
all_diffs.append(
|
||||
f"pair[{i}] differ in {diff}, where lhs is {lhs} and rhs is {rhs}"
|
||||
)
|
||||
return all_diffs
|
||||
|
||||
if all_diffs := diff_meta_pairs(outs, carries):
|
||||
diff_str = "\n".join(all_diffs)
|
||||
raise RuntimeError(
|
||||
f"Expected carried_inputs and body outputs return tensors with same metadata but found:\n{diff_str}"
|
||||
)
|
||||
|
|
@ -266,12 +389,12 @@ def while_loop_fake_tensor_mode(
|
|||
mode, cond_fn, body_fn, carried_inputs, additional_inputs
|
||||
):
|
||||
with mode:
|
||||
# NOTE: [Handling unback symints created in subgraph of while_loop]
|
||||
# NOTE: [Handling unback symints in subgraph of while_loop]
|
||||
# The idea is that the scope of unbacked symints are limited to the subgraph.
|
||||
#
|
||||
# We're implementing the fake tensor mode of while_loop operator.
|
||||
# and we run body_fn once to get an fake output.
|
||||
# Let's only consider tensor output for now:
|
||||
# Let's first consider the case that unbacked symints are tensor shapes:
|
||||
#
|
||||
# Case 1:
|
||||
# if the unbacked symints is local to the subgraph e.g.
|
||||
|
|
@ -282,8 +405,8 @@ def while_loop_fake_tensor_mode(
|
|||
# no effect on the output of while_loop and it's tracked when we tracing.
|
||||
# the subgraph.
|
||||
#
|
||||
# Case 2.1:
|
||||
# if the unbacked symints are part of output of while_loop e.g.
|
||||
# Case 2:
|
||||
# if the unbacked symints are shape of output of while_loop e.g.
|
||||
# def body_fn(it, x):
|
||||
# nz = x.nonzero()
|
||||
# return it+1, nz
|
||||
|
|
@ -291,8 +414,8 @@ def while_loop_fake_tensor_mode(
|
|||
# must match the output shape as nz.shape contains newly allocated unbacked symint, this
|
||||
# won't match the carried_input's shape.
|
||||
#
|
||||
# Case 2.2:
|
||||
# if the unbacked symints are part of carried_inputs e.g.
|
||||
# Case 3:
|
||||
# if the unbacked symints are shape of carried_inputs e.g.
|
||||
# nz = a.nonzero()
|
||||
# body_fn(it, nz):
|
||||
# return it+1. nz.sin() + 1,
|
||||
|
|
@ -302,7 +425,10 @@ def while_loop_fake_tensor_mode(
|
|||
# so we could just return the output after one iteration.
|
||||
body_outs = body_fn(*carried_inputs, *additional_inputs)
|
||||
check_outputs_carry_consistency(body_outs, carried_inputs)
|
||||
return body_outs
|
||||
# See NOTE [unspecialize int carry with unbacked symints]
|
||||
return pytree.tree_map_only(
|
||||
(int, torch.SymInt), lambda _: _create_unbacked_symint(mode), body_outs
|
||||
)
|
||||
|
||||
|
||||
@while_loop_op.py_functionalize_impl
|
||||
|
|
|
|||
Loading…
Reference in a new issue