mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Handle unbacked SymInt sized outputs in AOTAutograd (#113159)
Thanks aakhundov for constructing the test case. This PR was constructed by running the failing test case, and then fixing problems until we got all the way to the end. There are a few distinct fixes: * AOTAutograd performs equality tests on tensor metadata to determine if a metadata mutation had occurred. If we test i0 vs i1, we should report these are NOT equal, since obviously we have somehow resized the tensor from i0 to i1 (even if, on a particular run, it is possible i0 == i1). * There's a sketchy fix for `test_aot_autograd_exhaustive_matmul_cpu_float32` where we check if the output shape equals the tangent shape. Unfortunately, the same `definitely_true` treatment does not work here, it still fails on the example. I piled an extra sketchy fix on top of it, where I just try my best to avoid doing the view. Maybe we should have some sort of logging here. * Partitioner needs to get out a size for unbacked SymInt when partitioning. I just feed it a random heuristic value in this case, similar to how we've been dealing with this in Inductor. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/113159 Approved by: https://github.com/aakhundov, https://github.com/bdhirsh
This commit is contained in:
parent
aa376e31fd
commit
1f3fa13f0a
9 changed files with 69 additions and 16 deletions
|
|
@ -979,6 +979,7 @@ coverage_ignore_functions = [
|
|||
"is_symbolic",
|
||||
"parallel_and",
|
||||
"parallel_or",
|
||||
"sym_eq",
|
||||
"tensor_has_hints",
|
||||
# torch.fx.experimental.unification.core
|
||||
"reify",
|
||||
|
|
|
|||
|
|
@ -3457,6 +3457,20 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertEqual(ref_mantissa, mantissa)
|
||||
self.assertEqual(ref_exponent, exponent)
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_split_with_sizes_aot_autograd(self):
|
||||
def fn(result, split_sizes):
|
||||
rs = torch.ops.aten.split_with_sizes(result, split_sizes.tolist())
|
||||
return rs
|
||||
|
||||
example_inputs = (
|
||||
torch.randn(32, requires_grad=True),
|
||||
torch.tensor((7, 16, 9)),
|
||||
)
|
||||
actual = torch.compile(fn, fullgraph=True, backend="aot_eager")(*example_inputs)
|
||||
expected = fn(*example_inputs)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_unspecialized_nn_module_with_torch_variable_attribute(self):
|
||||
"""
|
||||
In this case self.fn = something that should be a TorchVariable.
|
||||
|
|
|
|||
|
|
@ -3503,7 +3503,6 @@ symbolic_aot_autograd_failures = {
|
|||
xfail('nn.functional.nll_loss', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('nn.functional.pixel_shuffle', ''), # aten.pixel_shuffle.default - couldn't find symbolic meta fun...
|
||||
xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta...
|
||||
xfail('repeat_interleave', ''), # aten.repeat_interleave.Te...
|
||||
xfail('_segment_reduce', 'lengths'), # aten.segment_reduce.default - couldn't find symbolic meta functio...
|
||||
xfail('_segment_reduce', 'offsets'), # aten.segment_reduce.default - couldn't find symbolic meta functio...
|
||||
xfail('special.i1', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition
|
||||
|
|
|
|||
|
|
@ -592,7 +592,6 @@ class TestPythonDispatch(TestCase):
|
|||
$0: f32[1] = input('x')
|
||||
$1: f32[1] = torch._ops.aten.mul.Tensor($0, $0)
|
||||
$2: f32[1] = input('grad_y')
|
||||
True = torch._ops.aten.is_same_size.default($1, $2)
|
||||
$3: f32[1] = torch._ops.aten.mul.Tensor($2, $0)
|
||||
$4: f32[1] = torch._ops.aten.mul.Tensor($2, $0)
|
||||
$5: f32[1] = torch._ops.aten.add.Tensor($4, $3)''')
|
||||
|
|
@ -852,7 +851,6 @@ $0: f32[1] = input('x')
|
|||
$1: f32[1] = input('x.grad')
|
||||
$2: f32[1] = torch._ops.aten.pow.Tensor_Scalar($0, 2)
|
||||
$3: f32[1] = input('grad_output')
|
||||
True = torch._ops.aten.is_same_size.default($2, $3)
|
||||
$4: f32[1] = torch._ops.aten.mul.Tensor($3, 2)
|
||||
$5: f32[1] = torch._ops.aten.mul.Tensor($4, $0)
|
||||
$6: f32[1] = torch._ops.aten.add_.Tensor($1, $5)''')
|
||||
|
|
|
|||
|
|
@ -31,7 +31,9 @@ from torch._subclasses.fake_tensor import is_fake
|
|||
from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode
|
||||
from torch.fx import immutable_collections, Interpreter
|
||||
from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv, is_concrete_int, fx_placeholder_vals
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
ShapeEnv, is_concrete_int, fx_placeholder_vals, definitely_true, definitely_false, sym_eq
|
||||
)
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
from torch.nn.utils import stateless
|
||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass, transform_subclass
|
||||
|
|
@ -786,10 +788,10 @@ class TensorAlias:
|
|||
|
||||
def has_same_metadata(t1, t2):
|
||||
return (
|
||||
t1.size() == t2.size()
|
||||
and t1.stride() == t2.stride()
|
||||
and t1.storage_offset() == t2.storage_offset()
|
||||
and t1.storage_offset() == t2.storage_offset()
|
||||
definitely_true(sym_eq(t1.size(), t2.size()))
|
||||
and definitely_true(sym_eq(t1.stride(), t2.stride()))
|
||||
and definitely_true(t1.storage_offset() == t2.storage_offset())
|
||||
and definitely_true(t1.storage_offset() == t2.storage_offset())
|
||||
and t1.is_conj() == t2.is_conj()
|
||||
and t1.is_neg() == t2.is_neg()
|
||||
)
|
||||
|
|
@ -1769,8 +1771,11 @@ def create_joint(
|
|||
# A bit sketchy, but fixes e.g. test_aot_autograd_exhaustive_matmul_cpu_float32
|
||||
# The issue is that we are sensitive to decomps that don't accurately maintain
|
||||
# their output's _base.shape compared to eager mode, and this helps mitigate a bit.
|
||||
# The not definitely_false is also sketchy; if unbacked
|
||||
# symints are involved, we're just going to assume that the
|
||||
# decomps setup the base shape correctly
|
||||
needed_outs.append(
|
||||
out if out.shape == tangent.shape else out.view(tangent.shape)
|
||||
out if not definitely_false(sym_eq(out.shape, tangent.shape)) else out.view(tangent.shape)
|
||||
)
|
||||
needed_tangents.append(tangent)
|
||||
|
||||
|
|
|
|||
|
|
@ -314,10 +314,13 @@ def _size_of(node: fx.Node) -> int:
|
|||
return 1
|
||||
else:
|
||||
return 999999
|
||||
# NB: The fallback values here are meaningless, maybe we should respect
|
||||
# torch._inductor.config.unbacked_symint_fallback (but this is a
|
||||
# layering violation)
|
||||
elif isinstance(val, (list, tuple)):
|
||||
return sum(_tensor_nbytes(hint_int(n.numel()), n.dtype) for n in val if isinstance(n, torch.Tensor))
|
||||
return sum(_tensor_nbytes(hint_int(n.numel(), fallback=4098), n.dtype) for n in val if isinstance(n, torch.Tensor))
|
||||
elif isinstance(val, torch.Tensor):
|
||||
return _tensor_nbytes(hint_int(val.numel()), val.dtype)
|
||||
return _tensor_nbytes(hint_int(val.numel(), fallback=4098), val.dtype)
|
||||
|
||||
raise RuntimeError(f"Unknown metadata type {type(val)}")
|
||||
|
||||
|
|
|
|||
|
|
@ -64,8 +64,20 @@ def _make_grads(
|
|||
new_grads: List[_OptionalTensor] = []
|
||||
for out, grad in zip(outputs, grads):
|
||||
if isinstance(grad, torch.Tensor):
|
||||
from torch.fx.experimental.symbolic_shapes import expect_true, sym_eq
|
||||
|
||||
first_grad = grad if not is_grads_batched else grad[0]
|
||||
if not torch.is_same_size(out, first_grad):
|
||||
# TODO: We can remove this conditional once we uniformly use
|
||||
# singleton int to represent jagged dimension, so that size() call
|
||||
# on nested tensor works
|
||||
if out.is_nested or first_grad.is_nested:
|
||||
shape_matches = torch.is_same_size(out, first_grad)
|
||||
else:
|
||||
# We need to do a regular size check, without going through
|
||||
# the operator, to be able to handle unbacked symints
|
||||
# (expect_true ensures we can deal with unbacked)
|
||||
shape_matches = expect_true(sym_eq(out.size(), first_grad.size()))
|
||||
if not shape_matches:
|
||||
out_shape, grad_shape = _calculate_shape(
|
||||
out, first_grad, is_grads_batched
|
||||
)
|
||||
|
|
|
|||
|
|
@ -132,10 +132,12 @@ class SymNode:
|
|||
self._update_hint()
|
||||
return self._hint is not None
|
||||
|
||||
def require_hint(self):
|
||||
def require_hint(self, fallback=None):
|
||||
if self._hint is None:
|
||||
self._update_hint()
|
||||
if self._hint is None:
|
||||
if fallback is not None:
|
||||
return fallback
|
||||
# NB: we expect this to raise
|
||||
return self.shape_env.size_hint(self.expr)
|
||||
return self._hint
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ __all__ = [
|
|||
"guard_int", "guard_float", "guard_scalar",
|
||||
"hint_int", "SYMPY_INTERP", "free_symbols", "is_symbol_binding_fx_node",
|
||||
"is_concrete_bool", "SHAPEENV_EVENT_KEY", "CURRENT_NODE_KEY",
|
||||
"has_free_symbols",
|
||||
"has_free_symbols", "sym_eq",
|
||||
]
|
||||
|
||||
# FX node metadata keys for symbolic shape FX graph.
|
||||
|
|
@ -101,9 +101,14 @@ def create_contiguous(shape):
|
|||
strides.append(dim * strides[-1])
|
||||
return list(reversed(strides))
|
||||
|
||||
def hint_int(a):
|
||||
def hint_int(a, fallback=None):
|
||||
"""
|
||||
Retrieve the hint for an int (based on the underlying real values as observed
|
||||
at runtime). If no hint is available (e.g., because data dependent shapes),
|
||||
if fallback is not None, use that instead (otherwise raise an error).
|
||||
"""
|
||||
if isinstance(a, torch.SymInt):
|
||||
return a.node.require_hint()
|
||||
return a.node.require_hint(fallback)
|
||||
assert type(a) is int, a
|
||||
return a
|
||||
|
||||
|
|
@ -275,6 +280,20 @@ def parallel_and(*args):
|
|||
return False
|
||||
return all(args)
|
||||
|
||||
def sym_eq(x, y):
|
||||
"""
|
||||
Like ==, but when run on list/tuple, it will recursively test equality
|
||||
and use sym_and to join the results together, without guarding.
|
||||
"""
|
||||
if (isinstance(x, tuple) and isinstance(y, tuple)) or (isinstance(x, list) and isinstance(y, list)):
|
||||
if len(x) != len(y):
|
||||
return False
|
||||
return functools.reduce(operator.and_, map(sym_eq, x, y), True)
|
||||
elif isinstance(x, (int, torch.SymInt)) and isinstance(y, (int, torch.SymInt)):
|
||||
return x == y
|
||||
else:
|
||||
raise AssertionError(f"unexpected sym_eq between {type(x)} {type(y)}")
|
||||
|
||||
def guard_scalar(a):
|
||||
if isinstance(a, (SymBool, bool)):
|
||||
return guard_bool(a)
|
||||
|
|
|
|||
Loading…
Reference in a new issue