mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Part of #123062 - #123062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128870 Approved by: https://github.com/fegin, https://github.com/wconstab
33 lines
885 B
Python
33 lines
885 B
Python
# mypy: allow-untyped-defs
|
|
import torchvision
|
|
|
|
import torch
|
|
from torch.distributed._tools import MemoryTracker
|
|
|
|
|
|
def run_one_model(net: torch.nn.Module, input: torch.Tensor):
|
|
net.cuda()
|
|
input = input.cuda()
|
|
|
|
# Create the memory Tracker
|
|
mem_tracker = MemoryTracker()
|
|
# start_monitor before the training iteration starts
|
|
mem_tracker.start_monitor(net)
|
|
|
|
# run one training iteration
|
|
net.zero_grad(True)
|
|
loss = net(input)
|
|
if isinstance(loss, dict):
|
|
loss = loss["out"]
|
|
loss.sum().backward()
|
|
net.zero_grad(set_to_none=True)
|
|
|
|
# stop monitoring after the training iteration ends
|
|
mem_tracker.stop()
|
|
# print the memory stats summary
|
|
mem_tracker.summary()
|
|
# plot the memory traces at operator level
|
|
mem_tracker.show_traces()
|
|
|
|
|
|
run_one_model(torchvision.models.resnet34(), torch.rand(32, 3, 224, 224, device="cuda"))
|