[dynamo] only error out on nested fx trace if dynamo is optimizing (#88640)

I think this is the final resolution to issue caused by
https://github.com/pytorch/pytorch/pull/87797. The nvfuser issue that PR
tripped up was because, even though we're correctly disabling
torchdynamo via a `DisableContext`, the nested fx trace check was still
firing. This PR properly narrows it to only fire if we're not disabled.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88640
Approved by: https://github.com/yf225
This commit is contained in:
Michael Suo 2022-11-07 22:23:01 -08:00 committed by PyTorch MergeBot
parent a02ea655b5
commit c0e6b4329f
2 changed files with 21 additions and 9 deletions

View file

@ -2799,6 +2799,22 @@ class MiscTests(torch._dynamo.test_case.TestCase):
with self.assertRaisesRegex(RuntimeError, "Detected that you are using FX"):
gm = torch.fx.symbolic_trace(optimized)
@patch.object(torch._dynamo.config, "error_on_nested_fx_trace", False)
def test_no_error_on_nested_fx_trace(self):
input = torch.rand(2, 3)
def f(x):
x + x
real = f(input)
optimized = torch._dynamo.optimize("eager")(f)
self.assertTrue(same(optimized(input), real))
# should not error
gm = torch.fx.symbolic_trace(optimized)
self.assertTrue(same(gm(input), real))
def test_inference_mode(self):
@torch.inference_mode()
def func(x, y):

View file

@ -2,7 +2,6 @@ import contextlib
import copy
import functools
import inspect
import itertools
import logging
import os
import sys
@ -150,20 +149,17 @@ class _TorchDynamoContext:
@functools.wraps(fn)
def _fn(*args, **kwargs):
any_arg_is_proxy = any(
map(
lambda arg: isinstance(arg, torch.fx.Proxy),
itertools.chain(args, kwargs.values()),
)
)
if any_arg_is_proxy:
if (
not isinstance(self, DisableContext)
and torch.fx._symbolic_trace.is_fx_tracing()
):
if config.error_on_nested_fx_trace:
raise RuntimeError(
"Detected that you are using FX to symbolically trace "
"a dynamo-optimized function. This is not supported at the moment."
)
else:
return fn
return fn(*args, **kwargs)
on_enter()
prior = set_eval_frame(callback)