mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Ignore logger methods to avoid graph breaks (#139403)
Fixes #132635 Calls to logging.logger cause a graph break, this PR allows the user to avoid these graph breaks (for specific methods) by setting DISABLE_LOGS_WHILE_COMPILING to 1. Pull Request resolved: https://github.com/pytorch/pytorch/pull/139403 Approved by: https://github.com/williamwen42
This commit is contained in:
parent
41952c1876
commit
16ea0ddcdb
4 changed files with 76 additions and 1 deletions
|
|
@ -1,5 +1,6 @@
|
|||
# Owner(s): ["module: dynamo"]
|
||||
import io
|
||||
import logging
|
||||
import warnings
|
||||
from unittest.mock import patch
|
||||
|
||||
|
|
@ -9,6 +10,64 @@ import torch._dynamo.test_case
|
|||
import torch._dynamo.testing
|
||||
from torch._dynamo.testing import same
|
||||
from torch._dynamo.utils import counters
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger_test = logging.getLogger("test")
|
||||
|
||||
|
||||
def f_info(x):
|
||||
x = x + x
|
||||
logger.info("moo")
|
||||
x = x * x
|
||||
return x
|
||||
|
||||
|
||||
def f_isEnabledFor(x):
|
||||
x = x + x
|
||||
if logger.isEnabledFor(logging.INFO):
|
||||
logger.info("moo")
|
||||
x = x * x
|
||||
return x
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class IgnoreLogsTests(torch._dynamo.test_case.TestCase):
|
||||
@parametrize(
|
||||
"ignore_method, fn, should_ignore_logger",
|
||||
[
|
||||
(None, f_info, False),
|
||||
(logger_test.info, f_info, False),
|
||||
(None, f_isEnabledFor, False),
|
||||
(logger_test.isEnabledFor, f_isEnabledFor, False),
|
||||
(logger.info, f_info, True),
|
||||
(logging.Logger.info, f_info, True),
|
||||
(logger.isEnabledFor, f_isEnabledFor, True),
|
||||
(logging.Logger.isEnabledFor, f_isEnabledFor, True),
|
||||
],
|
||||
)
|
||||
def test_ignore_logger(self, ignore_method, fn, should_ignore_logger):
|
||||
counters.clear()
|
||||
x = torch.randn(3, 3)
|
||||
orig_out = fn(x)
|
||||
with torch._dynamo.config.patch(ignore_logger_methods={ignore_method}):
|
||||
opt_f = torch.compile(backend="eager")(fn)
|
||||
with self.assertLogs(logger, level="INFO") as captured:
|
||||
logger.info("call logger info to avoid error")
|
||||
opt_out = opt_f(x)
|
||||
printed_output = [entry.split(":", 2)[2] for entry in captured.output]
|
||||
|
||||
self.assertTrue(same(orig_out, opt_out))
|
||||
if should_ignore_logger:
|
||||
self.assertNotIn("moo", printed_output)
|
||||
self.assertEqual(len(counters["graph_break"]), 0)
|
||||
else:
|
||||
self.assertIn("moo", printed_output)
|
||||
self.assertEqual(len(counters["graph_break"]), 1)
|
||||
|
||||
|
||||
class ReorderLogsTests(torch._dynamo.test_case.TestCase):
|
||||
|
|
|
|||
|
|
@ -447,6 +447,12 @@ log_compilation_metrics = True
|
|||
# mutated after the print statement.
|
||||
reorderable_logging_functions: Set[Callable[[Any], None]] = set()
|
||||
|
||||
# A set of methods that will be ignored while tracing,
|
||||
# to prevent graph breaks.
|
||||
# Add logging.Logger.<method> to ignore all calls for method,
|
||||
# or logger.<method> to ignore calls for method from this logger instance only.
|
||||
ignore_logger_methods: Set[Callable[..., Any]] = set()
|
||||
|
||||
# simulates what would happen if we didn't have support for BUILD_SET opcode,
|
||||
# used for testing
|
||||
inject_BUILD_SET_unimplemented_TESTING_ONLY = False
|
||||
|
|
|
|||
|
|
@ -1089,6 +1089,7 @@ def _compile(
|
|||
"inject_BUILD_SET_unimplemented_TESTING_ONLY",
|
||||
"_autograd_backward_strict_mode_banned_ops",
|
||||
"reorderable_logging_functions",
|
||||
"ignore_logger_methods",
|
||||
"traceable_tensor_subclasses",
|
||||
"_custom_ops_profile",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1454,6 +1454,7 @@ class LoggingLoggerVariable(VariableTracker):
|
|||
|
||||
def __init__(self, value, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.value = value
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
|
|
@ -1465,7 +1466,15 @@ class LoggingLoggerVariable(VariableTracker):
|
|||
if tx.export:
|
||||
# For export cases, we can just make debugging functions no-ops
|
||||
return
|
||||
unimplemented("Logger not supported for non-export cases")
|
||||
method = getattr(self.value, name, None)
|
||||
function = getattr(method, "__func__", None)
|
||||
if {method, function}.intersection(torch._dynamo.config.ignore_logger_methods):
|
||||
return variables.ConstantVariable.create(None)
|
||||
unimplemented(
|
||||
"Logger not supported for non-export cases. "
|
||||
"To avoid graph breaks caused by logger in compile-mode, it is recommended to"
|
||||
" disable logging by adding logging methods to config.ignore_logger_methods"
|
||||
)
|
||||
|
||||
|
||||
class ConstantLikeVariable(VariableTracker):
|
||||
|
|
|
|||
Loading…
Reference in a new issue