Handle unbacked SymInt sized outputs in AOTAutograd (#113159)

Thanks aakhundov for constructing the test case. This PR was constructed by running the failing test case, and then fixing problems until we got all the way to the end. There are a few distinct fixes:

* AOTAutograd performs equality tests on tensor metadata to determine if a metadata mutation had occurred. If we test i0 vs i1, we should report these are NOT equal, since obviously we have somehow resized the tensor from i0 to i1 (even if, on a particular run, it is possible i0 == i1).
* There's a sketchy fix for `test_aot_autograd_exhaustive_matmul_cpu_float32` where we check if the output shape equals the tangent shape. Unfortunately, the same `definitely_true` treatment does not work here, it still fails on the example. I piled an extra sketchy fix on top of it, where I just try my best to avoid doing the view. Maybe we should have some sort of logging here.
* Partitioner needs to get out a size for unbacked SymInt when partitioning. I just feed it a random heuristic value in this case, similar to how we've been dealing with this in Inductor.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113159
Approved by: https://github.com/aakhundov, https://github.com/bdhirsh
This commit is contained in:
Edward Z. Yang 2023-11-07 17:16:31 -08:00 committed by PyTorch MergeBot
parent aa376e31fd
commit 1f3fa13f0a
9 changed files with 69 additions and 16 deletions

View file

@ -979,6 +979,7 @@ coverage_ignore_functions = [
"is_symbolic",
"parallel_and",
"parallel_or",
"sym_eq",
"tensor_has_hints",
# torch.fx.experimental.unification.core
"reify",

View file

@ -3457,6 +3457,20 @@ class ReproTests(torch._dynamo.test_case.TestCase):
self.assertEqual(ref_mantissa, mantissa)
self.assertEqual(ref_exponent, exponent)
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_split_with_sizes_aot_autograd(self):
def fn(result, split_sizes):
rs = torch.ops.aten.split_with_sizes(result, split_sizes.tolist())
return rs
example_inputs = (
torch.randn(32, requires_grad=True),
torch.tensor((7, 16, 9)),
)
actual = torch.compile(fn, fullgraph=True, backend="aot_eager")(*example_inputs)
expected = fn(*example_inputs)
self.assertEqual(actual, expected)
def test_unspecialized_nn_module_with_torch_variable_attribute(self):
"""
In this case self.fn = something that should be a TorchVariable.

View file

@ -3503,7 +3503,6 @@ symbolic_aot_autograd_failures = {
xfail('nn.functional.nll_loss', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.pixel_shuffle', ''), # aten.pixel_shuffle.default - couldn't find symbolic meta fun...
xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta...
xfail('repeat_interleave', ''), # aten.repeat_interleave.Te...
xfail('_segment_reduce', 'lengths'), # aten.segment_reduce.default - couldn't find symbolic meta functio...
xfail('_segment_reduce', 'offsets'), # aten.segment_reduce.default - couldn't find symbolic meta functio...
xfail('special.i1', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition

View file

@ -592,7 +592,6 @@ class TestPythonDispatch(TestCase):
$0: f32[1] = input('x')
$1: f32[1] = torch._ops.aten.mul.Tensor($0, $0)
$2: f32[1] = input('grad_y')
True = torch._ops.aten.is_same_size.default($1, $2)
$3: f32[1] = torch._ops.aten.mul.Tensor($2, $0)
$4: f32[1] = torch._ops.aten.mul.Tensor($2, $0)
$5: f32[1] = torch._ops.aten.add.Tensor($4, $3)''')
@ -852,7 +851,6 @@ $0: f32[1] = input('x')
$1: f32[1] = input('x.grad')
$2: f32[1] = torch._ops.aten.pow.Tensor_Scalar($0, 2)
$3: f32[1] = input('grad_output')
True = torch._ops.aten.is_same_size.default($2, $3)
$4: f32[1] = torch._ops.aten.mul.Tensor($3, 2)
$5: f32[1] = torch._ops.aten.mul.Tensor($4, $0)
$6: f32[1] = torch._ops.aten.add_.Tensor($1, $5)''')

View file

@ -31,7 +31,9 @@ from torch._subclasses.fake_tensor import is_fake
from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode
from torch.fx import immutable_collections, Interpreter
from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types
from torch.fx.experimental.symbolic_shapes import ShapeEnv, is_concrete_int, fx_placeholder_vals
from torch.fx.experimental.symbolic_shapes import (
ShapeEnv, is_concrete_int, fx_placeholder_vals, definitely_true, definitely_false, sym_eq
)
from torch.multiprocessing.reductions import StorageWeakRef
from torch.nn.utils import stateless
from torch.utils._python_dispatch import is_traceable_wrapper_subclass, transform_subclass
@ -786,10 +788,10 @@ class TensorAlias:
def has_same_metadata(t1, t2):
return (
t1.size() == t2.size()
and t1.stride() == t2.stride()
and t1.storage_offset() == t2.storage_offset()
and t1.storage_offset() == t2.storage_offset()
definitely_true(sym_eq(t1.size(), t2.size()))
and definitely_true(sym_eq(t1.stride(), t2.stride()))
and definitely_true(t1.storage_offset() == t2.storage_offset())
and definitely_true(t1.storage_offset() == t2.storage_offset())
and t1.is_conj() == t2.is_conj()
and t1.is_neg() == t2.is_neg()
)
@ -1769,8 +1771,11 @@ def create_joint(
# A bit sketchy, but fixes e.g. test_aot_autograd_exhaustive_matmul_cpu_float32
# The issue is that we are sensitive to decomps that don't accurately maintain
# their output's _base.shape compared to eager mode, and this helps mitigate a bit.
# The not definitely_false is also sketchy; if unbacked
# symints are involved, we're just going to assume that the
# decomps setup the base shape correctly
needed_outs.append(
out if out.shape == tangent.shape else out.view(tangent.shape)
out if not definitely_false(sym_eq(out.shape, tangent.shape)) else out.view(tangent.shape)
)
needed_tangents.append(tangent)

View file

@ -314,10 +314,13 @@ def _size_of(node: fx.Node) -> int:
return 1
else:
return 999999
# NB: The fallback values here are meaningless, maybe we should respect
# torch._inductor.config.unbacked_symint_fallback (but this is a
# layering violation)
elif isinstance(val, (list, tuple)):
return sum(_tensor_nbytes(hint_int(n.numel()), n.dtype) for n in val if isinstance(n, torch.Tensor))
return sum(_tensor_nbytes(hint_int(n.numel(), fallback=4098), n.dtype) for n in val if isinstance(n, torch.Tensor))
elif isinstance(val, torch.Tensor):
return _tensor_nbytes(hint_int(val.numel()), val.dtype)
return _tensor_nbytes(hint_int(val.numel(), fallback=4098), val.dtype)
raise RuntimeError(f"Unknown metadata type {type(val)}")

View file

@ -64,8 +64,20 @@ def _make_grads(
new_grads: List[_OptionalTensor] = []
for out, grad in zip(outputs, grads):
if isinstance(grad, torch.Tensor):
from torch.fx.experimental.symbolic_shapes import expect_true, sym_eq
first_grad = grad if not is_grads_batched else grad[0]
if not torch.is_same_size(out, first_grad):
# TODO: We can remove this conditional once we uniformly use
# singleton int to represent jagged dimension, so that size() call
# on nested tensor works
if out.is_nested or first_grad.is_nested:
shape_matches = torch.is_same_size(out, first_grad)
else:
# We need to do a regular size check, without going through
# the operator, to be able to handle unbacked symints
# (expect_true ensures we can deal with unbacked)
shape_matches = expect_true(sym_eq(out.size(), first_grad.size()))
if not shape_matches:
out_shape, grad_shape = _calculate_shape(
out, first_grad, is_grads_batched
)

View file

@ -132,10 +132,12 @@ class SymNode:
self._update_hint()
return self._hint is not None
def require_hint(self):
def require_hint(self, fallback=None):
if self._hint is None:
self._update_hint()
if self._hint is None:
if fallback is not None:
return fallback
# NB: we expect this to raise
return self.shape_env.size_hint(self.expr)
return self._hint

View file

@ -61,7 +61,7 @@ __all__ = [
"guard_int", "guard_float", "guard_scalar",
"hint_int", "SYMPY_INTERP", "free_symbols", "is_symbol_binding_fx_node",
"is_concrete_bool", "SHAPEENV_EVENT_KEY", "CURRENT_NODE_KEY",
"has_free_symbols",
"has_free_symbols", "sym_eq",
]
# FX node metadata keys for symbolic shape FX graph.
@ -101,9 +101,14 @@ def create_contiguous(shape):
strides.append(dim * strides[-1])
return list(reversed(strides))
def hint_int(a):
def hint_int(a, fallback=None):
"""
Retrieve the hint for an int (based on the underlying real values as observed
at runtime). If no hint is available (e.g., because data dependent shapes),
if fallback is not None, use that instead (otherwise raise an error).
"""
if isinstance(a, torch.SymInt):
return a.node.require_hint()
return a.node.require_hint(fallback)
assert type(a) is int, a
return a
@ -275,6 +280,20 @@ def parallel_and(*args):
return False
return all(args)
def sym_eq(x, y):
"""
Like ==, but when run on list/tuple, it will recursively test equality
and use sym_and to join the results together, without guarding.
"""
if (isinstance(x, tuple) and isinstance(y, tuple)) or (isinstance(x, list) and isinstance(y, list)):
if len(x) != len(y):
return False
return functools.reduce(operator.and_, map(sym_eq, x, y), True)
elif isinstance(x, (int, torch.SymInt)) and isinstance(y, (int, torch.SymInt)):
return x == y
else:
raise AssertionError(f"unexpected sym_eq between {type(x)} {type(y)}")
def guard_scalar(a):
if isinstance(a, (SymBool, bool)):
return guard_bool(a)