mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[dynamo] only error out on nested fx trace if dynamo is optimizing (#88640)
I think this is the final resolution to issue caused by https://github.com/pytorch/pytorch/pull/87797. The nvfuser issue that PR tripped up was because, even though we're correctly disabling torchdynamo via a `DisableContext`, the nested fx trace check was still firing. This PR properly narrows it to only fire if we're not disabled. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88640 Approved by: https://github.com/yf225
This commit is contained in:
parent
a02ea655b5
commit
c0e6b4329f
2 changed files with 21 additions and 9 deletions
|
|
@ -2799,6 +2799,22 @@ class MiscTests(torch._dynamo.test_case.TestCase):
|
|||
with self.assertRaisesRegex(RuntimeError, "Detected that you are using FX"):
|
||||
gm = torch.fx.symbolic_trace(optimized)
|
||||
|
||||
@patch.object(torch._dynamo.config, "error_on_nested_fx_trace", False)
|
||||
def test_no_error_on_nested_fx_trace(self):
|
||||
input = torch.rand(2, 3)
|
||||
|
||||
def f(x):
|
||||
x + x
|
||||
|
||||
real = f(input)
|
||||
|
||||
optimized = torch._dynamo.optimize("eager")(f)
|
||||
self.assertTrue(same(optimized(input), real))
|
||||
|
||||
# should not error
|
||||
gm = torch.fx.symbolic_trace(optimized)
|
||||
self.assertTrue(same(gm(input), real))
|
||||
|
||||
def test_inference_mode(self):
|
||||
@torch.inference_mode()
|
||||
def func(x, y):
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import contextlib
|
|||
import copy
|
||||
import functools
|
||||
import inspect
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
|
@ -150,20 +149,17 @@ class _TorchDynamoContext:
|
|||
|
||||
@functools.wraps(fn)
|
||||
def _fn(*args, **kwargs):
|
||||
any_arg_is_proxy = any(
|
||||
map(
|
||||
lambda arg: isinstance(arg, torch.fx.Proxy),
|
||||
itertools.chain(args, kwargs.values()),
|
||||
)
|
||||
)
|
||||
if any_arg_is_proxy:
|
||||
if (
|
||||
not isinstance(self, DisableContext)
|
||||
and torch.fx._symbolic_trace.is_fx_tracing()
|
||||
):
|
||||
if config.error_on_nested_fx_trace:
|
||||
raise RuntimeError(
|
||||
"Detected that you are using FX to symbolically trace "
|
||||
"a dynamo-optimized function. This is not supported at the moment."
|
||||
)
|
||||
else:
|
||||
return fn
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
on_enter()
|
||||
prior = set_eval_frame(callback)
|
||||
|
|
|
|||
Loading…
Reference in a new issue