From 3912b722f318192fc9e394ff1fd8cb4e0927c906 Mon Sep 17 00:00:00 2001 From: soulitzer Date: Wed, 21 Jun 2023 12:12:52 -0400 Subject: [PATCH] Upgrade LoggingTensor mode and add traceback collection (#103734) Parts borrowed from: https://github.com/albanD/subclass_zoo/blob/main/logging_mode.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/103734 Approved by: https://github.com/albanD --- test/test_functionalization.py | 4 +- test/test_prims.py | 4 +- test/test_python_dispatch.py | 110 +++++++++++----------- torch/testing/_internal/logging_tensor.py | 108 +++++++++++++++------ 4 files changed, 139 insertions(+), 87 deletions(-) diff --git a/test/test_functionalization.py b/test/test_functionalization.py index 1b43eb06cce..9a66b0819ed 100644 --- a/test/test_functionalization.py +++ b/test/test_functionalization.py @@ -1312,8 +1312,8 @@ def forward(self, arg0_1): # Make sure that functionalization ran the "+" kernel # with a functional + non-functional tensor, and wrapped the output appropriately. self.assertExpectedInline('\n'.join(logs), """\ -$2 = torch._ops.aten.add.Tensor($0, $1) -$3 = torch._ops.aten.add.Tensor($2, 1)""") +$2: f32[4] = torch._ops.aten.add.Tensor($0, $1) +$3: f32[4] = torch._ops.aten.add.Tensor($2, 1)""") def test_mixed_wrappers_invalid(self): x1_not_functional = torch.ones(4) diff --git a/test/test_prims.py b/test/test_prims.py index ad1e63f58a1..b7a93a1b814 100644 --- a/test/test_prims.py +++ b/test/test_prims.py @@ -1155,8 +1155,8 @@ class TestPrimsBasic(TestCase): log_input("input", r) prims.sin(r) self.assertExpectedInline('\n'.join(logs), """\ -$0 = input('input') -$1 = torch._ops.prims.sin.default($0)""") +$0: f32[2] = input('input') +$1: f32[2] = torch._ops.prims.sin.default($0)""") def test_mul_complex(self): prims.mul(torch.randn(2), 1 + 1j) diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index 6c76fda3ff9..77aa663bf09 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -1485,13 +1485,13 @@ class TestPythonDispatch(TestCase): # TODO: figure out why broken # self.assertEqual(saved_x._version, x._version) self.assertExpectedInline('\n'.join(logs), '''\ -$0 = input('x') -$1 = torch._ops.aten.mul.Tensor($0, $0) -$2 = input('grad_y') +$0: f32[1] = input('x') +$1: f32[1] = torch._ops.aten.mul.Tensor($0, $0) +$2: f32[1] = input('grad_y') True = torch._ops.aten.is_same_size.default($1, $2) -$3 = torch._ops.aten.mul.Tensor($2, $0) -$4 = torch._ops.aten.mul.Tensor($2, $0) -$5 = torch._ops.aten.add.Tensor($4, $3)''') +$3: f32[1] = torch._ops.aten.mul.Tensor($2, $0) +$4: f32[1] = torch._ops.aten.mul.Tensor($2, $0) +$5: f32[1] = torch._ops.aten.add.Tensor($4, $3)''') def test_out(self) -> None: with capture_logs() as logs: @@ -1505,9 +1505,9 @@ $5 = torch._ops.aten.add.Tensor($4, $3)''') # TODO: arguably this shouldn't pass and we should complain # that out isn't a kwarg self.assertExpectedInline('\n'.join(logs), '''\ -$0 = input('x') -$1 = input('y') -$2 = torch._ops.aten.abs.out($0, out=$1)''') +$0: f32[1] = input('x') +$1: f32[1] = input('y') +$2: f32[1] = torch._ops.aten.abs.out($0, out=$1)''') def test_kwarg_only(self) -> None: with capture_logs() as logs: @@ -1526,14 +1526,14 @@ $2 = torch._ops.aten.abs.out($0, out=$1)''') # The expectation is that beta/alpha don't show up when they're # defaulted. This is even if the user explicitly specified it. self.assertExpectedInline('\n'.join(logs), '''\ -$0 = input('x') -$1 = input('y') -$2 = input('z') -$3 = torch._ops.aten.addmv.default($0, $1, $2) -$4 = torch._ops.aten.addmv.default($0, $1, $2) -$5 = torch._ops.aten.addmv.default($0, $1, $2, beta=2) -$6 = torch._ops.aten.addmv.default($0, $1, $2, alpha=2) -$7 = torch._ops.aten.addmv.default($0, $1, $2, beta=2, alpha=2)''') +$0: f32[1] = input('x') +$1: f32[1, 1] = input('y') +$2: f32[1] = input('z') +$3: f32[1] = torch._ops.aten.addmv.default($0, $1, $2) +$4: f32[1] = torch._ops.aten.addmv.default($0, $1, $2) +$5: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, beta=2) +$6: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, alpha=2) +$7: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, beta=2, alpha=2)''') def test_kwarg_only_and_positional_default(self) -> None: with capture_logs() as logs: @@ -1547,11 +1547,11 @@ $7 = torch._ops.aten.addmv.default($0, $1, $2, beta=2, alpha=2)''') # What we are testing here is that we omit arg2 # if it is defaulted, even if a kwarg is set self.assertExpectedInline('\n'.join(logs), '''\ -$0 = input('x') -$1 = torch._ops.aten._foobar.default($0) -$2 = torch._ops.aten._foobar.default($0, False) -$3 = torch._ops.aten._foobar.default($0, arg3=False) -$4 = torch._ops.aten._foobar.default($0, False, arg3=False)''') +$0: f32[1] = input('x') +$1: f32[1] = torch._ops.aten._foobar.default($0) +$2: f32[1] = torch._ops.aten._foobar.default($0, False) +$3: f32[1] = torch._ops.aten._foobar.default($0, arg3=False) +$4: f32[1] = torch._ops.aten._foobar.default($0, False, arg3=False)''') def test_produce_real_type(self) -> None: with capture_logs() as logs: @@ -1564,12 +1564,12 @@ $4 = torch._ops.aten._foobar.default($0, False, arg3=False)''') # triggerable using tensor subclasses (need to use a mode) self.assertExpectedInline('\n'.join(logs), '''\ -$0 = input('x') -$1 = torch._ops.aten._to_copy.default($0, dtype=torch.float64) -$2 = torch._ops.aten.cumprod.default($0, 0, dtype=torch.float64) -$3 = torch._ops.aten.slice.Tensor($0, 0, 0, 9223372036854775807) -$4 = torch._ops.aten.select.int($3, 1, 1) -$5 = torch._ops.aten.clone.default($4, memory_format=torch.contiguous_format)''') +$0: f32[2, 2] = input('x') +$1: f64[2, 2] = torch._ops.aten._to_copy.default($0, dtype=torch.float64) +$2: f64[2, 2] = torch._ops.aten.cumprod.default($0, 0, dtype=torch.float64) +$3: f32[2, 2] = torch._ops.aten.slice.Tensor($0, 0, 0, 9223372036854775807) +$4: f32[2] = torch._ops.aten.select.int($3, 1, 1) +$5: f32[2] = torch._ops.aten.clone.default($4, memory_format=torch.contiguous_format)''') def test_optional_tensor_list(self) -> None: def weird(xs): @@ -1585,9 +1585,8 @@ $5 = torch._ops.aten.clone.default($4, memory_format=torch.contiguous_format)''' torch.ops.my_lib.weird.default([None, x]) self.assertExpectedInline('\n'.join(logs), '''\ -$0 = input('x') -$1 = torch._ops.my_lib.weird.default([None, LoggingTensor(tensor([[1., 1.], - [1., 1.]]))])''') +$0: f32[2, 2] = input('x') +$1: f32[] = torch._ops.my_lib.weird.default(['None', '$0'])''') def test_list_ret(self) -> None: # test all sequence types are permissible returns @@ -1639,9 +1638,9 @@ $1 = torch._ops.my_lib.weird.default([None, LoggingTensor(tensor([[1., 1.], # this test here to make sure we don't regress even further (it # would be bad if calling .detach() once emits 3+ detaches). self.assertExpectedInline('\n'.join(logs), '''\ -$0 = input('x') -$1 = torch._ops.aten.detach.default($0) -$2 = torch._ops.aten.detach.default($1)''') +$0: f32[1] = input('x') +$1: f32[1] = torch._ops.aten.detach.default($0) +$2: f32[1] = torch._ops.aten.detach.default($1)''') def test_storage(self) -> None: # For now, just make sure it doesn't crash. Ideally, we should @@ -1744,14 +1743,14 @@ $2 = torch._ops.aten.detach.default($1)''') # self.assertEqual(escape[0]._version, x._version) self.assertExpectedInline('\n'.join(logs), '''\ -$0 = input('x') -$1 = input('x.grad') -$2 = torch._ops.aten.pow.Tensor_Scalar($0, 2) -$3 = input('grad_output') +$0: f32[1] = input('x') +$1: f32[1] = input('x.grad') +$2: f32[1] = torch._ops.aten.pow.Tensor_Scalar($0, 2) +$3: f32[1] = input('grad_output') True = torch._ops.aten.is_same_size.default($2, $3) -$4 = torch._ops.aten.mul.Tensor($3, 2) -$5 = torch._ops.aten.mul.Tensor($4, $0) -$6 = torch._ops.aten.add_.Tensor($1, $5)''') +$4: f32[1] = torch._ops.aten.mul.Tensor($3, 2) +$5: f32[1] = torch._ops.aten.mul.Tensor($4, $0) +$6: f32[1] = torch._ops.aten.add_.Tensor($1, $5)''') def test_subclass_creation(self): # Make sure these statements runs without error @@ -1934,7 +1933,7 @@ $6 = torch._ops.aten.add_.Tensor($1, $5)''') with LoggingTensorMode(): torch.empty([]) self.assertExpectedInline('\n'.join(logs), """\ -$0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""") +$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""") def test_torch_dispatch_mode_unrelated_tensors(self) -> None: x = torch.randn([]) @@ -1942,8 +1941,7 @@ $0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memo with capture_logs(is_mode=True) as logs: with LoggingTensorMode(): x + y - self.assertExpectedInline('\n'.join(logs), """\ -$2 = torch._ops.aten.add.Tensor($0, $1)""") + self.assertExpectedInline('\n'.join(logs), """$2: f32[] = torch._ops.aten.add.Tensor($0, $1)""") def test_nested_push_logging_tensor_mode(self): x = torch.randn([]) @@ -1955,10 +1953,10 @@ $2 = torch._ops.aten.add.Tensor($0, $1)""") x + y self.assertExpectedInline('\n'.join(logs), """\ -$0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False) -$0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False) -$3 = torch._ops.aten.add.Tensor($1, $2) -$3 = torch._ops.aten.add.Tensor($1, $2)""") +$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False) +$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False) +$3: f32[] = torch._ops.aten.add.Tensor($1, $2) +$3: f32[] = torch._ops.aten.add.Tensor($1, $2)""") def test_capture_logs_with_torch_dispatch_mode(self): x = torch.randn([]) @@ -1967,8 +1965,8 @@ $3 = torch._ops.aten.add.Tensor($1, $2)""") torch.empty([]) x + y self.assertExpectedInline('\n'.join(logs), """\ -$0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False) -$3 = torch._ops.aten.add.Tensor($1, $2)""") +$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False) +$3: f32[] = torch._ops.aten.add.Tensor($1, $2)""") x = torch.randn([]) y = torch.randn([]) @@ -1979,10 +1977,10 @@ $3 = torch._ops.aten.add.Tensor($1, $2)""") x + y self.assertExpectedInline('\n'.join(logs2), """\ -$0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False) -$0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False) -$3 = torch._ops.aten.add.Tensor($1, $2) -$3 = torch._ops.aten.add.Tensor($1, $2)""") +$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False) +$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False) +$3: f32[] = torch._ops.aten.add.Tensor($1, $2) +$3: f32[] = torch._ops.aten.add.Tensor($1, $2)""") self.assertEqual(logs1, logs2) @@ -2238,8 +2236,8 @@ $3 = torch._ops.aten.add.Tensor($1, $2)""") with reenabled: torch.empty([]) self.assertExpectedInline('\n'.join(logs), """\ -$0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False) -$0 = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""") +$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False) +$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""") def test_error_using_class_method_on_mode(self): diff --git a/torch/testing/_internal/logging_tensor.py b/torch/testing/_internal/logging_tensor.py index 649c411fc94..4c53517ae49 100644 --- a/torch/testing/_internal/logging_tensor.py +++ b/torch/testing/_internal/logging_tensor.py @@ -1,12 +1,31 @@ import torch from torch.utils._pytree import tree_map -from typing import Iterator, List +from typing import Iterator, List, Optional import logging import contextlib import itertools from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils.weak import WeakTensorKeyDictionary +import functools +from torch._C._profiler import gather_traceback, symbolize_tracebacks +_dtype_abbrs = { + torch.bfloat16: "bf16", + torch.float64: "f64", + torch.float32: "f32", + torch.float16: "f16", + torch.complex32: "c32", + torch.complex64: "c64", + torch.complex128: "c128", + torch.int8: "i8", + torch.int16: "i16", + torch.int32: "i32", + torch.int64: "i64", + torch.bool: "b8", + torch.uint8: "u8", +} + # How the chain of calls works for LoggingTensor: # 1. Call torch.sin # 2. Attempt __torch_function__. In LoggingTensor torch function is disabled so we bypass it entirely @@ -76,53 +95,88 @@ class LoggingTensorReentrant(LoggingTensor): # https://stackoverflow.com/questions/36408496/python-logging-handler-to-append-to-list class LoggingTensorHandler(logging.Handler): - log_list: List[str] - next_shortid: int - - def __init__(self, log_list: List[str], use_shortid_for_all_tensors: bool) -> None: + def __init__( + self, log_list: List[str], use_shortid_for_all_tensors: bool, + with_type: bool, tracebacks_list: Optional[List]) -> None: logging.Handler.__init__(self) self.log_list = log_list - self.next_shortid = 0 self.use_shortid_for_all_tensors = use_shortid_for_all_tensors + self.tracebacks_list = tracebacks_list + self.memo = WeakTensorKeyDictionary() + self.next_id = 0 + self.with_type = with_type - # WARNING: not deterministic over multiple threads, this matters for - # autograd - def _shortid(self, o: object) -> int: - if not hasattr(o, '_shortid'): - o._shortid = self.next_shortid # type: ignore[attr-defined] - self.next_shortid += 1 - return o._shortid # type: ignore[attr-defined] + def _shortid(self, t: torch.Tensor) -> int: + if t not in self.memo: + self.memo[t] = self.next_id + self.next_id += 1 + return self.memo[t] - def _fmt(self, a: object) -> str: + def _fmt(self, a: object, with_type: bool = False) -> str: cond_cls = torch.Tensor if self.use_shortid_for_all_tensors else LoggingTensor - return f'${self._shortid(a)}' if isinstance(a, cond_cls) else repr(a) + if isinstance(a, cond_cls): + maybe_type = "" + if with_type and self.with_type: + maybe_type = f": {_dtype_abbrs[a.dtype]}[{', '.join(map(str, a.shape))}]" + x = f"${self._shortid(a)}{maybe_type}" + return x + else: + return repr(a) def emit(self, record): - fmt_args = ", ".join(itertools.chain( - (self._fmt(a) for a in record.args[0]), - (f"{k}={self._fmt(v)}" for k, v in record.args[1].items()) - )) - fmt_rets = ", ".join(self._fmt(a) for a in record.args[2]) \ - if isinstance(record.args[2], (list, tuple)) else self._fmt(record.args[2]) + fmt_args = ", ".join( + itertools.chain( + (str(tree_map(self._fmt, a)) for a in record.args[0]), + (f"{k}={str(tree_map(self._fmt, v))}" for k, v in record.args[1].items()), + ) + ) + fmt_rets = tree_map(functools.partial(self._fmt, with_type=True), record.args[2]) self.log_list.append(f'{fmt_rets} = {record.msg}({fmt_args})') + if self.tracebacks_list is not None: + self.tracebacks_list.append(record.traceback) def log_input(name: str, var: object): - logging.getLogger("LoggingTensor").info("input", (name,), {}, (var,)) + logging.getLogger("LoggingTensor").info("input", (name,), {}, var) + +class GatherTraceback(logging.Filter): + def __init__(self, python=True, script=True, cpp=False): + self.python = python + self.script = script + self.cpp = cpp + + def filter(self, record): + record.traceback = gather_traceback(python=self.python, script=self.script, cpp=self.cpp) + return True @contextlib.contextmanager -def capture_logs(is_mode=False) -> Iterator[List[str]]: +def capture_logs(is_mode=False, python_tb=False, script_tb=False, cpp_tb=False) -> Iterator[List[str]]: + collect_traceback = python_tb or script_tb or cpp_tb logger = logging.getLogger("LoggingTensor") log_list: List[str] = [] - handler = LoggingTensorHandler(log_list, use_shortid_for_all_tensors=is_mode) + tracebacks_list: List[str] = [] + handler = LoggingTensorHandler( + log_list, + with_type=True, + use_shortid_for_all_tensors=is_mode, + tracebacks_list=tracebacks_list if collect_traceback else None + ) logger.addHandler(handler) logger.setLevel(logging.INFO) logger.propagate = False + if collect_traceback: + logger.addFilter(GatherTraceback(python=python_tb, script=script_tb, cpp=cpp_tb)) try: - yield log_list + if collect_traceback: + yield log_list, tracebacks_list + else: + yield log_list finally: + symbolized_tracebacks = symbolize_tracebacks(tracebacks_list) + tracebacks_list.clear() + tracebacks_list.extend(symbolized_tracebacks) logger.removeHandler(handler) @contextlib.contextmanager -def capture_logs_with_logging_tensor_mode(): - with LoggingTensorMode(), capture_logs(True) as logs: +def capture_logs_with_logging_tensor_mode(python_tb=False, script_tb=False, cpp_tb=False): + with LoggingTensorMode(), capture_logs(True, python_tb, script_tb, cpp_tb) as logs: yield logs