mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Update base for Update on "Test on in-graph constructed NJTs"
A recent set of bugs has been cropping up related to NJTs that constructed in-graph within a compiled function. This exercises different paths related to symbolic nested ints, etc. Some examples: * #145874 * #146644 To get ahead of these, we should do NJT testing for this case as well. This PR parametrizes the OpInfo tests for compile + forward to cover both in-graph constructed NJT and normal input cases. TBD what fails.. TODO: * Do this for compile + backward tests also (?) [ghstack-poisoned]
This commit is contained in:
parent
b975d44576
commit
b60de30769
1 changed files with 2 additions and 6 deletions
|
|
@ -8651,9 +8651,7 @@ class TestNestedTensorOpInfo(NestedTensorTestCase):
|
|||
def f(*args, **kwargs):
|
||||
return op_fn(*args, **kwargs)
|
||||
|
||||
compiled_f = torch.compile(
|
||||
f, fullgraph=True, backend="inductor"
|
||||
)
|
||||
compiled_f = torch.compile(f, fullgraph=True, backend="inductor")
|
||||
|
||||
out_ref = f(sample.input, *sample.args, **sample.kwargs)
|
||||
out_compile = compiled_f(sample.input, *sample.args, **sample.kwargs)
|
||||
|
|
@ -8705,9 +8703,7 @@ class TestNestedTensorOpInfo(NestedTensorTestCase):
|
|||
def f(*args, **kwargs):
|
||||
return op_fn(*args, **kwargs)
|
||||
|
||||
compiled_f = torch.compile(
|
||||
f, fullgraph=True, backend="inductor"
|
||||
)
|
||||
compiled_f = torch.compile(f, fullgraph=True, backend="inductor")
|
||||
|
||||
out_ref = f(sample.input, *sample.args, **sample.kwargs)
|
||||
out_compile = compiled_f(sample.input, *sample.args, **sample.kwargs)
|
||||
|
|
|
|||
Loading…
Reference in a new issue