diff --git a/test/dynamo/test_reorder_logs.py b/test/dynamo/test_reorder_logs.py index 80ffe2e13cd..b67013079fa 100644 --- a/test/dynamo/test_reorder_logs.py +++ b/test/dynamo/test_reorder_logs.py @@ -1,5 +1,6 @@ # Owner(s): ["module: dynamo"] import io +import logging import warnings from unittest.mock import patch @@ -9,6 +10,64 @@ import torch._dynamo.test_case import torch._dynamo.testing from torch._dynamo.testing import same from torch._dynamo.utils import counters +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, +) + + +logger = logging.getLogger(__name__) +logger_test = logging.getLogger("test") + + +def f_info(x): + x = x + x + logger.info("moo") + x = x * x + return x + + +def f_isEnabledFor(x): + x = x + x + if logger.isEnabledFor(logging.INFO): + logger.info("moo") + x = x * x + return x + + +@instantiate_parametrized_tests +class IgnoreLogsTests(torch._dynamo.test_case.TestCase): + @parametrize( + "ignore_method, fn, should_ignore_logger", + [ + (None, f_info, False), + (logger_test.info, f_info, False), + (None, f_isEnabledFor, False), + (logger_test.isEnabledFor, f_isEnabledFor, False), + (logger.info, f_info, True), + (logging.Logger.info, f_info, True), + (logger.isEnabledFor, f_isEnabledFor, True), + (logging.Logger.isEnabledFor, f_isEnabledFor, True), + ], + ) + def test_ignore_logger(self, ignore_method, fn, should_ignore_logger): + counters.clear() + x = torch.randn(3, 3) + orig_out = fn(x) + with torch._dynamo.config.patch(ignore_logger_methods={ignore_method}): + opt_f = torch.compile(backend="eager")(fn) + with self.assertLogs(logger, level="INFO") as captured: + logger.info("call logger info to avoid error") + opt_out = opt_f(x) + printed_output = [entry.split(":", 2)[2] for entry in captured.output] + + self.assertTrue(same(orig_out, opt_out)) + if should_ignore_logger: + self.assertNotIn("moo", printed_output) + self.assertEqual(len(counters["graph_break"]), 0) + else: + self.assertIn("moo", printed_output) + self.assertEqual(len(counters["graph_break"]), 1) class ReorderLogsTests(torch._dynamo.test_case.TestCase): diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 4d66727add9..058232b4ea5 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -447,6 +447,12 @@ log_compilation_metrics = True # mutated after the print statement. reorderable_logging_functions: Set[Callable[[Any], None]] = set() +# A set of methods that will be ignored while tracing, +# to prevent graph breaks. +# Add logging.Logger. to ignore all calls for method, +# or logger. to ignore calls for method from this logger instance only. +ignore_logger_methods: Set[Callable[..., Any]] = set() + # simulates what would happen if we didn't have support for BUILD_SET opcode, # used for testing inject_BUILD_SET_unimplemented_TESTING_ONLY = False diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 22291ef3d89..3ab2aae3576 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -1089,6 +1089,7 @@ def _compile( "inject_BUILD_SET_unimplemented_TESTING_ONLY", "_autograd_backward_strict_mode_banned_ops", "reorderable_logging_functions", + "ignore_logger_methods", "traceable_tensor_subclasses", "_custom_ops_profile", } diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index ced56ef2ac6..68d81757714 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -1454,6 +1454,7 @@ class LoggingLoggerVariable(VariableTracker): def __init__(self, value, **kwargs) -> None: super().__init__(**kwargs) + self.value = value def call_method( self, @@ -1465,7 +1466,15 @@ class LoggingLoggerVariable(VariableTracker): if tx.export: # For export cases, we can just make debugging functions no-ops return - unimplemented("Logger not supported for non-export cases") + method = getattr(self.value, name, None) + function = getattr(method, "__func__", None) + if {method, function}.intersection(torch._dynamo.config.ignore_logger_methods): + return variables.ConstantVariable.create(None) + unimplemented( + "Logger not supported for non-export cases. " + "To avoid graph breaks caused by logger in compile-mode, it is recommended to" + " disable logging by adding logging methods to config.ignore_logger_methods" + ) class ConstantLikeVariable(VariableTracker):