mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
There are multiple ways to indentify that a Tensor is a gradient. (A subset of which also give additional context.) So to start off I've made a utility to handle that determination. Differential Revision: [D39920730](https://our.internmc.facebook.com/intern/diff/D39920730/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/86802 Approved by: https://github.com/chaekit
114 lines
4.2 KiB
Python
114 lines
4.2 KiB
Python
import dataclasses
|
|
from typing import Any, Iterator, Optional, Tuple
|
|
|
|
import torch
|
|
from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata, RecordScope
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class _Storage:
|
|
"""Bundle storage pointer and id.
|
|
|
|
All profiling logic should use `allocation_id`, however it is useful to
|
|
print storage pointers for debugging and unit tests sometimes look up
|
|
values using the storage data pointer of a live Tensor."""
|
|
|
|
ptr: int
|
|
allocation_id: int
|
|
|
|
def __repr__(self) -> str:
|
|
return f"{hex(self.ptr):>18} ({self.allocation_id})"
|
|
|
|
def __eq__(self, other: Any) -> bool:
|
|
return isinstance(other, _Storage) and self.allocation_id == other.allocation_id
|
|
|
|
def __hash__(self) -> int:
|
|
return hash(self.allocation_id)
|
|
|
|
|
|
@dataclasses.dataclass(eq=True, unsafe_hash=True, frozen=True)
|
|
class TensorKey:
|
|
"""Hashable identifier for a storage which has been asigned an ID.
|
|
|
|
A detailed description of Tensor IDs and why they are needed is given in
|
|
`torch/csrc/profiler/collection.h` when `TensorID` is declared. To
|
|
summarize, multiple Storage buffers can map to the same logical Tensor.
|
|
This dataclass is used to refer to a concrete in-memory StorageImpl of
|
|
a Tensor.
|
|
"""
|
|
|
|
id: int
|
|
storage: _Storage
|
|
device: torch.device
|
|
|
|
def __repr__(self) -> str:
|
|
return f"id={self.id}: {repr(self.storage):<24} ({self.device})"
|
|
|
|
@staticmethod
|
|
def _make(
|
|
tensor_id: Optional[int],
|
|
storage_ptr: Optional[int],
|
|
allocation_id: Optional[int],
|
|
device: torch.device,
|
|
) -> Optional["TensorKey"]:
|
|
if (
|
|
tensor_id is not None
|
|
and storage_ptr is not None
|
|
and allocation_id is not None
|
|
):
|
|
return TensorKey(tensor_id, _Storage(storage_ptr, allocation_id), device)
|
|
return None
|
|
|
|
@classmethod
|
|
def from_tensor(cls, t: Optional[_TensorMetadata]) -> Optional["TensorKey"]:
|
|
if t is not None:
|
|
return cls._make(t.id, t.storage_data_ptr, t.allocation_id, t.device)
|
|
return None
|
|
|
|
|
|
def extract_gradients(
|
|
node: _ProfilerEvent,
|
|
) -> Iterator[Tuple[Optional[TensorKey], TensorKey]]:
|
|
children = node.children
|
|
|
|
# AccumulateGrad is used in the Autograd engine to handle gradient updates.
|
|
# There are two possible cases:
|
|
# 1) This is a newly created gradient Tensor. In that case there is nothing
|
|
# to accumulate, so autograd simply detaches the Tensor.
|
|
#
|
|
# 2) There is a preexisting gradient Tensor and we need to add the newly
|
|
# computed update. This is done with an in-place add (aten::add_) op.
|
|
# (The underscore suffix denotes "in-place".)
|
|
if (
|
|
node.typed[0] == _EventType.TorchOp
|
|
and node.typed[1].scope == RecordScope.BACKWARD_FUNCTION
|
|
# TODO(robieta): Move away from load bearing names
|
|
and node.name == "torch::autograd::AccumulateGrad"
|
|
and children
|
|
and children[0].typed[0] == _EventType.TorchOp
|
|
and children[0].name in ("aten::detach", "aten::add_")
|
|
and children[0].typed[1].inputs
|
|
and isinstance(children[0].typed[1].inputs[0], _TensorMetadata)
|
|
):
|
|
key = TensorKey.from_tensor(children[0].typed[1].inputs[0])
|
|
if key:
|
|
yield None, key
|
|
|
|
# We directly instrument `torch.nn.Module` and `torch.optim.Optimizer`
|
|
# NOTE: The values captured by the python tracer are cached; they can be
|
|
# used to build up labels but do not imply that a Tensor was live at
|
|
# a particular time.
|
|
elif node.typed[0] == _EventType.PyCall:
|
|
typed_fields = node.typed[1]
|
|
assert typed_fields.module is None or typed_fields.optimizer is None
|
|
if typed_fields.module is not None:
|
|
for _, p, p_grad in typed_fields.module.parameters:
|
|
p_grad_key = TensorKey.from_tensor(p_grad)
|
|
if p_grad_key is not None:
|
|
yield TensorKey.from_tensor(p), p_grad_key
|
|
|
|
if typed_fields.optimizer is not None:
|
|
for p, p_grad, _ in typed_fields.optimizer.parameters:
|
|
p_grad_key = TensorKey.from_tensor(p_grad)
|
|
if p_grad_key is not None:
|
|
yield TensorKey.from_tensor(p), p_grad_key
|