diff --git a/mypy-strict.ini b/mypy-strict.ini index 460599699c4..81c66d5239e 100644 --- a/mypy-strict.ini +++ b/mypy-strict.ini @@ -40,6 +40,7 @@ files = .github, benchmarks/instruction_counts, tools, + torch/profiler/_memory_profiler.py, torch/utils/_pytree.py, torch/utils/benchmark/utils/common.py, torch/utils/benchmark/utils/timer.py, diff --git a/test/profiler/test_memory_profiler.py b/test/profiler/test_memory_profiler.py new file mode 100644 index 00000000000..c725f8bec51 --- /dev/null +++ b/test/profiler/test_memory_profiler.py @@ -0,0 +1,224 @@ +# Owner(s): ["oncall: profiler"] +import functools +from typing import Iterator, Optional + +import torch +from torch._C._profiler import _EventType +from torch.profiler import _memory_profiler, _utils +from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase + + +profile = functools.partial( + torch.profiler.profile, record_shapes=True, profile_memory=True, with_stack=True +) + + +class ScaleLayer(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.scale = torch.nn.Parameter(torch.rand(()), requires_grad=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x * self.scale + + +@skipIfTorchDynamo("TorchDynamo changes Python calls that memory profiling relies on.") +class TestIdentifyGradients(TestCase): + def gradient_detected( + self, + prof: torch.profiler.profile, + ctx: _EventType, + grad_tensor: torch.Tensor, + parameter: Optional[torch.Tensor] = None, + ) -> None: + + # This is not an exhaustive check, but for the purpose of unit testing + # it is sufficient. + def key_matches_tensor(key, tensor) -> bool: + # Vacuous case. + if tensor is None: + return True + + if key is None: + return False + + return tensor.storage().data_ptr() == key.storage.ptr + + tree = prof.profiler.kineto_results.experimental_event_tree() + for node in _utils.traverse_dfs(tree): + for p_key, p_grad_key in _memory_profiler.extract_gradients(node): + if node.tag == ctx and key_matches_tensor(p_grad_key, grad_tensor): + if parameter is None: + return True # Don't need to check parameter; we're done. + + elif p_key is not None: + # For a complex workflow a gradient could correspond to + # different parameters at different points in a trace. + # However this will not happen in the relatively simple + # cases tested here, so if `extract_gradients` identifies + # the parameter corresponding to a particular gradient it + # must be the one we expect. + self.assertTrue(key_matches_tensor(p_key, parameter)) + return True + + return False + + def assertGradientDetected(self, name: str, *args, **kwargs) -> None: + self.assertTrue( + self.gradient_detected(*args, **kwargs), + f"Failed to identify gradient `{name}` from profile.", + ) + + def assertOnlyGradients( + self, prof: torch.profiler.profile, tensors: Iterator[torch.Tensor] + ) -> None: + allowed_set = {t.storage().data_ptr() for t in tensors} + + tree = prof.profiler.kineto_results.experimental_event_tree() + for node in _utils.traverse_dfs(tree): + for _, p_grad_key in _memory_profiler.extract_gradients(node): + self.assertTrue( + p_grad_key.storage.ptr in allowed_set, + f"Tensor wrongly marked as gradient: {node.name}: {p_grad_key}", + ) + + def test_extract_gradients_low_level(self) -> None: + x = torch.ones((1,)) + w0 = torch.ones((1,), requires_grad=True) + w1 = torch.ones((1,), requires_grad=True) + + def check(cold_start: bool): + self.assertEqual(w0.grad is None, cold_start) + self.assertEqual(w1.grad is None, cold_start) + with profile() as prof: + z = x.expand(4) * w0 + (z * w1).sum().backward() + + # Gradient detection through op inspection does not provide a + # reference to the parameter corresponding to the gradient. + self.assertGradientDetected("w0", prof, _EventType.TorchOp, w0.grad) + self.assertGradientDetected("w1", prof, _EventType.TorchOp, w1.grad) + self.assertOnlyGradients(prof, (w0.grad, w1.grad)) + + check(cold_start=True) + check(cold_start=False) + + def test_extract_gradients_from_module(self) -> None: + model = torch.nn.Sequential(torch.nn.Linear(2, 1), ScaleLayer()) + named_parameters = {name: p for name, p in model.named_parameters()} + self.assertEqual(len(named_parameters), 3) + + def assert_only_gradients(prof: torch.profiler.profile): + gradients = tuple(i.grad for i in named_parameters.values()) + self.assertFalse(any(i is None for i in gradients)) + self.assertOnlyGradients(prof, gradients) + + def check(cold_start: bool): + x = torch.ones((2, 2)) + with profile() as prof: + model(x).sum().backward() + + for name, p in named_parameters.items(): + # The first time we run a module none of the `.grad` fields + # have been initialized. This is fine; in that case we can + # detect everything we need in the profiled section. + self.assertNotEqual( + self.gradient_detected(prof, _EventType.PyCall, p.grad, p), + cold_start, + name, + ) + + # Op based detection should still identify the gradients. + self.assertGradientDetected(name, prof, _EventType.TorchOp, p.grad) + assert_only_gradients(prof) + + # We can detect gradients even when `.backward()` is not called. + with profile() as prof: + model(torch.ones((2, 2))) + + for name, p in named_parameters.items(): + self.assertGradientDetected(name, prof, _EventType.PyCall, p.grad, p) + self.assertFalse( + self.gradient_detected(prof, _EventType.TorchOp, p.grad), name + ) + assert_only_gradients(prof) + + check(cold_start=True) + check(cold_start=False) + + def _test_extract_gradients_from_optimizer(self, set_to_none: bool) -> None: + + x = torch.ones((1,)) + w0 = torch.ones((1,), requires_grad=True) + w1 = torch.ones((1,), requires_grad=True) + optimizer = torch.optim.SGD((w0, w1), lr=0.1, momentum=0.9) + + def check(cold_start: bool): + self.assertEqual(w0.grad is None, cold_start) + self.assertEqual(w1.grad is None, cold_start) + with profile() as prof: + optimizer.zero_grad(set_to_none=set_to_none) + z = x.expand(4) * w0 + (z * w1).sum().backward() + optimizer.step() + + # Optimizer instrumentation runs late in the step, so we can detect + # gradients for both cold and warm start. + self.assertGradientDetected("w0", prof, _EventType.PyCall, w0.grad, w0) + self.assertGradientDetected("w1", prof, _EventType.PyCall, w1.grad, w1) + + self.assertGradientDetected("w0", prof, _EventType.TorchOp, w0.grad) + self.assertGradientDetected("w1", prof, _EventType.TorchOp, w1.grad) + self.assertOnlyGradients(prof, (w0.grad, w1.grad)) + + with profile() as prof: + for _ in range(2): + optimizer.zero_grad(set_to_none=set_to_none) + z = x.expand(4) * w0 + (z * w1).sum().backward() + optimizer.step() + + # Inspected state is cached, so if we replace gradients (as is the + # case for `set_to_none=True`) our python instrumentation will not + # see them. + # TODO(robieta): Should `.step()` be excluded from caching? + self.assertNotEqual( + self.gradient_detected(prof, _EventType.PyCall, w0.grad, w0), + set_to_none, + ) + + self.assertNotEqual( + self.gradient_detected(prof, _EventType.PyCall, w1.grad, w1), + set_to_none, + ) + + if set_to_none: + with self.assertRaisesRegex(AssertionError, "Tensor wrongly marked"): + self.assertOnlyGradients(prof, (w0.grad, w1.grad)) + + check(cold_start=True) + check(cold_start=False) + + def test_extract_gradients_from_optimizer(self) -> None: + self._test_extract_gradients_from_optimizer(set_to_none=False) + + def test_extract_gradients_from_optimizer_set_to_none(self) -> None: + self._test_extract_gradients_from_optimizer(set_to_none=True) + + def test_extract_gradients_from_module_and_optimizer(self) -> None: + # Module and optimizer are thoroughly tested individually and should be + # additive. Thus we can manage with a lightweight check that they don't + # interact adversely. + model = torch.nn.Sequential(torch.nn.Linear(2, 1), ScaleLayer()) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + with profile() as prof: + model(torch.ones((2, 2))).sum().backward() + optimizer.step() + + self.assertGradientDetected( + "weight", prof, _EventType.PyCall, model[0].weight.grad, model[0].weight + ) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/_C/_profiler.pyi b/torch/_C/_profiler.pyi index 2843090d61f..da0f191e26b 100644 --- a/torch/_C/_profiler.pyi +++ b/torch/_C/_profiler.pyi @@ -3,6 +3,8 @@ from typing import List, Optional, Tuple, Union from torch._C import device, dtype, layout +from typing_extensions import Literal + # defined in torch/csrc/profiler/python/init.cpp class RecordScope(Enum): @@ -38,11 +40,12 @@ class ProfilerActivity(Enum): CUDA = ... class _EventType(Enum): - Allocation = ... + TorchOp = ... Backend = ... + Allocation = ... + OutOfMemory = ... PyCall = ... PyCCall = ... - TorchOp = ... Kineto = ... class _ExperimentalConfig: @@ -71,6 +74,8 @@ class _ProfilerEvent: start_tid: int start_time_ns: int children: List[_ProfilerEvent] + + # TODO(robieta): remove in favor of `self.typed` extra_fields: Union[ _ExtraFields_TorchOp, _ExtraFields_Backend, @@ -81,6 +86,18 @@ class _ProfilerEvent: _ExtraFields_Kineto, ] + @property + def typed( + self, + ) -> Union[ + Tuple[Literal[_EventType.TorchOp], _ExtraFields_TorchOp], + Tuple[Literal[_EventType.Backend], _ExtraFields_Backend], + Tuple[Literal[_EventType.Allocation], _ExtraFields_Allocation], + Tuple[Literal[_EventType.OutOfMemory], _ExtraFields_OutOfMemory], + Tuple[Literal[_EventType.PyCall], _ExtraFields_PyCall], + Tuple[Literal[_EventType.PyCCall], _ExtraFields_PyCCall], + Tuple[Literal[_EventType.Kineto], _ExtraFields_Kineto], + ]: ... @property def name(self) -> str: ... @property @@ -101,6 +118,8 @@ class _TensorMetadata: storage_data_ptr: Optional[int] id: Optional[int] + @property + def allocation_id(self) -> Optional[int]: ... @property def layout(self) -> layout: ... @property @@ -129,11 +148,12 @@ class _ExtraFields_Backend: ... class _ExtraFields_Allocation: ptr: int id: Optional[int] - allocation_id: Optional[int] alloc_size: int total_allocated: int total_reserved: int + @property + def allocation_id(self) -> Optional[int]: ... @property def device(self) -> device: ... @@ -147,22 +167,47 @@ class _PyFrameState: def file_name(self) -> str: ... class _NNModuleInfo: - @property - def params(self) -> List[Tuple[str, int]]: ... @property def self_ptr(self) -> int: ... @property def cls_ptr(self) -> int: ... @property def cls_name(self) -> str: ... + @property + def parameters( + self, + ) -> List[Tuple[str, _TensorMetadata, Optional[_TensorMetadata]]]: ... + +class _OptimizerInfo: + @property + def parameters( + self, + ) -> List[ + Tuple[ + # Parameter + _TensorMetadata, + # + # Gradient (if present during optimizer.step()) + Optional[_TensorMetadata], + # + # Optimizer state for Parameter as (name, tensor) pairs + List[Tuple[str, _TensorMetadata]], + ] + ]: ... class _ExtraFields_PyCCall: - callsite: _PyFrameState - caller: _PyFrameState - module: Optional[_NNModuleInfo] + @property + def caller(self) -> _PyFrameState: ... class _ExtraFields_PyCall: - caller: _PyFrameState + @property + def callsite(self) -> _PyFrameState: ... + @property + def caller(self) -> _PyFrameState: ... + @property + def module(self) -> Optional[_NNModuleInfo]: ... + @property + def optimizer(self) -> Optional[_OptimizerInfo]: ... class _ExtraFields_Kineto: ... diff --git a/torch/csrc/profiler/python/init.cpp b/torch/csrc/profiler/python/init.cpp index 7084a1b598d..2a5839fc6a2 100644 --- a/torch/csrc/profiler/python/init.cpp +++ b/torch/csrc/profiler/python/init.cpp @@ -251,6 +251,13 @@ void initPythonBindings(PyObject* module) { .def_property_readonly("name", &Result::name) .def_property_readonly("tag", &Result::tag) .def_readonly("extra_fields", &Result::extra_fields_) + .def_property_readonly( + "typed", + [](const Result& r) { + return py::make_tuple( + r.tag(), + py::cast(r.extra_fields_, py::return_value_policy::reference)); + }) .def_property_readonly( "id", [](const Result& r) { diff --git a/torch/profiler/_memory_profiler.py b/torch/profiler/_memory_profiler.py new file mode 100644 index 00000000000..cab77193148 --- /dev/null +++ b/torch/profiler/_memory_profiler.py @@ -0,0 +1,114 @@ +import dataclasses +from typing import Any, Iterator, Optional, Tuple + +import torch +from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata, RecordScope + + +@dataclasses.dataclass +class _Storage: + """Bundle storage pointer and id. + + All profiling logic should use `allocation_id`, however it is useful to + print storage pointers for debugging and unit tests sometimes look up + values using the storage data pointer of a live Tensor.""" + + ptr: int + allocation_id: int + + def __repr__(self) -> str: + return f"{hex(self.ptr):>18} ({self.allocation_id})" + + def __eq__(self, other: Any) -> bool: + return isinstance(other, _Storage) and self.allocation_id == other.allocation_id + + def __hash__(self) -> int: + return hash(self.allocation_id) + + +@dataclasses.dataclass(eq=True, unsafe_hash=True, frozen=True) +class TensorKey: + """Hashable identifier for a storage which has been asigned an ID. + + A detailed description of Tensor IDs and why they are needed is given in + `torch/csrc/profiler/collection.h` when `TensorID` is declared. To + summarize, multiple Storage buffers can map to the same logical Tensor. + This dataclass is used to refer to a concrete in-memory StorageImpl of + a Tensor. + """ + + id: int + storage: _Storage + device: torch.device + + def __repr__(self) -> str: + return f"id={self.id}: {repr(self.storage):<24} ({self.device})" + + @staticmethod + def _make( + tensor_id: Optional[int], + storage_ptr: Optional[int], + allocation_id: Optional[int], + device: torch.device, + ) -> Optional["TensorKey"]: + if ( + tensor_id is not None + and storage_ptr is not None + and allocation_id is not None + ): + return TensorKey(tensor_id, _Storage(storage_ptr, allocation_id), device) + return None + + @classmethod + def from_tensor(cls, t: Optional[_TensorMetadata]) -> Optional["TensorKey"]: + if t is not None: + return cls._make(t.id, t.storage_data_ptr, t.allocation_id, t.device) + return None + + +def extract_gradients( + node: _ProfilerEvent, +) -> Iterator[Tuple[Optional[TensorKey], TensorKey]]: + children = node.children + + # AccumulateGrad is used in the Autograd engine to handle gradient updates. + # There are two possible cases: + # 1) This is a newly created gradient Tensor. In that case there is nothing + # to accumulate, so autograd simply detaches the Tensor. + # + # 2) There is a preexisting gradient Tensor and we need to add the newly + # computed update. This is done with an in-place add (aten::add_) op. + # (The underscore suffix denotes "in-place".) + if ( + node.typed[0] == _EventType.TorchOp + and node.typed[1].scope == RecordScope.BACKWARD_FUNCTION + # TODO(robieta): Move away from load bearing names + and node.name == "torch::autograd::AccumulateGrad" + and children + and children[0].typed[0] == _EventType.TorchOp + and children[0].name in ("aten::detach", "aten::add_") + and children[0].typed[1].inputs + and isinstance(children[0].typed[1].inputs[0], _TensorMetadata) + ): + key = TensorKey.from_tensor(children[0].typed[1].inputs[0]) + if key: + yield None, key + + # We directly instrument `torch.nn.Module` and `torch.optim.Optimizer` + # NOTE: The values captured by the python tracer are cached; they can be + # used to build up labels but do not imply that a Tensor was live at + # a particular time. + elif node.typed[0] == _EventType.PyCall: + typed_fields = node.typed[1] + assert typed_fields.module is None or typed_fields.optimizer is None + if typed_fields.module is not None: + for _, p, p_grad in typed_fields.module.parameters: + p_grad_key = TensorKey.from_tensor(p_grad) + if p_grad_key is not None: + yield TensorKey.from_tensor(p), p_grad_key + + if typed_fields.optimizer is not None: + for p, p_grad, _ in typed_fields.optimizer.parameters: + p_grad_key = TensorKey.from_tensor(p_grad) + if p_grad_key is not None: + yield TensorKey.from_tensor(p), p_grad_key