Upgrade LoggingTensor mode and add traceback collection (#103734)

Parts borrowed from: https://github.com/albanD/subclass_zoo/blob/main/logging_mode.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103734
Approved by: https://github.com/albanD
This commit is contained in:
soulitzer 2023-06-21 12:12:52 -04:00 committed by PyTorch MergeBot
parent 09fdea8564
commit 3912b722f3
4 changed files with 139 additions and 87 deletions

View file

@ -1312,8 +1312,8 @@ def forward(self, arg0_1):
# Make sure that functionalization ran the "+" kernel
# with a functional + non-functional tensor, and wrapped the output appropriately.
self.assertExpectedInline('\n'.join(logs), """\
$2 = torch._ops.aten.add.Tensor($0, $1)
$3 = torch._ops.aten.add.Tensor($2, 1)""")
$2: f32[4] = torch._ops.aten.add.Tensor($0, $1)
$3: f32[4] = torch._ops.aten.add.Tensor($2, 1)""")
def test_mixed_wrappers_invalid(self):
x1_not_functional = torch.ones(4)

View file

@ -1155,8 +1155,8 @@ class TestPrimsBasic(TestCase):
log_input("input", r)
prims.sin(r)
self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1 = torch._ops.prims.sin.default($0)""")
$0: f32[2] = input('input')
$1: f32[2] = torch._ops.prims.sin.default($0)""")
def test_mul_complex(self):
prims.mul(torch.randn(2), 1 + 1j)

View file

@ -1485,13 +1485,13 @@ class TestPythonDispatch(TestCase):
# TODO: figure out why broken
# self.assertEqual(saved_x._version, x._version)
self.assertExpectedInline('\n'.join(logs), '''\
$0 = input('x')
$1 = torch._ops.aten.mul.Tensor($0, $0)
$2 = input('grad_y')
$0: f32[1] = input('x')
$1: f32[1] = torch._ops.aten.mul.Tensor($0, $0)
$2: f32[1] = input('grad_y')
True = torch._ops.aten.is_same_size.default($1, $2)
$3 = torch._ops.aten.mul.Tensor($2, $0)
$4 = torch._ops.aten.mul.Tensor($2, $0)
$5 = torch._ops.aten.add.Tensor($4, $3)''')
$3: f32[1] = torch._ops.aten.mul.Tensor($2, $0)
$4: f32[1] = torch._ops.aten.mul.Tensor($2, $0)
$5: f32[1] = torch._ops.aten.add.Tensor($4, $3)''')
def test_out(self) -> None:
with capture_logs() as logs:
@ -1505,9 +1505,9 @@ $5 = torch._ops.aten.add.Tensor($4, $3)''')
# TODO: arguably this shouldn't pass and we should complain
# that out isn't a kwarg
self.assertExpectedInline('\n'.join(logs), '''\
$0 = input('x')
$1 = input('y')
$2 = torch._ops.aten.abs.out($0, out=$1)''')
$0: f32[1] = input('x')
$1: f32[1] = input('y')
$2: f32[1] = torch._ops.aten.abs.out($0, out=$1)''')
def test_kwarg_only(self) -> None:
with capture_logs() as logs:
@ -1526,14 +1526,14 @@ $2 = torch._ops.aten.abs.out($0, out=$1)''')
# The expectation is that beta/alpha don't show up when they're
# defaulted. This is even if the user explicitly specified it.
self.assertExpectedInline('\n'.join(logs), '''\
$0 = input('x')
$1 = input('y')
$2 = input('z')
$3 = torch._ops.aten.addmv.default($0, $1, $2)
$4 = torch._ops.aten.addmv.default($0, $1, $2)
$5 = torch._ops.aten.addmv.default($0, $1, $2, beta=2)
$6 = torch._ops.aten.addmv.default($0, $1, $2, alpha=2)
$7 = torch._ops.aten.addmv.default($0, $1, $2, beta=2, alpha=2)''')
$0: f32[1] = input('x')
$1: f32[1, 1] = input('y')
$2: f32[1] = input('z')
$3: f32[1] = torch._ops.aten.addmv.default($0, $1, $2)
$4: f32[1] = torch._ops.aten.addmv.default($0, $1, $2)
$5: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, beta=2)
$6: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, alpha=2)
$7: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, beta=2, alpha=2)''')
def test_kwarg_only_and_positional_default(self) -> None:
with capture_logs() as logs:
@ -1547,11 +1547,11 @@ $7 = torch._ops.aten.addmv.default($0, $1, $2, beta=2, alpha=2)''')
# What we are testing here is that we omit arg2
# if it is defaulted, even if a kwarg is set
self.assertExpectedInline('\n'.join(logs), '''\
$0 = input('x')
$1 = torch._ops.aten._foobar.default($0)
$2 = torch._ops.aten._foobar.default($0, False)
$3 = torch._ops.aten._foobar.default($0, arg3=False)
$4 = torch._ops.aten._foobar.default($0, False, arg3=False)''')
$0: f32[1] = input('x')
$1: f32[1] = torch._ops.aten._foobar.default($0)
$2: f32[1] = torch._ops.aten._foobar.default($0, False)
$3: f32[1] = torch._ops.aten._foobar.default($0, arg3=False)
$4: f32[1] = torch._ops.aten._foobar.default($0, False, arg3=False)''')
def test_produce_real_type(self) -> None:
with capture_logs() as logs:
@ -1564,12 +1564,12 @@ $4 = torch._ops.aten._foobar.default($0, False, arg3=False)''')
# triggerable using tensor subclasses (need to use a mode)
self.assertExpectedInline('\n'.join(logs), '''\
$0 = input('x')
$1 = torch._ops.aten._to_copy.default($0, dtype=torch.float64)
$2 = torch._ops.aten.cumprod.default($0, 0, dtype=torch.float64)
$3 = torch._ops.aten.slice.Tensor($0, 0, 0, 9223372036854775807)
$4 = torch._ops.aten.select.int($3, 1, 1)
$5 = torch._ops.aten.clone.default($4, memory_format=torch.contiguous_format)''')
$0: f32[2, 2] = input('x')
$1: f64[2, 2] = torch._ops.aten._to_copy.default($0, dtype=torch.float64)
$2: f64[2, 2] = torch._ops.aten.cumprod.default($0, 0, dtype=torch.float64)
$3: f32[2, 2] = torch._ops.aten.slice.Tensor($0, 0, 0, 9223372036854775807)
$4: f32[2] = torch._ops.aten.select.int($3, 1, 1)
$5: f32[2] = torch._ops.aten.clone.default($4, memory_format=torch.contiguous_format)''')
def test_optional_tensor_list(self) -> None:
def weird(xs):
@ -1585,9 +1585,8 @@ $5 = torch._ops.aten.clone.default($4, memory_format=torch.contiguous_format)'''
torch.ops.my_lib.weird.default([None, x])
self.assertExpectedInline('\n'.join(logs), '''\
$0 = input('x')
$1 = torch._ops.my_lib.weird.default([None, LoggingTensor(tensor([[1., 1.],
[1., 1.]]))])''')
$0: f32[2, 2] = input('x')
$1: f32[] = torch._ops.my_lib.weird.default(['None', '$0'])''')
def test_list_ret(self) -> None:
# test all sequence types are permissible returns
@ -1639,9 +1638,9 @@ $1 = torch._ops.my_lib.weird.default([None, LoggingTensor(tensor([[1., 1.],
# this test here to make sure we don't regress even further (it
# would be bad if calling .detach() once emits 3+ detaches).
self.assertExpectedInline('\n'.join(logs), '''\
$0 = input('x')
$1 = torch._ops.aten.detach.default($0)
$2 = torch._ops.aten.detach.default($1)''')
$0: f32[1] = input('x')
$1: f32[1] = torch._ops.aten.detach.default($0)
$2: f32[1] = torch._ops.aten.detach.default($1)''')
def test_storage(self) -> None:
# For now, just make sure it doesn't crash. Ideally, we should
@ -1744,14 +1743,14 @@ $2 = torch._ops.aten.detach.default($1)''')
# self.assertEqual(escape[0]._version, x._version)
self.assertExpectedInline('\n'.join(logs), '''\
$0 = input('x')
$1 = input('x.grad')
$2 = torch._ops.aten.pow.Tensor_Scalar($0, 2)
$3 = input('grad_output')
$0: f32[1] = input('x')
$1: f32[1] = input('x.grad')
$2: f32[1] = torch._ops.aten.pow.Tensor_Scalar($0, 2)
$3: f32[1] = input('grad_output')
True = torch._ops.aten.is_same_size.default($2, $3)
$4 = torch._ops.aten.mul.Tensor($3, 2)
$5 = torch._ops.aten.mul.Tensor($4, $0)
$6 = torch._ops.aten.add_.Tensor($1, $5)''')
$4: f32[1] = torch._ops.aten.mul.Tensor($3, 2)
$5: f32[1] = torch._ops.aten.mul.Tensor($4, $0)
$6: f32[1] = torch._ops.aten.add_.Tensor($1, $5)''')
def test_subclass_creation(self):
# Make sure these statements runs without error
@ -1934,7 +1933,7 @@ $6 = torch._ops.aten.add_.Tensor($1, $5)''')
with LoggingTensorMode():
torch.empty([])
self.assertExpectedInline('\n'.join(logs), """\
$0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""")
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""")
def test_torch_dispatch_mode_unrelated_tensors(self) -> None:
x = torch.randn([])
@ -1942,8 +1941,7 @@ $0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memo
with capture_logs(is_mode=True) as logs:
with LoggingTensorMode():
x + y
self.assertExpectedInline('\n'.join(logs), """\
$2 = torch._ops.aten.add.Tensor($0, $1)""")
self.assertExpectedInline('\n'.join(logs), """$2: f32[] = torch._ops.aten.add.Tensor($0, $1)""")
def test_nested_push_logging_tensor_mode(self):
x = torch.randn([])
@ -1955,10 +1953,10 @@ $2 = torch._ops.aten.add.Tensor($0, $1)""")
x + y
self.assertExpectedInline('\n'.join(logs), """\
$0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
$0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
$3 = torch._ops.aten.add.Tensor($1, $2)
$3 = torch._ops.aten.add.Tensor($1, $2)""")
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
$3: f32[] = torch._ops.aten.add.Tensor($1, $2)
$3: f32[] = torch._ops.aten.add.Tensor($1, $2)""")
def test_capture_logs_with_torch_dispatch_mode(self):
x = torch.randn([])
@ -1967,8 +1965,8 @@ $3 = torch._ops.aten.add.Tensor($1, $2)""")
torch.empty([])
x + y
self.assertExpectedInline('\n'.join(logs), """\
$0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
$3 = torch._ops.aten.add.Tensor($1, $2)""")
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
$3: f32[] = torch._ops.aten.add.Tensor($1, $2)""")
x = torch.randn([])
y = torch.randn([])
@ -1979,10 +1977,10 @@ $3 = torch._ops.aten.add.Tensor($1, $2)""")
x + y
self.assertExpectedInline('\n'.join(logs2), """\
$0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
$0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
$3 = torch._ops.aten.add.Tensor($1, $2)
$3 = torch._ops.aten.add.Tensor($1, $2)""")
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
$3: f32[] = torch._ops.aten.add.Tensor($1, $2)
$3: f32[] = torch._ops.aten.add.Tensor($1, $2)""")
self.assertEqual(logs1, logs2)
@ -2238,8 +2236,8 @@ $3 = torch._ops.aten.add.Tensor($1, $2)""")
with reenabled:
torch.empty([])
self.assertExpectedInline('\n'.join(logs), """\
$0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
$0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""")
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""")
def test_error_using_class_method_on_mode(self):

