[Profiler] Memory profiler part 13: Add sizes to timeline. (#89356)

If we see an allocation the size is unambiguous. Otherwise we have to use sizes and strides to bound the underlying storage.

Differential Revision: [D40868660](https://our.internmc.facebook.com/intern/diff/D40868660/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89356
Approved by: https://github.com/chaekit
This commit is contained in:
Taylor Robie 2022-12-01 08:35:20 -08:00 committed by PyTorch MergeBot
parent 6727e537a7
commit 63e57280fc
2 changed files with 174 additions and 170 deletions

View file

@ -861,6 +861,9 @@ class TestMemoryProfilerE2E(TestCase):
assert_category(p, _memory_profiler.Category.PARAMETER)
assert_category(p.grad, _memory_profiler.Category.GRADIENT)
# Rely on internal asserts
_ = memory_profile.timeline
def _run_and_format_categories(self, fn, indent=12):
"""Generate summary of assigned categories for expecttest."""
@ -1417,30 +1420,30 @@ class TestMemoryProfilerE2E(TestCase):
def test_memory_timeline(self) -> None:
model = torch.nn.Sequential(
torch.nn.Linear(2, 4, bias=True),
torch.nn.Linear(64, 512, bias=True),
torch.nn.ReLU(),
torch.nn.Linear(4, 4, bias=False),
torch.nn.Linear(512, 512, bias=False),
torch.nn.Softmax(dim=1),
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
with profile() as prof:
x = torch.ones((2, 2))
targets = torch.ones((2, 4))
x = torch.ones((1024, 64))
targets = torch.ones((1024, 512))
y = model(x)
loss = torch.sum((y - targets) ** 2).mean()
loss = torch.nn.functional.mse_loss(y, targets)
loss.backward()
optimizer.step()
optimizer.zero_grad()
memory_profile = prof._memory_profile()
timeline = memory_profile.timeline
times = tuple(t for t, _, _ in timeline)
times = tuple(t for t, _, _, _ in timeline)
self.assertTrue(all(t1 >= t0 for t0, t1 in zip(times, times[1:])), times)
self.assertTrue(
all(
(t == -1) if action == _memory_profiler.Action.PREEXISTING else (t > 0)
for t, action, _ in timeline
for t, action, _, _ in timeline
)
)
@ -1455,174 +1458,101 @@ class TestMemoryProfilerE2E(TestCase):
return f"{category_name(category)} -> {category_name(new_category)}"
return category_name(category)
def format_size(size: int):
if size < 1024:
return f"{size / 1024:3.1f} kB"
return f"{size // 1024} kB"
# We generate sequential IDs for Tensors; however platforms vary
# slightly in the exact computation executed. If this results in
# tensor creation the IDs will be shifted and the unit test will fail.
# (Even though the behavior we're testing is unchanged.) To correct for
# this we assign sequential numbers to the tensors which are actually
# tested, effectively suppressing the extraneous implementation details.
id_map = {}
def id_for_testing(key):
return id_map.setdefault(key.storage.allocation_id, len(id_map))
lines = [
f"{action.name.lower():<25} {format_action(action, key, version):<25} "
f"{key.storage.allocation_id:>2} v{version}"
for _, action, (key, version) in prof._memory_profile().timeline
f"{id_for_testing(key):>3}(v{version}) {format_size(size):>15}"
for _, action, (key, version), size in prof._memory_profile().timeline
# We generally don't care about tiny allocations during memory
# profiling and they add a lot of noise to the unit test.
if size >= 256
]
self.assertExpectedInline(
textwrap.indent("\n".join(lines), " " * 12),
"""\
preexisting PARAMETER 3 v0
preexisting PARAMETER 4 v0
preexisting PARAMETER 7 v0
create INPUT 1 v0
create INPUT 2 v0
create ACTIVATION 5 v0
create ACTIVATION 6 v0
destroy ACTIVATION 5 v0
create ACTIVATION 8 v0
create ACTIVATION 9 v0
destroy ACTIVATION 8 v0
create ACTIVATION 10 v0
create ACTIVATION 11 v0
create ACTIVATION 12 v0
destroy ACTIVATION 11 v0
create ACTIVATION 13 v0
create TEMPORARY 14 v0
create TEMPORARY 15 v0
destroy TEMPORARY 15 v0
destroy TEMPORARY 14 v0
create ACTIVATION 16 v0
create TEMPORARY 17 v0
create TEMPORARY 18 v0
create AUTOGRAD_DETAIL 19 v0
destroy TEMPORARY 18 v0
destroy TEMPORARY 17 v0
destroy ACTIVATION 12 v0
create TEMPORARY 20 v0
create TEMPORARY 21 v0
create TEMPORARY 22 v0
create TEMPORARY 23 v0
destroy TEMPORARY 22 v0
destroy TEMPORARY 21 v0
create AUTOGRAD_DETAIL 24 v0
destroy TEMPORARY 23 v0
destroy TEMPORARY 20 v0
destroy AUTOGRAD_DETAIL 19 v0
destroy ACTIVATION 10 v0
increment_version AUTOGRAD_DETAIL 24 v0
create AUTOGRAD_DETAIL 25 v0
destroy AUTOGRAD_DETAIL 24 v1
create GRADIENT 26 v0
create AUTOGRAD_DETAIL 27 v0
destroy AUTOGRAD_DETAIL 25 v0
create AUTOGRAD_DETAIL 28 v0
destroy AUTOGRAD_DETAIL 27 v0
destroy ACTIVATION 6 v0
create GRADIENT 29 v0
create GRADIENT 30 v0
destroy AUTOGRAD_DETAIL 28 v0
destroy ACTIVATION 16 v0
create OPTIMIZER_STATE 31 v0
increment_version OPTIMIZER_STATE 31 v0
create OPTIMIZER_STATE 32 v0
create OPTIMIZER_STATE 33 v0
create OPTIMIZER_STATE 34 v0
increment_version OPTIMIZER_STATE 34 v0
create OPTIMIZER_STATE 35 v0
create OPTIMIZER_STATE 36 v0
create OPTIMIZER_STATE 37 v0
increment_version OPTIMIZER_STATE 37 v0
create OPTIMIZER_STATE 38 v0
create OPTIMIZER_STATE 39 v0
create ??? 40 v0
increment_version OPTIMIZER_STATE 31 v1
create TEMPORARY 41 v0
destroy TEMPORARY 41 v0
destroy ??? 40 v0
create INPUT 42 v0
increment_version OPTIMIZER_STATE 32 v0
create TEMPORARY 43 v0
destroy TEMPORARY 43 v0
destroy INPUT 42 v0
increment_version OPTIMIZER_STATE 32 v1
create INPUT 44 v0
increment_version OPTIMIZER_STATE 33 v0
create TEMPORARY 45 v0
destroy TEMPORARY 45 v0
destroy INPUT 44 v0
increment_version OPTIMIZER_STATE 33 v1
create ??? 46 v0
create INPUT 47 v0
create TEMPORARY 48 v0
create ??? 49 v0
destroy TEMPORARY 48 v0
destroy INPUT 47 v0
destroy ??? 46 v0
create INPUT 50 v0
increment_version ??? 49 v0
create TEMPORARY 51 v0
destroy TEMPORARY 51 v0
destroy INPUT 50 v0
increment_version PARAMETER 3 v0
create ??? 52 v0
increment_version OPTIMIZER_STATE 34 v1
create TEMPORARY 53 v0
destroy TEMPORARY 53 v0
destroy ??? 52 v0
create INPUT 54 v0
increment_version OPTIMIZER_STATE 35 v0
create TEMPORARY 55 v0
destroy TEMPORARY 55 v0
destroy INPUT 54 v0
increment_version OPTIMIZER_STATE 35 v1
create INPUT 56 v0
increment_version OPTIMIZER_STATE 36 v0
create TEMPORARY 57 v0
destroy TEMPORARY 57 v0
destroy INPUT 56 v0
increment_version OPTIMIZER_STATE 36 v1
create ??? 58 v0
create INPUT 59 v0
create TEMPORARY 60 v0
create ??? 61 v0
destroy TEMPORARY 60 v0
destroy INPUT 59 v0
destroy ??? 58 v0
create INPUT 62 v0
increment_version ??? 61 v0
create TEMPORARY 63 v0
destroy TEMPORARY 63 v0
destroy INPUT 62 v0
destroy ??? 49 v1
increment_version PARAMETER 4 v0
create ??? 64 v0
increment_version OPTIMIZER_STATE 37 v1
create TEMPORARY 65 v0
destroy TEMPORARY 65 v0
destroy ??? 64 v0
create INPUT 66 v0
increment_version OPTIMIZER_STATE 38 v0
create TEMPORARY 67 v0
destroy TEMPORARY 67 v0
destroy INPUT 66 v0
increment_version OPTIMIZER_STATE 38 v1
create INPUT 68 v0
increment_version OPTIMIZER_STATE 39 v0
create TEMPORARY 69 v0
destroy TEMPORARY 69 v0
destroy INPUT 68 v0
increment_version OPTIMIZER_STATE 39 v1
create ??? 70 v0
create INPUT 71 v0
create TEMPORARY 72 v0
create ??? 73 v0
destroy TEMPORARY 72 v0
destroy INPUT 71 v0
destroy ??? 70 v0
create INPUT 74 v0
increment_version ??? 73 v0
create TEMPORARY 75 v0
destroy TEMPORARY 75 v0
destroy INPUT 74 v0
destroy ??? 61 v1
increment_version PARAMETER 7 v0
destroy ??? 73 v1
increment_version GRADIENT 29 v0
increment_version GRADIENT 30 v0
increment_version GRADIENT 26 v0""")
preexisting PARAMETER 0(v0) 128 kB
preexisting PARAMETER 1(v0) 2 kB
preexisting PARAMETER 2(v0) 1024 kB
create INPUT 3(v0) 256 kB
create INPUT 4(v0) 2048 kB
create ACTIVATION 5(v0) 2048 kB
create ACTIVATION 6(v0) 2048 kB
destroy ACTIVATION 5(v0) 2048 kB
create ACTIVATION 7(v0) 2048 kB
create ACTIVATION 8(v0) 2048 kB
destroy ACTIVATION 7(v0) 2048 kB
create ACTIVATION 9(v0) 2048 kB
create TEMPORARY 10(v0) 2048 kB
destroy TEMPORARY 10(v0) 2048 kB
create AUTOGRAD_DETAIL 11(v0) 2048 kB
create AUTOGRAD_DETAIL 12(v0) 2048 kB
destroy AUTOGRAD_DETAIL 11(v0) 2048 kB
create GRADIENT 13(v0) 1024 kB
create AUTOGRAD_DETAIL 14(v0) 2048 kB
destroy AUTOGRAD_DETAIL 12(v0) 2048 kB
create AUTOGRAD_DETAIL 15(v0) 2048 kB
destroy AUTOGRAD_DETAIL 14(v0) 2048 kB
destroy ACTIVATION 6(v0) 2048 kB
create GRADIENT 16(v0) 128 kB
create GRADIENT 17(v0) 2 kB
destroy AUTOGRAD_DETAIL 15(v0) 2048 kB
create OPTIMIZER_STATE 18(v0) 128 kB
create OPTIMIZER_STATE 19(v0) 128 kB
create OPTIMIZER_STATE 20(v0) 2 kB
create OPTIMIZER_STATE 21(v0) 2 kB
create OPTIMIZER_STATE 22(v0) 1024 kB
create OPTIMIZER_STATE 23(v0) 1024 kB
increment_version OPTIMIZER_STATE 18(v0) 128 kB
increment_version OPTIMIZER_STATE 18(v1) 128 kB
increment_version OPTIMIZER_STATE 19(v0) 128 kB
increment_version OPTIMIZER_STATE 19(v1) 128 kB
create ??? 24(v0) 128 kB
create ??? 25(v0) 128 kB
destroy ??? 24(v0) 128 kB
increment_version ??? 25(v0) 128 kB
increment_version PARAMETER 0(v0) 128 kB
increment_version OPTIMIZER_STATE 20(v0) 2 kB
increment_version OPTIMIZER_STATE 20(v1) 2 kB
increment_version OPTIMIZER_STATE 21(v0) 2 kB
increment_version OPTIMIZER_STATE 21(v1) 2 kB
create ??? 26(v0) 2 kB
create ??? 27(v0) 2 kB
destroy ??? 26(v0) 2 kB
increment_version ??? 27(v0) 2 kB
destroy ??? 25(v1) 128 kB
increment_version PARAMETER 1(v0) 2 kB
increment_version OPTIMIZER_STATE 22(v0) 1024 kB
increment_version OPTIMIZER_STATE 22(v1) 1024 kB
increment_version OPTIMIZER_STATE 23(v0) 1024 kB
increment_version OPTIMIZER_STATE 23(v1) 1024 kB
create ??? 28(v0) 1024 kB
create ??? 29(v0) 1024 kB
destroy ??? 28(v0) 1024 kB
increment_version ??? 29(v0) 1024 kB
destroy ??? 27(v1) 2 kB
increment_version PARAMETER 2(v0) 1024 kB
destroy ??? 29(v1) 1024 kB
increment_version GRADIENT 16(v0) 128 kB
increment_version GRADIENT 17(v0) 2 kB
increment_version GRADIENT 13(v0) 1024 kB""")
if __name__ == "__main__":

View file

@ -2,6 +2,7 @@ import collections
import dataclasses
import enum
import itertools as it
import logging
from typing import (
Any,
cast,
@ -26,6 +27,7 @@ from torch._C._profiler import (
_TensorMetadata,
RecordScope,
)
from torch._utils import _element_size
from torch.profiler import _utils
TensorAndID = Tuple["TensorKey", int]
@ -305,6 +307,74 @@ class OpTree:
return self._sorted_nodes
class SizeMap:
def __init__(self, op_tree: OpTree) -> None:
self._values: Dict[TensorKey, int] = {}
for node in op_tree.sorted_nodes:
if node.typed[0] == _EventType.TorchOp:
for t in self._flat_tensor_inputs(node.typed[1]):
self._update_values(t)
elif node.typed[0] == _EventType.PyCall:
typed_fields = node.typed[1]
assert typed_fields.module is None or typed_fields.optimizer is None
if typed_fields.module is not None:
for _, p, p_grad in typed_fields.module.parameters:
self._update_values(p)
self._update_values(p_grad)
if typed_fields.optimizer is not None:
for p, p_grad, state in typed_fields.optimizer.parameters:
self._update_values(p)
self._update_values(p_grad)
for _, t in state:
self._update_values(t)
allocations: Dict[TensorKey, int] = {}
for node in op_tree.sorted_nodes:
if node.typed[0] == _EventType.Allocation:
alloc_fields = node.typed[1]
key = TensorKey.from_allocation(alloc_fields)
if key:
new_size = abs(alloc_fields.alloc_size)
prior_size = allocations.setdefault(key, new_size)
# It is possible to resize Storage in PyTorch, however we
# key on data pointer so most resizes will be treated as a
# change in storage. The one corner case that cannot be
# handled is `realloc` which successfully resizes the
# storage. At time of writing this is not done anywhere in
# the core PyTorch codebase.
if prior_size != new_size:
delta = f"{prior_size} vs. {new_size}"
logging.warn(f"Mismatch between allocation and free: {delta}")
self._values.update(allocations)
def _update_values(self, t: Optional[_TensorMetadata]) -> None:
key = TensorKey.from_tensor(t)
if key is not None and t is not None and t.layout == torch.strided:
# Scalars are represented as zero dim Tensors
n = max(i[0] * i[1] for i in zip(t.sizes or [1], t.strides or [1]))
num_bytes = n * _element_size(t.dtype)
assert num_bytes >= 0, f"{num_bytes}"
self._values[key] = max(self._values.get(key, 0), num_bytes)
@staticmethod
def _flat_tensor_inputs(op: _ExtraFields_TorchOp) -> Iterator[_TensorMetadata]:
for i in op.inputs:
if isinstance(i, _TensorMetadata):
yield i
elif isinstance(i, list):
for t in i:
yield t
def __getitem__(self, key: TensorKey):
return self._values[key]
@dataclasses.dataclass()
class DataFlowEdge:
input_version: Optional[int] = None
@ -564,6 +634,7 @@ class MemoryProfile:
def __init__(self, result: _ProfilerResult) -> None:
self._op_tree = OpTree(result)
self._data_flow_graph = DataFlowGraph(self._op_tree)
self._size_map = SizeMap(self._op_tree)
self._categories = CategoryDict()
self._set_gradients_and_temporaries()
@ -575,7 +646,7 @@ class MemoryProfile:
self._set_autograd_detail()
@property
def timeline(self) -> Tuple[Tuple[int, Action, TensorAndID], ...]:
def timeline(self) -> Tuple[Tuple[int, Action, TensorAndID, int], ...]:
t0 = min(event.start_time_ns for event in self._op_tree.dfs())
allocation_times: Dict[Tuple[TensorKey, bool], int] = {}
for event in self._op_tree.dfs():
@ -612,7 +683,10 @@ class MemoryProfile:
events.append((t, Action.DESTROY, (key, last_version[key])))
events.sort(key=lambda x: (x[0], x[1].value))
return tuple(events)
return tuple(
(time, action, (key, version), self._size_map[key])
for time, action, (key, version) in events
)
def _is_gradient(self, *args, **kwargs) -> bool:
return self._categories.get(*args, **kwargs) == Category.GRADIENT