diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 8f195f60d15..45433b6795c 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -2799,6 +2799,22 @@ class MiscTests(torch._dynamo.test_case.TestCase): with self.assertRaisesRegex(RuntimeError, "Detected that you are using FX"): gm = torch.fx.symbolic_trace(optimized) + @patch.object(torch._dynamo.config, "error_on_nested_fx_trace", False) + def test_no_error_on_nested_fx_trace(self): + input = torch.rand(2, 3) + + def f(x): + x + x + + real = f(input) + + optimized = torch._dynamo.optimize("eager")(f) + self.assertTrue(same(optimized(input), real)) + + # should not error + gm = torch.fx.symbolic_trace(optimized) + self.assertTrue(same(gm(input), real)) + def test_inference_mode(self): @torch.inference_mode() def func(x, y): diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index cd6aedee604..09bfa572d77 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -2,7 +2,6 @@ import contextlib import copy import functools import inspect -import itertools import logging import os import sys @@ -150,20 +149,17 @@ class _TorchDynamoContext: @functools.wraps(fn) def _fn(*args, **kwargs): - any_arg_is_proxy = any( - map( - lambda arg: isinstance(arg, torch.fx.Proxy), - itertools.chain(args, kwargs.values()), - ) - ) - if any_arg_is_proxy: + if ( + not isinstance(self, DisableContext) + and torch.fx._symbolic_trace.is_fx_tracing() + ): if config.error_on_nested_fx_trace: raise RuntimeError( "Detected that you are using FX to symbolically trace " "a dynamo-optimized function. This is not supported at the moment." ) else: - return fn + return fn(*args, **kwargs) on_enter() prior = set_eval_frame(callback)