[NestedTensor] Use maybe_mark_dynamic instead of mark_dynamic (#127453)

Fixes #127097

**TL;DR**: dimensions marked with mark_dynamic can result in assertion failures if the marked-dynamic dimensions get specialized. In NJT, we don't care _that_ much that a dimension is marked as dynamic. So instead, mark with `maybe_mark_dynamic` which suggests that a dimension should be dynamic, but doesn't fail if the dimension gets specialized.

**Background**:
NJT marks the values tensor as dynamic:

49ad90349d/torch/nested/_internal/nested_tensor.py (L122)

It does this for two reasons:
1. **Conceptual**: We know that this dimension _should_ be dynamic; it's a nested tensor, so the sequence lengths will _probably_ vary between batches in the common case. Therefore, we should compile it as dynamic to prevent needing a recompile to trigger automatic dynamic shapes.
2. **Implementation detail**: Right now we run into issues with torch.compile / tensor_unflatten / other details when the dimensions are not marked as dynamic. We have some attempts to remove this (e.g. https://github.com/pytorch/pytorch/pull/126563) but while testing this I wasn't able to get all tests to pass, so there could be potential regressions here if we removed the mark_dynamic.

**Justification for this change**

1. **Conceptual**: AFAIK, we don't care enough about the dynamism of this dimension to error out if we specialize. We'd prefer that we don't have to recompile to get automatic dynamic shapes, but it's also better to not have this issue (and not to force the user to go hunt down all the other equivalent shapes to mark them as dynamic as well). This solution allows us to suggest the dynamism but not force it.
2. **Implementation detail**: This still marks the dimension as symbolic at the beginning of dynamo tracing, so we will (probably) avoid a lot of the issues we run into when we completely remove the `mark_dynamic` decorators.

Differential Revision: [D57933779](https://our.internmc.facebook.com/intern/diff/D57933779)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127453
Approved by: https://github.com/soulitzer, https://github.com/YuqingJ
This commit is contained in:
David Berard 2024-05-31 00:07:06 +00:00 committed by PyTorch MergeBot
parent 6bfc6e0875
commit f33beb767d
2 changed files with 68 additions and 2 deletions

View file

@ -10,6 +10,8 @@ import math
import numpy as np
import torch
import torch._dynamo
import torch._dynamo.testing
import torch.nn
import torch.nn.functional as F
from torch.testing._internal.common_cuda import (
@ -4008,6 +4010,70 @@ class TestNestedTensorSubclass(TestCase):
nt1_t, nt2_t, nt3_t, nt4_t = (x.transpose(1, 2) for x in (nt1, nt2, nt3, nt4))
check_size(nt1_t, nt2_t, nt3_t, nt4_t)
@skipIfTorchDynamo("compiles internally")
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
def test_specialize_dynamic_shape(self, device):
values = torch.randn((18, 16), device=device)
offsets = torch.tensor([0, 2, 3, 6, 15, 18], device=device)
like_values = torch.randn_like(values)
# this marks values as dynamic
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
def fn(values, same_size):
# here, the dynamic shape is specialized by same_size's shape
# https://github.com/pytorch/pytorch/issues/127097
# make sure this doesn't error out in torch.compile
return values + same_size
self.assertEqual(
fn(values, like_values),
torch.compile(fn)(values, like_values),
)
@skipIfTorchDynamo("compiles internally")
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
def test_specialize_dynamic_shape_recompile(self, device):
def generate_inp(total_len):
values = torch.randn((total_len, 16), device=device)
offsets = torch.tensor([0, 2, 3, 6, 15, total_len], device=device)
like_values = torch.randn_like(values)
return values, offsets, like_values
def check_results(ref_fn, res_fn, args):
values, offsets, like_values = args
# this may add dynamic shape markings
# goal of this test is to make sure that whatever markings are there,
# we eventually stop recompiling as shape changes.
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
self.assertEqual(
ref_fn(values, like_values),
res_fn(values, like_values),
)
def fn(values, same_size):
return values + same_size
compile_counter = torch._dynamo.testing.CompileCounter()
compiled_fn = torch._dynamo.optimize(compile_counter, nopython=True)(fn)
check_results(fn, compiled_fn, generate_inp(18))
self.assertEqual(compile_counter.frame_count, 1)
check_results(fn, compiled_fn, generate_inp(19))
# we'll probably recompile here with dynamic shapes - it's okay if not though.
frame_count_2 = compile_counter.frame_count
self.assertIn(frame_count_2, [1, 2])
# make sure that by now we've already compiled with dynamic shapes, so additional
# shapes should not trigger additional recompiles.
check_results(fn, compiled_fn, generate_inp(20))
self.assertEqual(compile_counter.frame_count, frame_count_2)
# Doesn't work until we have real views
@xfailIfTorchDynamo
# Note 1: Math fallback doesn't work with bfloat16 on CUDA

View file

@ -118,8 +118,8 @@ class NestedTensor(torch.Tensor):
self._metadata_cache = kwargs.get("_metadata_cache") or {}
# collapsed ragged dim must always be dynamic
torch._dynamo.mark_dynamic(self, self._ragged_idx)
torch._dynamo.mark_dynamic(self._values, self._ragged_idx - 1)
torch._dynamo.maybe_mark_dynamic(self, self._ragged_idx)
torch._dynamo.maybe_mark_dynamic(self._values, self._ragged_idx - 1)
def values(self):
# dispatch to get proper view relationship