mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Upgrade LoggingTensor mode and add traceback collection (#103734)
Parts borrowed from: https://github.com/albanD/subclass_zoo/blob/main/logging_mode.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/103734 Approved by: https://github.com/albanD
This commit is contained in:
parent
09fdea8564
commit
3912b722f3
4 changed files with 139 additions and 87 deletions
|
|
@ -1312,8 +1312,8 @@ def forward(self, arg0_1):
|
|||
# Make sure that functionalization ran the "+" kernel
|
||||
# with a functional + non-functional tensor, and wrapped the output appropriately.
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$2 = torch._ops.aten.add.Tensor($0, $1)
|
||||
$3 = torch._ops.aten.add.Tensor($2, 1)""")
|
||||
$2: f32[4] = torch._ops.aten.add.Tensor($0, $1)
|
||||
$3: f32[4] = torch._ops.aten.add.Tensor($2, 1)""")
|
||||
|
||||
def test_mixed_wrappers_invalid(self):
|
||||
x1_not_functional = torch.ones(4)
|
||||
|
|
|
|||
|
|
@ -1155,8 +1155,8 @@ class TestPrimsBasic(TestCase):
|
|||
log_input("input", r)
|
||||
prims.sin(r)
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$0 = input('input')
|
||||
$1 = torch._ops.prims.sin.default($0)""")
|
||||
$0: f32[2] = input('input')
|
||||
$1: f32[2] = torch._ops.prims.sin.default($0)""")
|
||||
|
||||
def test_mul_complex(self):
|
||||
prims.mul(torch.randn(2), 1 + 1j)
|
||||
|
|
|
|||
|
|
@ -1485,13 +1485,13 @@ class TestPythonDispatch(TestCase):
|
|||
# TODO: figure out why broken
|
||||
# self.assertEqual(saved_x._version, x._version)
|
||||
self.assertExpectedInline('\n'.join(logs), '''\
|
||||
$0 = input('x')
|
||||
$1 = torch._ops.aten.mul.Tensor($0, $0)
|
||||
$2 = input('grad_y')
|
||||
$0: f32[1] = input('x')
|
||||
$1: f32[1] = torch._ops.aten.mul.Tensor($0, $0)
|
||||
$2: f32[1] = input('grad_y')
|
||||
True = torch._ops.aten.is_same_size.default($1, $2)
|
||||
$3 = torch._ops.aten.mul.Tensor($2, $0)
|
||||
$4 = torch._ops.aten.mul.Tensor($2, $0)
|
||||
$5 = torch._ops.aten.add.Tensor($4, $3)''')
|
||||
$3: f32[1] = torch._ops.aten.mul.Tensor($2, $0)
|
||||
$4: f32[1] = torch._ops.aten.mul.Tensor($2, $0)
|
||||
$5: f32[1] = torch._ops.aten.add.Tensor($4, $3)''')
|
||||
|
||||
def test_out(self) -> None:
|
||||
with capture_logs() as logs:
|
||||
|
|
@ -1505,9 +1505,9 @@ $5 = torch._ops.aten.add.Tensor($4, $3)''')
|
|||
# TODO: arguably this shouldn't pass and we should complain
|
||||
# that out isn't a kwarg
|
||||
self.assertExpectedInline('\n'.join(logs), '''\
|
||||
$0 = input('x')
|
||||
$1 = input('y')
|
||||
$2 = torch._ops.aten.abs.out($0, out=$1)''')
|
||||
$0: f32[1] = input('x')
|
||||
$1: f32[1] = input('y')
|
||||
$2: f32[1] = torch._ops.aten.abs.out($0, out=$1)''')
|
||||
|
||||
def test_kwarg_only(self) -> None:
|
||||
with capture_logs() as logs:
|
||||
|
|
@ -1526,14 +1526,14 @@ $2 = torch._ops.aten.abs.out($0, out=$1)''')
|
|||
# The expectation is that beta/alpha don't show up when they're
|
||||
# defaulted. This is even if the user explicitly specified it.
|
||||
self.assertExpectedInline('\n'.join(logs), '''\
|
||||
$0 = input('x')
|
||||
$1 = input('y')
|
||||
$2 = input('z')
|
||||
$3 = torch._ops.aten.addmv.default($0, $1, $2)
|
||||
$4 = torch._ops.aten.addmv.default($0, $1, $2)
|
||||
$5 = torch._ops.aten.addmv.default($0, $1, $2, beta=2)
|
||||
$6 = torch._ops.aten.addmv.default($0, $1, $2, alpha=2)
|
||||
$7 = torch._ops.aten.addmv.default($0, $1, $2, beta=2, alpha=2)''')
|
||||
$0: f32[1] = input('x')
|
||||
$1: f32[1, 1] = input('y')
|
||||
$2: f32[1] = input('z')
|
||||
$3: f32[1] = torch._ops.aten.addmv.default($0, $1, $2)
|
||||
$4: f32[1] = torch._ops.aten.addmv.default($0, $1, $2)
|
||||
$5: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, beta=2)
|
||||
$6: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, alpha=2)
|
||||
$7: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, beta=2, alpha=2)''')
|
||||
|
||||
def test_kwarg_only_and_positional_default(self) -> None:
|
||||
with capture_logs() as logs:
|
||||
|
|
@ -1547,11 +1547,11 @@ $7 = torch._ops.aten.addmv.default($0, $1, $2, beta=2, alpha=2)''')
|
|||
# What we are testing here is that we omit arg2
|
||||
# if it is defaulted, even if a kwarg is set
|
||||
self.assertExpectedInline('\n'.join(logs), '''\
|
||||
$0 = input('x')
|
||||
$1 = torch._ops.aten._foobar.default($0)
|
||||
$2 = torch._ops.aten._foobar.default($0, False)
|
||||
$3 = torch._ops.aten._foobar.default($0, arg3=False)
|
||||
$4 = torch._ops.aten._foobar.default($0, False, arg3=False)''')
|
||||
$0: f32[1] = input('x')
|
||||
$1: f32[1] = torch._ops.aten._foobar.default($0)
|
||||
$2: f32[1] = torch._ops.aten._foobar.default($0, False)
|
||||
$3: f32[1] = torch._ops.aten._foobar.default($0, arg3=False)
|
||||
$4: f32[1] = torch._ops.aten._foobar.default($0, False, arg3=False)''')
|
||||
|
||||
def test_produce_real_type(self) -> None:
|
||||
with capture_logs() as logs:
|
||||
|
|
@ -1564,12 +1564,12 @@ $4 = torch._ops.aten._foobar.default($0, False, arg3=False)''')
|
|||
# triggerable using tensor subclasses (need to use a mode)
|
||||
|
||||
self.assertExpectedInline('\n'.join(logs), '''\
|
||||
$0 = input('x')
|
||||
$1 = torch._ops.aten._to_copy.default($0, dtype=torch.float64)
|
||||
$2 = torch._ops.aten.cumprod.default($0, 0, dtype=torch.float64)
|
||||
$3 = torch._ops.aten.slice.Tensor($0, 0, 0, 9223372036854775807)
|
||||
$4 = torch._ops.aten.select.int($3, 1, 1)
|
||||
$5 = torch._ops.aten.clone.default($4, memory_format=torch.contiguous_format)''')
|
||||
$0: f32[2, 2] = input('x')
|
||||
$1: f64[2, 2] = torch._ops.aten._to_copy.default($0, dtype=torch.float64)
|
||||
$2: f64[2, 2] = torch._ops.aten.cumprod.default($0, 0, dtype=torch.float64)
|
||||
$3: f32[2, 2] = torch._ops.aten.slice.Tensor($0, 0, 0, 9223372036854775807)
|
||||
$4: f32[2] = torch._ops.aten.select.int($3, 1, 1)
|
||||
$5: f32[2] = torch._ops.aten.clone.default($4, memory_format=torch.contiguous_format)''')
|
||||
|
||||
def test_optional_tensor_list(self) -> None:
|
||||
def weird(xs):
|
||||
|
|
@ -1585,9 +1585,8 @@ $5 = torch._ops.aten.clone.default($4, memory_format=torch.contiguous_format)'''
|
|||
torch.ops.my_lib.weird.default([None, x])
|
||||
|
||||
self.assertExpectedInline('\n'.join(logs), '''\
|
||||
$0 = input('x')
|
||||
$1 = torch._ops.my_lib.weird.default([None, LoggingTensor(tensor([[1., 1.],
|
||||
[1., 1.]]))])''')
|
||||
$0: f32[2, 2] = input('x')
|
||||
$1: f32[] = torch._ops.my_lib.weird.default(['None', '$0'])''')
|
||||
|
||||
def test_list_ret(self) -> None:
|
||||
# test all sequence types are permissible returns
|
||||
|
|
@ -1639,9 +1638,9 @@ $1 = torch._ops.my_lib.weird.default([None, LoggingTensor(tensor([[1., 1.],
|
|||
# this test here to make sure we don't regress even further (it
|
||||
# would be bad if calling .detach() once emits 3+ detaches).
|
||||
self.assertExpectedInline('\n'.join(logs), '''\
|
||||
$0 = input('x')
|
||||
$1 = torch._ops.aten.detach.default($0)
|
||||
$2 = torch._ops.aten.detach.default($1)''')
|
||||
$0: f32[1] = input('x')
|
||||
$1: f32[1] = torch._ops.aten.detach.default($0)
|
||||
$2: f32[1] = torch._ops.aten.detach.default($1)''')
|
||||
|
||||
def test_storage(self) -> None:
|
||||
# For now, just make sure it doesn't crash. Ideally, we should
|
||||
|
|
@ -1744,14 +1743,14 @@ $2 = torch._ops.aten.detach.default($1)''')
|
|||
# self.assertEqual(escape[0]._version, x._version)
|
||||
|
||||
self.assertExpectedInline('\n'.join(logs), '''\
|
||||
$0 = input('x')
|
||||
$1 = input('x.grad')
|
||||
$2 = torch._ops.aten.pow.Tensor_Scalar($0, 2)
|
||||
$3 = input('grad_output')
|
||||
$0: f32[1] = input('x')
|
||||
$1: f32[1] = input('x.grad')
|
||||
$2: f32[1] = torch._ops.aten.pow.Tensor_Scalar($0, 2)
|
||||
$3: f32[1] = input('grad_output')
|
||||
True = torch._ops.aten.is_same_size.default($2, $3)
|
||||
$4 = torch._ops.aten.mul.Tensor($3, 2)
|
||||
$5 = torch._ops.aten.mul.Tensor($4, $0)
|
||||
$6 = torch._ops.aten.add_.Tensor($1, $5)''')
|
||||
$4: f32[1] = torch._ops.aten.mul.Tensor($3, 2)
|
||||
$5: f32[1] = torch._ops.aten.mul.Tensor($4, $0)
|
||||
$6: f32[1] = torch._ops.aten.add_.Tensor($1, $5)''')
|
||||
|
||||
def test_subclass_creation(self):
|
||||
# Make sure these statements runs without error
|
||||
|
|
@ -1934,7 +1933,7 @@ $6 = torch._ops.aten.add_.Tensor($1, $5)''')
|
|||
with LoggingTensorMode():
|
||||
torch.empty([])
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""")
|
||||
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""")
|
||||
|
||||
def test_torch_dispatch_mode_unrelated_tensors(self) -> None:
|
||||
x = torch.randn([])
|
||||
|
|
@ -1942,8 +1941,7 @@ $0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memo
|
|||
with capture_logs(is_mode=True) as logs:
|
||||
with LoggingTensorMode():
|
||||
x + y
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$2 = torch._ops.aten.add.Tensor($0, $1)""")
|
||||
self.assertExpectedInline('\n'.join(logs), """$2: f32[] = torch._ops.aten.add.Tensor($0, $1)""")
|
||||
|
||||
def test_nested_push_logging_tensor_mode(self):
|
||||
x = torch.randn([])
|
||||
|
|
@ -1955,10 +1953,10 @@ $2 = torch._ops.aten.add.Tensor($0, $1)""")
|
|||
x + y
|
||||
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
|
||||
$0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
|
||||
$3 = torch._ops.aten.add.Tensor($1, $2)
|
||||
$3 = torch._ops.aten.add.Tensor($1, $2)""")
|
||||
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
|
||||
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
|
||||
$3: f32[] = torch._ops.aten.add.Tensor($1, $2)
|
||||
$3: f32[] = torch._ops.aten.add.Tensor($1, $2)""")
|
||||
|
||||
def test_capture_logs_with_torch_dispatch_mode(self):
|
||||
x = torch.randn([])
|
||||
|
|
@ -1967,8 +1965,8 @@ $3 = torch._ops.aten.add.Tensor($1, $2)""")
|
|||
torch.empty([])
|
||||
x + y
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
|
||||
$3 = torch._ops.aten.add.Tensor($1, $2)""")
|
||||
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
|
||||
$3: f32[] = torch._ops.aten.add.Tensor($1, $2)""")
|
||||
|
||||
x = torch.randn([])
|
||||
y = torch.randn([])
|
||||
|
|
@ -1979,10 +1977,10 @@ $3 = torch._ops.aten.add.Tensor($1, $2)""")
|
|||
x + y
|
||||
|
||||
self.assertExpectedInline('\n'.join(logs2), """\
|
||||
$0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
|
||||
$0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
|
||||
$3 = torch._ops.aten.add.Tensor($1, $2)
|
||||
$3 = torch._ops.aten.add.Tensor($1, $2)""")
|
||||
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
|
||||
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
|
||||
$3: f32[] = torch._ops.aten.add.Tensor($1, $2)
|
||||
$3: f32[] = torch._ops.aten.add.Tensor($1, $2)""")
|
||||
|
||||
self.assertEqual(logs1, logs2)
|
||||
|
||||
|
|
@ -2238,8 +2236,8 @@ $3 = torch._ops.aten.add.Tensor($1, $2)""")
|
|||
with reenabled:
|
||||
torch.empty([])
|
||||
self.assertExpectedInline('\n'.join(logs), """\
|
||||
$0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
|
||||
$0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""")
|
||||
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
|
||||
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""")
|
||||
|
||||
|
||||
def test_error_using_class_method_on_mode(self):
|
||||
|
|
|
|||
|
|
@ -1,12 +1,31 @@
|
|||
import torch
|
||||
from torch.utils._pytree import tree_map
|
||||
from typing import Iterator, List
|
||||
from typing import Iterator, List, Optional
|
||||
import logging
|
||||
import contextlib
|
||||
import itertools
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch.utils.weak import WeakTensorKeyDictionary
|
||||
import functools
|
||||
from torch._C._profiler import gather_traceback, symbolize_tracebacks
|
||||
|
||||
|
||||
_dtype_abbrs = {
|
||||
torch.bfloat16: "bf16",
|
||||
torch.float64: "f64",
|
||||
torch.float32: "f32",
|
||||
torch.float16: "f16",
|
||||
torch.complex32: "c32",
|
||||
torch.complex64: "c64",
|
||||
torch.complex128: "c128",
|
||||
torch.int8: "i8",
|
||||
torch.int16: "i16",
|
||||
torch.int32: "i32",
|
||||
torch.int64: "i64",
|
||||
torch.bool: "b8",
|
||||
torch.uint8: "u8",
|
||||
}
|
||||
|
||||
# How the chain of calls works for LoggingTensor:
|
||||
# 1. Call torch.sin
|
||||
# 2. Attempt __torch_function__. In LoggingTensor torch function is disabled so we bypass it entirely
|
||||
|
|
@ -76,53 +95,88 @@ class LoggingTensorReentrant(LoggingTensor):
|
|||
|
||||
# https://stackoverflow.com/questions/36408496/python-logging-handler-to-append-to-list
|
||||
class LoggingTensorHandler(logging.Handler):
|
||||
log_list: List[str]
|
||||
next_shortid: int
|
||||
|
||||
def __init__(self, log_list: List[str], use_shortid_for_all_tensors: bool) -> None:
|
||||
def __init__(
|
||||
self, log_list: List[str], use_shortid_for_all_tensors: bool,
|
||||
with_type: bool, tracebacks_list: Optional[List]) -> None:
|
||||
logging.Handler.__init__(self)
|
||||
self.log_list = log_list
|
||||
self.next_shortid = 0
|
||||
self.use_shortid_for_all_tensors = use_shortid_for_all_tensors
|
||||
self.tracebacks_list = tracebacks_list
|
||||
self.memo = WeakTensorKeyDictionary()
|
||||
self.next_id = 0
|
||||
self.with_type = with_type
|
||||
|
||||
# WARNING: not deterministic over multiple threads, this matters for
|
||||
# autograd
|
||||
def _shortid(self, o: object) -> int:
|
||||
if not hasattr(o, '_shortid'):
|
||||
o._shortid = self.next_shortid # type: ignore[attr-defined]
|
||||
self.next_shortid += 1
|
||||
return o._shortid # type: ignore[attr-defined]
|
||||
def _shortid(self, t: torch.Tensor) -> int:
|
||||
if t not in self.memo:
|
||||
self.memo[t] = self.next_id
|
||||
self.next_id += 1
|
||||
return self.memo[t]
|
||||
|
||||
def _fmt(self, a: object) -> str:
|
||||
def _fmt(self, a: object, with_type: bool = False) -> str:
|
||||
cond_cls = torch.Tensor if self.use_shortid_for_all_tensors else LoggingTensor
|
||||
return f'${self._shortid(a)}' if isinstance(a, cond_cls) else repr(a)
|
||||
if isinstance(a, cond_cls):
|
||||
maybe_type = ""
|
||||
if with_type and self.with_type:
|
||||
maybe_type = f": {_dtype_abbrs[a.dtype]}[{', '.join(map(str, a.shape))}]"
|
||||
x = f"${self._shortid(a)}{maybe_type}"
|
||||
return x
|
||||
else:
|
||||
return repr(a)
|
||||
|
||||
def emit(self, record):
|
||||
fmt_args = ", ".join(itertools.chain(
|
||||
(self._fmt(a) for a in record.args[0]),
|
||||
(f"{k}={self._fmt(v)}" for k, v in record.args[1].items())
|
||||
))
|
||||
fmt_rets = ", ".join(self._fmt(a) for a in record.args[2]) \
|
||||
if isinstance(record.args[2], (list, tuple)) else self._fmt(record.args[2])
|
||||
fmt_args = ", ".join(
|
||||
itertools.chain(
|
||||
(str(tree_map(self._fmt, a)) for a in record.args[0]),
|
||||
(f"{k}={str(tree_map(self._fmt, v))}" for k, v in record.args[1].items()),
|
||||
)
|
||||
)
|
||||
fmt_rets = tree_map(functools.partial(self._fmt, with_type=True), record.args[2])
|
||||
self.log_list.append(f'{fmt_rets} = {record.msg}({fmt_args})')
|
||||
if self.tracebacks_list is not None:
|
||||
self.tracebacks_list.append(record.traceback)
|
||||
|
||||
def log_input(name: str, var: object):
|
||||
logging.getLogger("LoggingTensor").info("input", (name,), {}, (var,))
|
||||
logging.getLogger("LoggingTensor").info("input", (name,), {}, var)
|
||||
|
||||
class GatherTraceback(logging.Filter):
|
||||
def __init__(self, python=True, script=True, cpp=False):
|
||||
self.python = python
|
||||
self.script = script
|
||||
self.cpp = cpp
|
||||
|
||||
def filter(self, record):
|
||||
record.traceback = gather_traceback(python=self.python, script=self.script, cpp=self.cpp)
|
||||
return True
|
||||
|
||||
@contextlib.contextmanager
|
||||
def capture_logs(is_mode=False) -> Iterator[List[str]]:
|
||||
def capture_logs(is_mode=False, python_tb=False, script_tb=False, cpp_tb=False) -> Iterator[List[str]]:
|
||||
collect_traceback = python_tb or script_tb or cpp_tb
|
||||
logger = logging.getLogger("LoggingTensor")
|
||||
log_list: List[str] = []
|
||||
handler = LoggingTensorHandler(log_list, use_shortid_for_all_tensors=is_mode)
|
||||
tracebacks_list: List[str] = []
|
||||
handler = LoggingTensorHandler(
|
||||
log_list,
|
||||
with_type=True,
|
||||
use_shortid_for_all_tensors=is_mode,
|
||||
tracebacks_list=tracebacks_list if collect_traceback else None
|
||||
)
|
||||
logger.addHandler(handler)
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.propagate = False
|
||||
if collect_traceback:
|
||||
logger.addFilter(GatherTraceback(python=python_tb, script=script_tb, cpp=cpp_tb))
|
||||
try:
|
||||
yield log_list
|
||||
if collect_traceback:
|
||||
yield log_list, tracebacks_list
|
||||
else:
|
||||
yield log_list
|
||||
finally:
|
||||
symbolized_tracebacks = symbolize_tracebacks(tracebacks_list)
|
||||
tracebacks_list.clear()
|
||||
tracebacks_list.extend(symbolized_tracebacks)
|
||||
logger.removeHandler(handler)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def capture_logs_with_logging_tensor_mode():
|
||||
with LoggingTensorMode(), capture_logs(True) as logs:
|
||||
def capture_logs_with_logging_tensor_mode(python_tb=False, script_tb=False, cpp_tb=False):
|
||||
with LoggingTensorMode(), capture_logs(True, python_tb, script_tb, cpp_tb) as logs:
|
||||
yield logs
|
||||
|
|
|
|||
Loading…
Reference in a new issue