From b60de30769bef096672eb5f924333baa64b5ec96 Mon Sep 17 00:00:00 2001 From: Joel Schlosser Date: Fri, 7 Feb 2025 15:38:04 -0500 Subject: [PATCH] Update base for Update on "Test on in-graph constructed NJTs" A recent set of bugs has been cropping up related to NJTs that constructed in-graph within a compiled function. This exercises different paths related to symbolic nested ints, etc. Some examples: * #145874 * #146644 To get ahead of these, we should do NJT testing for this case as well. This PR parametrizes the OpInfo tests for compile + forward to cover both in-graph constructed NJT and normal input cases. TBD what fails.. TODO: * Do this for compile + backward tests also (?) [ghstack-poisoned] --- test/test_nestedtensor.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 04bb127178e..3b6f5e9f2c2 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -8651,9 +8651,7 @@ class TestNestedTensorOpInfo(NestedTensorTestCase): def f(*args, **kwargs): return op_fn(*args, **kwargs) - compiled_f = torch.compile( - f, fullgraph=True, backend="inductor" - ) + compiled_f = torch.compile(f, fullgraph=True, backend="inductor") out_ref = f(sample.input, *sample.args, **sample.kwargs) out_compile = compiled_f(sample.input, *sample.args, **sample.kwargs) @@ -8705,9 +8703,7 @@ class TestNestedTensorOpInfo(NestedTensorTestCase): def f(*args, **kwargs): return op_fn(*args, **kwargs) - compiled_f = torch.compile( - f, fullgraph=True, backend="inductor" - ) + compiled_f = torch.compile(f, fullgraph=True, backend="inductor") out_ref = f(sample.input, *sample.args, **sample.kwargs) out_compile = compiled_f(sample.input, *sample.args, **sample.kwargs)