Update base for Update on "Test on in-graph constructed NJTs"

A recent set of bugs has been cropping up related to NJTs that constructed in-graph within a compiled function. This exercises different paths related to symbolic nested ints, etc. Some examples:
* #145874
* #146644

To get ahead of these, we should do NJT testing for this case as well.

This PR parametrizes the OpInfo tests for compile + forward to cover both in-graph constructed NJT and normal input cases. TBD what fails..

TODO:
* Do this for compile + backward tests also (?)

[ghstack-poisoned]
This commit is contained in:
Joel Schlosser 2025-02-07 15:38:04 -05:00
parent b975d44576
commit b60de30769

View file

@ -8651,9 +8651,7 @@ class TestNestedTensorOpInfo(NestedTensorTestCase):
def f(*args, **kwargs):
return op_fn(*args, **kwargs)
compiled_f = torch.compile(
f, fullgraph=True, backend="inductor"
)
compiled_f = torch.compile(f, fullgraph=True, backend="inductor")
out_ref = f(sample.input, *sample.args, **sample.kwargs)
out_compile = compiled_f(sample.input, *sample.args, **sample.kwargs)
@ -8705,9 +8703,7 @@ class TestNestedTensorOpInfo(NestedTensorTestCase):
def f(*args, **kwargs):
return op_fn(*args, **kwargs)
compiled_f = torch.compile(
f, fullgraph=True, backend="inductor"
)
compiled_f = torch.compile(f, fullgraph=True, backend="inductor")
out_ref = f(sample.input, *sample.args, **sample.kwargs)
out_compile = compiled_f(sample.input, *sample.args, **sample.kwargs)