View file

@ -1,12 +1,31 @@
import torch
from torch.utils._pytree import tree_map
from typing import Iterator, List
from typing import Iterator, List, Optional
import logging
import contextlib
import itertools
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils.weak import WeakTensorKeyDictionary
import functools
from torch._C._profiler import gather_traceback, symbolize_tracebacks
_dtype_abbrs = {
torch.bfloat16: "bf16",
torch.float64: "f64",
torch.float32: "f32",
torch.float16: "f16",
torch.complex32: "c32",
torch.complex64: "c64",
torch.complex128: "c128",
torch.int8: "i8",
torch.int16: "i16",
torch.int32: "i32",
torch.int64: "i64",
torch.bool: "b8",
torch.uint8: "u8",
}
# How the chain of calls works for LoggingTensor:
# 1. Call torch.sin
# 2. Attempt __torch_function__. In LoggingTensor torch function is disabled so we bypass it entirely
@ -76,53 +95,88 @@ class LoggingTensorReentrant(LoggingTensor):
# https://stackoverflow.com/questions/36408496/python-logging-handler-to-append-to-list
class LoggingTensorHandler(logging.Handler):
log_list: List[str]
next_shortid: int
def __init__(self, log_list: List[str], use_shortid_for_all_tensors: bool) -> None:
def __init__(
self, log_list: List[str], use_shortid_for_all_tensors: bool,
with_type: bool, tracebacks_list: Optional[List]) -> None:
logging.Handler.__init__(self)
self.log_list = log_list
self.next_shortid = 0
self.use_shortid_for_all_tensors = use_shortid_for_all_tensors
self.tracebacks_list = tracebacks_list
self.memo = WeakTensorKeyDictionary()
self.next_id = 0
self.with_type = with_type
# WARNING: not deterministic over multiple threads, this matters for
# autograd
def _shortid(self, o: object) -> int:
if not hasattr(o, '_shortid'):
o._shortid = self.next_shortid # type: ignore[attr-defined]
self.next_shortid += 1
return o._shortid # type: ignore[attr-defined]
def _shortid(self, t: torch.Tensor) -> int:
if t not in self.memo:
self.memo[t] = self.next_id
self.next_id += 1
return self.memo[t]
def _fmt(self, a: object) -> str:
def _fmt(self, a: object, with_type: bool = False) -> str:
cond_cls = torch.Tensor if self.use_shortid_for_all_tensors else LoggingTensor
return f'${self._shortid(a)}' if isinstance(a, cond_cls) else repr(a)
if isinstance(a, cond_cls):
maybe_type = ""
if with_type and self.with_type:
maybe_type = f": {_dtype_abbrs[a.dtype]}[{', '.join(map(str, a.shape))}]"
x = f"${self._shortid(a)}{maybe_type}"
return x
else:
return repr(a)
def emit(self, record):
fmt_args = ", ".join(itertools.chain(
(self._fmt(a) for a in record.args[0]),
(f"{k}={self._fmt(v)}" for k, v in record.args[1].items())
))
fmt_rets = ", ".join(self._fmt(a) for a in record.args[2]) \
if isinstance(record.args[2], (list, tuple)) else self._fmt(record.args[2])
fmt_args = ", ".join(
itertools.chain(
(str(tree_map(self._fmt, a)) for a in record.args[0]),
(f"{k}={str(tree_map(self._fmt, v))}" for k, v in record.args[1].items()),
)
)
fmt_rets = tree_map(functools.partial(self._fmt, with_type=True), record.args[2])
self.log_list.append(f'{fmt_rets} = {record.msg}({fmt_args})')
if self.tracebacks_list is not None:
self.tracebacks_list.append(record.traceback)
def log_input(name: str, var: object):
logging.getLogger("LoggingTensor").info("input", (name,), {}, (var,))
logging.getLogger("LoggingTensor").info("input", (name,), {}, var)
class GatherTraceback(logging.Filter):
def __init__(self, python=True, script=True, cpp=False):
self.python = python
self.script = script
self.cpp = cpp
def filter(self, record):
record.traceback = gather_traceback(python=self.python, script=self.script, cpp=self.cpp)
return True
@contextlib.contextmanager
def capture_logs(is_mode=False) -> Iterator[List[str]]:
def capture_logs(is_mode=False, python_tb=False, script_tb=False, cpp_tb=False) -> Iterator[List[str]]:
collect_traceback = python_tb or script_tb or cpp_tb
logger = logging.getLogger("LoggingTensor")
log_list: List[str] = []
handler = LoggingTensorHandler(log_list, use_shortid_for_all_tensors=is_mode)
tracebacks_list: List[str] = []
handler = LoggingTensorHandler(
log_list,
with_type=True,
use_shortid_for_all_tensors=is_mode,
tracebacks_list=tracebacks_list if collect_traceback else None
)
logger.addHandler(handler)
logger.setLevel(logging.INFO)
logger.propagate = False
if collect_traceback:
logger.addFilter(GatherTraceback(python=python_tb, script=script_tb, cpp=cpp_tb))
try:
yield log_list
if collect_traceback:
yield log_list, tracebacks_list
else:
yield log_list
finally:
symbolized_tracebacks = symbolize_tracebacks(tracebacks_list)
tracebacks_list.clear()
tracebacks_list.extend(symbolized_tracebacks)
logger.removeHandler(handler)
@contextlib.contextmanager
def capture_logs_with_logging_tensor_mode():
with LoggingTensorMode(), capture_logs(True) as logs:
def capture_logs_with_logging_tensor_mode(python_tb=False, script_tb=False, cpp_tb=False):
with LoggingTensorMode(), capture_logs(True, python_tb, script_tb, cpp_tb) as logs:
yield logs