Add basic autograd TORCH_LOGS support (#115438)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115438
Approved by: https://github.com/albanD
This commit is contained in:
soulitzer 2023-12-19 20:51:09 -05:00 committed by PyTorch MergeBot
parent cfbf647adb
commit cfb3cd11c1
7 changed files with 151 additions and 53 deletions

View file

@ -0,0 +1,33 @@
# Owner(s): ["module: autograd"]
import logging
import torch
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
class TestAutogradLogging(LoggingTestCase):
@make_logging_test(autograd=logging.INFO)
def test_logging(self, records):
a = torch.rand(10, requires_grad=True)
b = a.mul(2).div(3).sum()
c = b.clone()
torch.autograd.backward((b, c))
self.assertEqual(len(records), 5)
expected = [
"CloneBackward0",
"SumBackward0",
"DivBackward0",
"MulBackward0",
"AccumulateGrad",
]
for i, record in enumerate(records):
self.assertIn(expected[i], record.getMessage())
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View file

@ -469,47 +469,49 @@ class TestProfilerTree(TestCase):
[memory]
aten::fill_
<built-in method append of list object at 0xXXXXXXXXXXXX>
<built-in method run_backward of torch._C._EngineBase object at 0xXXXXXXXXXXXX>
autograd::engine::evaluate_function: PowBackward0
PowBackward0
aten::pow
aten::result_type
aten::to
[memory]
aten::copy_
aten::mul
[memory]
aten::mul
torch/autograd/graph.py(...): _engine_run_backward
logging/__init__.py(...): getEffectiveLevel
<built-in method run_backward of torch._C._EngineBase object at 0xXXXXXXXXXXXX>
autograd::engine::evaluate_function: PowBackward0
PowBackward0
aten::pow
aten::result_type
aten::to
aten::_to_copy
aten::empty_strided
[memory]
aten::copy_
[memory]
aten::copy_
aten::mul
[memory]
aten::mul
aten::to
aten::_to_copy
aten::empty_strided
[memory]
aten::copy_
[memory]
[memory]
[memory]
aten::mul
[memory]
[memory]
aten::mul
[memory]
[memory]
[memory]
[memory]
autograd::engine::evaluate_function: SubBackward0
SubBackward0
aten::neg
[memory]
[memory]
autograd::engine::evaluate_function: AddBackward0
AddBackward0
autograd::engine::evaluate_function: torch::autograd::AccumulateGrad
torch::autograd::AccumulateGrad
aten::new_empty_strided
aten::empty_strided
autograd::engine::evaluate_function: SubBackward0
SubBackward0
aten::neg
[memory]
aten::copy_
autograd::engine::evaluate_function: torch::autograd::AccumulateGrad
torch::autograd::AccumulateGrad
aten::detach
detach
[memory]
autograd::engine::evaluate_function: AddBackward0
AddBackward0
autograd::engine::evaluate_function: torch::autograd::AccumulateGrad
torch::autograd::AccumulateGrad
aten::new_empty_strided
aten::empty_strided
[memory]
aten::copy_
autograd::engine::evaluate_function: torch::autograd::AccumulateGrad
torch::autograd::AccumulateGrad
aten::detach
detach
[memory]
torch/profiler/profiler.py(...): __exit__
torch/profiler/profiler.py(...): stop

View file

@ -11880,6 +11880,7 @@ class TestAutogradMultipleDispatch(TestCase):
from autograd.test_complex import TestAutogradComplex # noqa: F401
from autograd.test_functional import TestAutogradFunctional # noqa: F401
from autograd.test_logging import TestAutogradLogging # noqa: F401
# e.g., TestAutogradDeviceTypeCPU and TestAutogradDeviceTypeCUDA
instantiate_device_type_tests(

View file

@ -169,6 +169,7 @@ def set_logs(
all: Optional[int] = None,
dynamo: Optional[int] = None,
aot: Optional[int] = None,
autograd: Optional[int] = None,
dynamic: Optional[int] = None,
inductor: Optional[int] = None,
distributed: Optional[int] = None,
@ -240,6 +241,9 @@ def set_logs(
aot (:class:`Optional[int]`):
The log level for the AOTAutograd component. Default: ``logging.WARN``
autograd (:class:`Optional[int]`):
The log level for autograd. Default: ``logging.WARN``
inductor (:class:`Optional[int]`):
The log level for the TorchInductor component. Default: ``logging.WARN``
@ -392,6 +396,7 @@ def set_logs(
torch=all,
dynamo=dynamo,
aot=aot,
autograd=autograd,
inductor=inductor,
dynamic=dynamic,
bytecode=bytecode,

View file

@ -6,6 +6,7 @@ DISTRIBUTED = ["torch.distributed", "torch._dynamo.backends.distributed"]
register_log("dynamo", ["torch._dynamo", *DYNAMIC])
register_log("aot", ["torch._functorch.aot_autograd", "torch._functorch._aot_autograd"])
register_log("autograd", "torch.autograd")
register_log("inductor", "torch._inductor")
register_log("dynamic", DYNAMIC)
register_log("torch", "torch")

View file

@ -27,6 +27,7 @@ from .grad_mode import (
set_multithreading_enabled,
)
from .gradcheck import gradcheck, gradgradcheck
from .graph import _engine_run_backward
from .variable import Variable
@ -263,7 +264,7 @@ def backward(
# The reason we repeat the same comment below is that
# some Python versions print out the first line of a multi-line function
# calls in the traceback and some print out the last line
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
_engine_run_backward(
tensors,
grad_tensors_,
retain_graph,
@ -271,7 +272,7 @@ def backward(
inputs,
allow_unreachable=True,
accumulate_grad=True,
) # Calls into the C++ engine to run the backward pass
)
def grad(
@ -394,7 +395,7 @@ def grad(
if is_grads_batched:
def vjp(gO):
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
return _engine_run_backward(
t_outputs,
gO,
retain_graph,
@ -402,13 +403,13 @@ def grad(
inputs,
allow_unused,
accumulate_grad=False,
) # Calls into the C++ engine to run the backward pass
)
result = _vmap_internals._vmap(vjp, 0, 0, allow_none_pass_through=True)(
grad_outputs_
)
else:
result = Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
result = _engine_run_backward(
t_outputs,
grad_outputs_,
retain_graph,
@ -416,7 +417,7 @@ def grad(
inputs,
allow_unused,
accumulate_grad=False,
) # Calls into the C++ engine to run the backward pass
)
if materialize_grads:
if any(
result[i] is None and not is_tensor_like(inputs[i])

View file

@ -1,13 +1,19 @@
import abc
import collections
import contextlib
import logging
import weakref
from collections import defaultdict, namedtuple
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
from typing import Any, Callable, Deque, Dict, List, Optional, Sequence, Set, Tuple
import torch
from torch.autograd.variable import Variable
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils.hooks import RemovableHandle
log = logging.getLogger(__name__)
__all__ = [
"saved_tensors_hooks",
"save_on_cpu",
@ -135,6 +141,13 @@ class Node(abc.ABC):
return NotImplemented
def _get_grad_fn_or_grad_acc(t):
if t.requires_grad and t.grad_fn is None:
return t.view_as(t).grad_fn.next_functions[0][0]
else:
return t.grad_fn
GradientEdge = namedtuple("GradientEdge", ("node output_nr"))
GradientEdge.__doc__ = """\
Object representing a given gradient edge within the autograd graph.
@ -153,10 +166,7 @@ def get_gradient_edge(tensor):
raise RuntimeError(
"It is not possible to get the gradient edge for a Tensor that does not require gradients"
)
grad_fn = tensor.grad_fn
if grad_fn is None:
# Do an op to force AccumulateGrad lazy creation and get it
grad_fn = tensor.view_as(tensor).grad_fn.next_functions[0][0]
grad_fn = _get_grad_fn_or_grad_acc(tensor)
# Note that output_nr default to 0 which is the right value
# for the AccumulateGrad node.
@ -407,14 +417,7 @@ def register_multi_grad_hook(
nb_calls = None
buffer: Dict[int, List[Optional[torch.Tensor]]] = dict()
def get_grad_fn(t):
# or grad accumulator
if t.requires_grad and t.grad_fn is None:
return t.expand_as(t).grad_fn.next_functions[0][0]
else:
return t.grad_fn
grad_fns = list(map(get_grad_fn, tensors))
grad_fns = list(map(_get_grad_fn_or_grad_acc, tensors))
len_tensors = len(tensors)
def get_inner_hook(idx):
@ -629,3 +632,55 @@ def allow_mutation_on_saved_tensors():
finally:
ctx.clear()
_allow_mutation_on_saved_tensors_enabled = False
def _register_logging_hooks_on_whole_graph(t_outputs: List[torch.Tensor]):
grad_fns = list(map(_get_grad_fn_or_grad_acc, t_outputs))
def iter_graph(roots):
if not roots:
return
seen = set()
q: Deque = collections.deque()
for node in roots:
if node is not None:
seen.add(node)
q.append(node)
while q:
node = q.popleft()
for fn, _idx in node.next_functions:
if fn in seen or fn is None:
continue
seen.add(fn)
q.append(fn)
yield node
def prehook(grad_output):
node = torch._C._current_autograd_node()
log_str = f"Executing: {node} with grad_output: {grad_output}"
log.info(log_str)
handles = []
for node in iter_graph(grad_fns):
handles.append(node.register_prehook(prehook))
def unregister_hooks():
for handle in handles:
handle.remove()
return unregister_hooks
def _engine_run_backward(t_outputs, *args, **kwargs):
attach_logging_hooks = log.getEffectiveLevel() <= logging.INFO
if attach_logging_hooks:
unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
try:
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
t_outputs, *args, **kwargs
) # Calls into the C++ engine to run the backward pass
finally:
if attach_logging_hooks:
unregister_hooks()