mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Summary:
Description:
- Before this, logging level could only be changed by changing the env
variable "PYTORCH_JIT_LOG_LEVEL"
- Can change the level from python now
- Have not added stream configuration for now
- Configuration is stored in a singleton class managing the options
Issue Link: https://github.com/pytorch/pytorch/issues/54188
Gotchas:
- Created separate functions
`::torch::jit::get_jit_logging_levels/set_jit_logging_levels` instead of
using the singleton class's method directly
- This is because when running test cases, two different instances
of the singleton are created for the test suite and the actual code
(`jit_log.cpp`)
- On using these methods directly, `is_enabled` calls the singleton
in `jit_log.cpp` while we are setting the config using another
singleton
- See: https://stackoverflow.com/questions/55467246/my-singleton-can-be-called-multiple-times
API:
- To set the level: `torch._C._jit_set_logging_option("level")`
- To get the level: `torch._C._jit_get_logging_option()`
Testing:
- UTs were added for C++
- A very simple UT was added for python to just check if the API is
being called correctly
- The API was checked by running trace in a sample python file
- Set env variable to "" and used `_jit_set_logging_option` in python to set the variable to `>dead_code_elimination`
- The error output had logs of form [DUMP..] [UPDATE...] etc
Fixes https://github.com/pytorch/pytorch/issues/54188
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58821
Reviewed By: soulitzer
Differential Revision: D29116712
Pulled By: ZolotukhinM
fbshipit-source-id: 8f2861ee2bd567fb63b405953d035ca657a3200f
117 lines
4.2 KiB
Python
117 lines
4.2 KiB
Python
import os
|
|
import sys
|
|
|
|
import torch
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
|
|
if __name__ == '__main__':
|
|
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
|
"\tpython test/test_jit.py TESTNAME\n\n"
|
|
"instead.")
|
|
|
|
class TestLogging(JitTestCase):
|
|
def test_bump_numeric_counter(self):
|
|
class ModuleThatLogs(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
for i in range(x.size(0)):
|
|
x += 1.0
|
|
torch.jit._logging.add_stat_value('foo', 1)
|
|
|
|
if bool(x.sum() > 0.0):
|
|
torch.jit._logging.add_stat_value('positive', 1)
|
|
else:
|
|
torch.jit._logging.add_stat_value('negative', 1)
|
|
return x
|
|
|
|
logger = torch.jit._logging.LockingLogger()
|
|
old_logger = torch.jit._logging.set_logger(logger)
|
|
try:
|
|
|
|
mtl = ModuleThatLogs()
|
|
for i in range(5):
|
|
mtl(torch.rand(3, 4, 5))
|
|
|
|
self.assertEqual(logger.get_counter_val('foo'), 15)
|
|
self.assertEqual(logger.get_counter_val('positive'), 5)
|
|
finally:
|
|
torch.jit._logging.set_logger(old_logger)
|
|
|
|
def test_trace_numeric_counter(self):
|
|
def foo(x):
|
|
torch.jit._logging.add_stat_value('foo', 1)
|
|
return x + 1.0
|
|
|
|
traced = torch.jit.trace(foo, torch.rand(3, 4))
|
|
logger = torch.jit._logging.LockingLogger()
|
|
old_logger = torch.jit._logging.set_logger(logger)
|
|
try:
|
|
traced(torch.rand(3, 4))
|
|
|
|
self.assertEqual(logger.get_counter_val('foo'), 1)
|
|
finally:
|
|
torch.jit._logging.set_logger(old_logger)
|
|
|
|
def test_time_measurement_counter(self):
|
|
class ModuleThatTimes(torch.jit.ScriptModule):
|
|
def forward(self, x):
|
|
tp_start = torch.jit._logging.time_point()
|
|
for i in range(30):
|
|
x += 1.0
|
|
tp_end = torch.jit._logging.time_point()
|
|
torch.jit._logging.add_stat_value('mytimer', tp_end - tp_start)
|
|
return x
|
|
|
|
mtm = ModuleThatTimes()
|
|
logger = torch.jit._logging.LockingLogger()
|
|
old_logger = torch.jit._logging.set_logger(logger)
|
|
try:
|
|
mtm(torch.rand(3, 4))
|
|
self.assertGreater(logger.get_counter_val('mytimer'), 0)
|
|
finally:
|
|
torch.jit._logging.set_logger(old_logger)
|
|
|
|
def test_time_measurement_counter_script(self):
|
|
class ModuleThatTimes(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
tp_start = torch.jit._logging.time_point()
|
|
for i in range(30):
|
|
x += 1.0
|
|
tp_end = torch.jit._logging.time_point()
|
|
torch.jit._logging.add_stat_value('mytimer', tp_end - tp_start)
|
|
return x
|
|
|
|
mtm = ModuleThatTimes()
|
|
logger = torch.jit._logging.LockingLogger()
|
|
old_logger = torch.jit._logging.set_logger(logger)
|
|
try:
|
|
mtm(torch.rand(3, 4))
|
|
self.assertGreater(logger.get_counter_val('mytimer'), 0)
|
|
finally:
|
|
torch.jit._logging.set_logger(old_logger)
|
|
|
|
def test_counter_aggregation(self):
|
|
def foo(x):
|
|
for i in range(3):
|
|
torch.jit._logging.add_stat_value('foo', 1)
|
|
return x + 1.0
|
|
|
|
traced = torch.jit.trace(foo, torch.rand(3, 4))
|
|
logger = torch.jit._logging.LockingLogger()
|
|
logger.set_aggregation_type('foo', torch.jit._logging.AggregationType.AVG)
|
|
old_logger = torch.jit._logging.set_logger(logger)
|
|
try:
|
|
traced(torch.rand(3, 4))
|
|
|
|
self.assertEqual(logger.get_counter_val('foo'), 1)
|
|
finally:
|
|
torch.jit._logging.set_logger(old_logger)
|
|
|
|
def test_logging_levels_set(self):
|
|
torch._C._jit_set_logging_option('foo')
|
|
self.assertEqual('foo', torch._C._jit_get_logging_option())
|