From c0e6b4329fe2dd35bb0bf162f4203ad7e0162554 Mon Sep 17 00:00:00 2001 From: Michael Suo Date: Mon, 7 Nov 2022 22:23:01 -0800 Subject: [PATCH] [dynamo] only error out on nested fx trace if dynamo is optimizing (#88640) I think this is the final resolution to issue caused by https://github.com/pytorch/pytorch/pull/87797. The nvfuser issue that PR tripped up was because, even though we're correctly disabling torchdynamo via a `DisableContext`, the nested fx trace check was still firing. This PR properly narrows it to only fire if we're not disabled. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88640 Approved by: https://github.com/yf225 --- test/dynamo/test_misc.py | 16 ++++++++++++++++ torch/_dynamo/eval_frame.py | 14 +++++--------- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 8f195f60d15..45433b6795c 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -2799,6 +2799,22 @@ class MiscTests(torch._dynamo.test_case.TestCase): with self.assertRaisesRegex(RuntimeError, "Detected that you are using FX"): gm = torch.fx.symbolic_trace(optimized) + @patch.object(torch._dynamo.config, "error_on_nested_fx_trace", False) + def test_no_error_on_nested_fx_trace(self): + input = torch.rand(2, 3) + + def f(x): + x + x + + real = f(input) + + optimized = torch._dynamo.optimize("eager")(f) + self.assertTrue(same(optimized(input), real)) + + # should not error + gm = torch.fx.symbolic_trace(optimized) + self.assertTrue(same(gm(input), real)) + def test_inference_mode(self): @torch.inference_mode() def func(x, y): diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index cd6aedee604..09bfa572d77 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -2,7 +2,6 @@ import contextlib import copy import functools import inspect -import itertools import logging import os import sys @@ -150,20 +149,17 @@ class _TorchDynamoContext: @functools.wraps(fn) def _fn(*args, **kwargs): - any_arg_is_proxy = any( - map( - lambda arg: isinstance(arg, torch.fx.Proxy), - itertools.chain(args, kwargs.values()), - ) - ) - if any_arg_is_proxy: + if ( + not isinstance(self, DisableContext) + and torch.fx._symbolic_trace.is_fx_tracing() + ): if config.error_on_nested_fx_trace: raise RuntimeError( "Detected that you are using FX to symbolically trace " "a dynamo-optimized function. This is not supported at the moment." ) else: - return fn + return fn(*args, **kwargs) on_enter() prior = set_eval_frame(callback)