From f33beb767d04ad00aecbcf16690e786eb93ebdd8 Mon Sep 17 00:00:00 2001 From: David Berard Date: Fri, 31 May 2024 00:07:06 +0000 Subject: [PATCH] [NestedTensor] Use maybe_mark_dynamic instead of mark_dynamic (#127453) Fixes #127097 **TL;DR**: dimensions marked with mark_dynamic can result in assertion failures if the marked-dynamic dimensions get specialized. In NJT, we don't care _that_ much that a dimension is marked as dynamic. So instead, mark with `maybe_mark_dynamic` which suggests that a dimension should be dynamic, but doesn't fail if the dimension gets specialized. **Background**: NJT marks the values tensor as dynamic: https://github.com/pytorch/pytorch/blob/49ad90349d57c35ab83f40c28d8b18caefb416d1/torch/nested/_internal/nested_tensor.py#L122 It does this for two reasons: 1. **Conceptual**: We know that this dimension _should_ be dynamic; it's a nested tensor, so the sequence lengths will _probably_ vary between batches in the common case. Therefore, we should compile it as dynamic to prevent needing a recompile to trigger automatic dynamic shapes. 2. **Implementation detail**: Right now we run into issues with torch.compile / tensor_unflatten / other details when the dimensions are not marked as dynamic. We have some attempts to remove this (e.g. https://github.com/pytorch/pytorch/pull/126563) but while testing this I wasn't able to get all tests to pass, so there could be potential regressions here if we removed the mark_dynamic. **Justification for this change** 1. **Conceptual**: AFAIK, we don't care enough about the dynamism of this dimension to error out if we specialize. We'd prefer that we don't have to recompile to get automatic dynamic shapes, but it's also better to not have this issue (and not to force the user to go hunt down all the other equivalent shapes to mark them as dynamic as well). This solution allows us to suggest the dynamism but not force it. 2. **Implementation detail**: This still marks the dimension as symbolic at the beginning of dynamo tracing, so we will (probably) avoid a lot of the issues we run into when we completely remove the `mark_dynamic` decorators. Differential Revision: [D57933779](https://our.internmc.facebook.com/intern/diff/D57933779) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127453 Approved by: https://github.com/soulitzer, https://github.com/YuqingJ --- test/test_nestedtensor.py | 66 +++++++++++++++++++++++++ torch/nested/_internal/nested_tensor.py | 4 +- 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 597180129f7..d369135a6e5 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -10,6 +10,8 @@ import math import numpy as np import torch +import torch._dynamo +import torch._dynamo.testing import torch.nn import torch.nn.functional as F from torch.testing._internal.common_cuda import ( @@ -4008,6 +4010,70 @@ class TestNestedTensorSubclass(TestCase): nt1_t, nt2_t, nt3_t, nt4_t = (x.transpose(1, 2) for x in (nt1, nt2, nt3, nt4)) check_size(nt1_t, nt2_t, nt3_t, nt4_t) + @skipIfTorchDynamo("compiles internally") + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") + @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + def test_specialize_dynamic_shape(self, device): + values = torch.randn((18, 16), device=device) + offsets = torch.tensor([0, 2, 3, 6, 15, 18], device=device) + like_values = torch.randn_like(values) + + # this marks values as dynamic + nt = torch.nested.nested_tensor_from_jagged(values, offsets) + + def fn(values, same_size): + # here, the dynamic shape is specialized by same_size's shape + # https://github.com/pytorch/pytorch/issues/127097 + # make sure this doesn't error out in torch.compile + return values + same_size + + self.assertEqual( + fn(values, like_values), + torch.compile(fn)(values, like_values), + ) + + @skipIfTorchDynamo("compiles internally") + @unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile") + @skipCUDAIf(not SM70OrLater, "GPU capability is < SM70") + def test_specialize_dynamic_shape_recompile(self, device): + def generate_inp(total_len): + values = torch.randn((total_len, 16), device=device) + offsets = torch.tensor([0, 2, 3, 6, 15, total_len], device=device) + like_values = torch.randn_like(values) + return values, offsets, like_values + + def check_results(ref_fn, res_fn, args): + values, offsets, like_values = args + # this may add dynamic shape markings + # goal of this test is to make sure that whatever markings are there, + # we eventually stop recompiling as shape changes. + nt = torch.nested.nested_tensor_from_jagged(values, offsets) + + self.assertEqual( + ref_fn(values, like_values), + res_fn(values, like_values), + ) + + + def fn(values, same_size): + return values + same_size + + compile_counter = torch._dynamo.testing.CompileCounter() + + compiled_fn = torch._dynamo.optimize(compile_counter, nopython=True)(fn) + check_results(fn, compiled_fn, generate_inp(18)) + self.assertEqual(compile_counter.frame_count, 1) + + check_results(fn, compiled_fn, generate_inp(19)) + # we'll probably recompile here with dynamic shapes - it's okay if not though. + frame_count_2 = compile_counter.frame_count + self.assertIn(frame_count_2, [1, 2]) + + # make sure that by now we've already compiled with dynamic shapes, so additional + # shapes should not trigger additional recompiles. + check_results(fn, compiled_fn, generate_inp(20)) + self.assertEqual(compile_counter.frame_count, frame_count_2) + # Doesn't work until we have real views @xfailIfTorchDynamo # Note 1: Math fallback doesn't work with bfloat16 on CUDA diff --git a/torch/nested/_internal/nested_tensor.py b/torch/nested/_internal/nested_tensor.py index 5cc6b1c75d7..5ef8983a839 100644 --- a/torch/nested/_internal/nested_tensor.py +++ b/torch/nested/_internal/nested_tensor.py @@ -118,8 +118,8 @@ class NestedTensor(torch.Tensor): self._metadata_cache = kwargs.get("_metadata_cache") or {} # collapsed ragged dim must always be dynamic - torch._dynamo.mark_dynamic(self, self._ragged_idx) - torch._dynamo.mark_dynamic(self._values, self._ragged_idx - 1) + torch._dynamo.maybe_mark_dynamic(self, self._ragged_idx) + torch._dynamo.maybe_mark_dynamic(self._values, self._ragged_idx - 1) def values(self): # dispatch to get proper view relationship