mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
[NestedTensor] Use maybe_mark_dynamic instead of mark_dynamic (#127453)
Fixes #127097
**TL;DR**: dimensions marked with mark_dynamic can result in assertion failures if the marked-dynamic dimensions get specialized. In NJT, we don't care _that_ much that a dimension is marked as dynamic. So instead, mark with `maybe_mark_dynamic` which suggests that a dimension should be dynamic, but doesn't fail if the dimension gets specialized.
**Background**:
NJT marks the values tensor as dynamic:
49ad90349d/torch/nested/_internal/nested_tensor.py (L122)
It does this for two reasons:
1. **Conceptual**: We know that this dimension _should_ be dynamic; it's a nested tensor, so the sequence lengths will _probably_ vary between batches in the common case. Therefore, we should compile it as dynamic to prevent needing a recompile to trigger automatic dynamic shapes.
2. **Implementation detail**: Right now we run into issues with torch.compile / tensor_unflatten / other details when the dimensions are not marked as dynamic. We have some attempts to remove this (e.g. https://github.com/pytorch/pytorch/pull/126563) but while testing this I wasn't able to get all tests to pass, so there could be potential regressions here if we removed the mark_dynamic.
**Justification for this change**
1. **Conceptual**: AFAIK, we don't care enough about the dynamism of this dimension to error out if we specialize. We'd prefer that we don't have to recompile to get automatic dynamic shapes, but it's also better to not have this issue (and not to force the user to go hunt down all the other equivalent shapes to mark them as dynamic as well). This solution allows us to suggest the dynamism but not force it.
2. **Implementation detail**: This still marks the dimension as symbolic at the beginning of dynamo tracing, so we will (probably) avoid a lot of the issues we run into when we completely remove the `mark_dynamic` decorators.
Differential Revision: [D57933779](https://our.internmc.facebook.com/intern/diff/D57933779)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127453
Approved by: https://github.com/soulitzer, https://github.com/YuqingJ
This commit is contained in:
parent
6bfc6e0875
commit
f33beb767d
2 changed files with 68 additions and 2 deletions
|
|
@ -10,6 +10,8 @@ import math
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch._dynamo.testing
|
||||
import torch.nn
|
||||
import torch.nn.functional as F
|
||||
from torch.testing._internal.common_cuda import (
|
||||
|
|
@ -4008,6 +4010,70 @@ class TestNestedTensorSubclass(TestCase):
|
|||
nt1_t, nt2_t, nt3_t, nt4_t = (x.transpose(1, 2) for x in (nt1, nt2, nt3, nt4))
|
||||
check_size(nt1_t, nt2_t, nt3_t, nt4_t)
|
||||
|
||||
@skipIfTorchDynamo("compiles internally")
|
||||
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
|
||||
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
|
||||
def test_specialize_dynamic_shape(self, device):
|
||||
values = torch.randn((18, 16), device=device)
|
||||
offsets = torch.tensor([0, 2, 3, 6, 15, 18], device=device)
|
||||
like_values = torch.randn_like(values)
|
||||
|
||||
# this marks values as dynamic
|
||||
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
|
||||
|
||||
def fn(values, same_size):
|
||||
# here, the dynamic shape is specialized by same_size's shape
|
||||
# https://github.com/pytorch/pytorch/issues/127097
|
||||
# make sure this doesn't error out in torch.compile
|
||||
return values + same_size
|
||||
|
||||
self.assertEqual(
|
||||
fn(values, like_values),
|
||||
torch.compile(fn)(values, like_values),
|
||||
)
|
||||
|
||||
@skipIfTorchDynamo("compiles internally")
|
||||
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
|
||||
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
|
||||
def test_specialize_dynamic_shape_recompile(self, device):
|
||||
def generate_inp(total_len):
|
||||
values = torch.randn((total_len, 16), device=device)
|
||||
offsets = torch.tensor([0, 2, 3, 6, 15, total_len], device=device)
|
||||
like_values = torch.randn_like(values)
|
||||
return values, offsets, like_values
|
||||
|
||||
def check_results(ref_fn, res_fn, args):
|
||||
values, offsets, like_values = args
|
||||
# this may add dynamic shape markings
|
||||
# goal of this test is to make sure that whatever markings are there,
|
||||
# we eventually stop recompiling as shape changes.
|
||||
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
|
||||
|
||||
self.assertEqual(
|
||||
ref_fn(values, like_values),
|
||||
res_fn(values, like_values),
|
||||
)
|
||||
|
||||
|
||||
def fn(values, same_size):
|
||||
return values + same_size
|
||||
|
||||
compile_counter = torch._dynamo.testing.CompileCounter()
|
||||
|
||||
compiled_fn = torch._dynamo.optimize(compile_counter, nopython=True)(fn)
|
||||
check_results(fn, compiled_fn, generate_inp(18))
|
||||
self.assertEqual(compile_counter.frame_count, 1)
|
||||
|
||||
check_results(fn, compiled_fn, generate_inp(19))
|
||||
# we'll probably recompile here with dynamic shapes - it's okay if not though.
|
||||
frame_count_2 = compile_counter.frame_count
|
||||
self.assertIn(frame_count_2, [1, 2])
|
||||
|
||||
# make sure that by now we've already compiled with dynamic shapes, so additional
|
||||
# shapes should not trigger additional recompiles.
|
||||
check_results(fn, compiled_fn, generate_inp(20))
|
||||
self.assertEqual(compile_counter.frame_count, frame_count_2)
|
||||
|
||||
# Doesn't work until we have real views
|
||||
@xfailIfTorchDynamo
|
||||
# Note 1: Math fallback doesn't work with bfloat16 on CUDA
|
||||
|
|
|
|||
|
|
@ -118,8 +118,8 @@ class NestedTensor(torch.Tensor):
|
|||
self._metadata_cache = kwargs.get("_metadata_cache") or {}
|
||||
|
||||
# collapsed ragged dim must always be dynamic
|
||||
torch._dynamo.mark_dynamic(self, self._ragged_idx)
|
||||
torch._dynamo.mark_dynamic(self._values, self._ragged_idx - 1)
|
||||
torch._dynamo.maybe_mark_dynamic(self, self._ragged_idx)
|
||||
torch._dynamo.maybe_mark_dynamic(self._values, self._ragged_idx - 1)
|
||||
|
||||
def values(self):
|
||||
# dispatch to get proper view relationship
|
||||
|
|
|
|||
Loading…
Reference in a new issue