Ignore logger methods to avoid graph breaks (#139403)

Fixes #132635

Calls to logging.logger cause a graph break, this PR allows the user to avoid these graph breaks (for specific methods) by setting DISABLE_LOGS_WHILE_COMPILING to 1.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139403
Approved by: https://github.com/williamwen42
This commit is contained in:
snahir 2024-12-05 20:12:23 +00:00 committed by PyTorch MergeBot
parent 41952c1876
commit 16ea0ddcdb
4 changed files with 76 additions and 1 deletions

View file

@ -1,5 +1,6 @@
# Owner(s): ["module: dynamo"]
import io
import logging
import warnings
from unittest.mock import patch
@ -9,6 +10,64 @@ import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.testing import same
from torch._dynamo.utils import counters
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
)
logger = logging.getLogger(__name__)
logger_test = logging.getLogger("test")
def f_info(x):
x = x + x
logger.info("moo")
x = x * x
return x
def f_isEnabledFor(x):
x = x + x
if logger.isEnabledFor(logging.INFO):
logger.info("moo")
x = x * x
return x
@instantiate_parametrized_tests
class IgnoreLogsTests(torch._dynamo.test_case.TestCase):
@parametrize(
"ignore_method, fn, should_ignore_logger",
[
(None, f_info, False),
(logger_test.info, f_info, False),
(None, f_isEnabledFor, False),
(logger_test.isEnabledFor, f_isEnabledFor, False),
(logger.info, f_info, True),
(logging.Logger.info, f_info, True),
(logger.isEnabledFor, f_isEnabledFor, True),
(logging.Logger.isEnabledFor, f_isEnabledFor, True),
],
)
def test_ignore_logger(self, ignore_method, fn, should_ignore_logger):
counters.clear()
x = torch.randn(3, 3)
orig_out = fn(x)
with torch._dynamo.config.patch(ignore_logger_methods={ignore_method}):
opt_f = torch.compile(backend="eager")(fn)
with self.assertLogs(logger, level="INFO") as captured:
logger.info("call logger info to avoid error")
opt_out = opt_f(x)
printed_output = [entry.split(":", 2)[2] for entry in captured.output]
self.assertTrue(same(orig_out, opt_out))
if should_ignore_logger:
self.assertNotIn("moo", printed_output)
self.assertEqual(len(counters["graph_break"]), 0)
else:
self.assertIn("moo", printed_output)
self.assertEqual(len(counters["graph_break"]), 1)
class ReorderLogsTests(torch._dynamo.test_case.TestCase):

View file

@ -447,6 +447,12 @@ log_compilation_metrics = True
# mutated after the print statement.
reorderable_logging_functions: Set[Callable[[Any], None]] = set()
# A set of methods that will be ignored while tracing,
# to prevent graph breaks.
# Add logging.Logger.<method> to ignore all calls for method,
# or logger.<method> to ignore calls for method from this logger instance only.
ignore_logger_methods: Set[Callable[..., Any]] = set()
# simulates what would happen if we didn't have support for BUILD_SET opcode,
# used for testing
inject_BUILD_SET_unimplemented_TESTING_ONLY = False

View file

@ -1089,6 +1089,7 @@ def _compile(
"inject_BUILD_SET_unimplemented_TESTING_ONLY",
"_autograd_backward_strict_mode_banned_ops",
"reorderable_logging_functions",
"ignore_logger_methods",
"traceable_tensor_subclasses",
"_custom_ops_profile",
}

View file

@ -1454,6 +1454,7 @@ class LoggingLoggerVariable(VariableTracker):
def __init__(self, value, **kwargs) -> None:
super().__init__(**kwargs)
self.value = value
def call_method(
self,
@ -1465,7 +1466,15 @@ class LoggingLoggerVariable(VariableTracker):
if tx.export:
# For export cases, we can just make debugging functions no-ops
return
unimplemented("Logger not supported for non-export cases")
method = getattr(self.value, name, None)
function = getattr(method, "__func__", None)
if {method, function}.intersection(torch._dynamo.config.ignore_logger_methods):
return variables.ConstantVariable.create(None)
unimplemented(
"Logger not supported for non-export cases. "
"To avoid graph breaks caused by logger in compile-mode, it is recommended to"
" disable logging by adding logging methods to config.ignore_logger_methods"
)
class ConstantLikeVariable(VariableTracker):