pytorch/test/jit/test_logging.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

125 lines
4.2 KiB
Python
Raw Permalink Normal View History

# Owner(s): ["oncall: jit"]
# ruff: noqa: F841
import os
import sys
import torch
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestLogging(JitTestCase):
def test_bump_numeric_counter(self):
class ModuleThatLogs(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
for i in range(x.size(0)):
x += 1.0
torch.jit._logging.add_stat_value("foo", 1)
if bool(x.sum() > 0.0):
torch.jit._logging.add_stat_value("positive", 1)
else:
torch.jit._logging.add_stat_value("negative", 1)
return x
logger = torch.jit._logging.LockingLogger()
old_logger = torch.jit._logging.set_logger(logger)
try:
mtl = ModuleThatLogs()
for i in range(5):
mtl(torch.rand(3, 4, 5))
self.assertEqual(logger.get_counter_val("foo"), 15)
self.assertEqual(logger.get_counter_val("positive"), 5)
finally:
torch.jit._logging.set_logger(old_logger)
def test_trace_numeric_counter(self):
def foo(x):
torch.jit._logging.add_stat_value("foo", 1)
return x + 1.0
traced = torch.jit.trace(foo, torch.rand(3, 4))
logger = torch.jit._logging.LockingLogger()
old_logger = torch.jit._logging.set_logger(logger)
try:
traced(torch.rand(3, 4))
self.assertEqual(logger.get_counter_val("foo"), 1)
finally:
torch.jit._logging.set_logger(old_logger)
def test_time_measurement_counter(self):
class ModuleThatTimes(torch.jit.ScriptModule):
def forward(self, x):
tp_start = torch.jit._logging.time_point()
for i in range(30):
x += 1.0
tp_end = torch.jit._logging.time_point()
torch.jit._logging.add_stat_value("mytimer", tp_end - tp_start)
return x
mtm = ModuleThatTimes()
logger = torch.jit._logging.LockingLogger()
old_logger = torch.jit._logging.set_logger(logger)
try:
mtm(torch.rand(3, 4))
self.assertGreater(logger.get_counter_val("mytimer"), 0)
finally:
torch.jit._logging.set_logger(old_logger)
def test_time_measurement_counter_script(self):
class ModuleThatTimes(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
tp_start = torch.jit._logging.time_point()
for i in range(30):
x += 1.0
tp_end = torch.jit._logging.time_point()
torch.jit._logging.add_stat_value("mytimer", tp_end - tp_start)
return x
mtm = ModuleThatTimes()
logger = torch.jit._logging.LockingLogger()
old_logger = torch.jit._logging.set_logger(logger)
try:
mtm(torch.rand(3, 4))
self.assertGreater(logger.get_counter_val("mytimer"), 0)
finally:
torch.jit._logging.set_logger(old_logger)
def test_counter_aggregation(self):
def foo(x):
for i in range(3):
torch.jit._logging.add_stat_value("foo", 1)
return x + 1.0
traced = torch.jit.trace(foo, torch.rand(3, 4))
logger = torch.jit._logging.LockingLogger()
logger.set_aggregation_type("foo", torch.jit._logging.AggregationType.AVG)
old_logger = torch.jit._logging.set_logger(logger)
try:
traced(torch.rand(3, 4))
self.assertEqual(logger.get_counter_val("foo"), 1)
finally:
torch.jit._logging.set_logger(old_logger)
Added API to change logging levels for JIT (#58821) Summary: Description: - Before this, logging level could only be changed by changing the env variable "PYTORCH_JIT_LOG_LEVEL" - Can change the level from python now - Have not added stream configuration for now - Configuration is stored in a singleton class managing the options Issue Link: https://github.com/pytorch/pytorch/issues/54188 Gotchas: - Created separate functions `::torch::jit::get_jit_logging_levels/set_jit_logging_levels` instead of using the singleton class's method directly - This is because when running test cases, two different instances of the singleton are created for the test suite and the actual code (`jit_log.cpp`) - On using these methods directly, `is_enabled` calls the singleton in `jit_log.cpp` while we are setting the config using another singleton - See: https://stackoverflow.com/questions/55467246/my-singleton-can-be-called-multiple-times API: - To set the level: `torch._C._jit_set_logging_option("level")` - To get the level: `torch._C._jit_get_logging_option()` Testing: - UTs were added for C++ - A very simple UT was added for python to just check if the API is being called correctly - The API was checked by running trace in a sample python file - Set env variable to "" and used `_jit_set_logging_option` in python to set the variable to `>dead_code_elimination` - The error output had logs of form [DUMP..] [UPDATE...] etc Fixes https://github.com/pytorch/pytorch/issues/54188 Pull Request resolved: https://github.com/pytorch/pytorch/pull/58821 Reviewed By: soulitzer Differential Revision: D29116712 Pulled By: ZolotukhinM fbshipit-source-id: 8f2861ee2bd567fb63b405953d035ca657a3200f
2021-06-21 23:09:35 +00:00
def test_logging_levels_set(self):
torch._C._jit_set_logging_option("foo")
self.assertEqual("foo", torch._C._jit_get_logging_option())