mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Call profiler step via optimizer post hook (#90101)
This PR adds the `_profile_using_dynolog` function to `torch/__init__.py`. The `_profile_using_dynolog` method allows registering the optimizer step post hook. This is required to collect iteration based traces using dynolog.
Other related changes for tests to pass:
1. Updated `optimizer.pyi`
1. Updated `overrides.py`
1. The test `test_kineto_profiler_multiple_steppers` in `test_profiler.py` has been broken down into two cases:
- `test_kineto_profiler_multiple_steppers_with_override_True` : this test uses the override argument
- `test_kineto_profiler_multiple_steppers_with_override_False` : this test uses the environment variable
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90101
Approved by: https://github.com/albanD
This commit is contained in:
parent
6783db13ef
commit
f4b804eeaa
3 changed files with 88 additions and 1 deletions
|
|
@ -8,10 +8,13 @@ import re
|
|||
import tempfile
|
||||
import textwrap
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
|
||||
import expecttest
|
||||
import subprocess
|
||||
import sys
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim
|
||||
|
|
@ -1325,6 +1328,78 @@ class TestProfiler(TestCase):
|
|||
self.assertTrue(len(e.input_shapes) > 0)
|
||||
self.assertTrue(len(e.input_shapes[0]) > 0)
|
||||
|
||||
@patch.dict(os.environ, {"KINETO_USE_DAEMON": "1"})
|
||||
def test_kineto_profiler_with_environment_variable(self):
|
||||
script = """
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.profiler import supported_activities, profile
|
||||
from torch.autograd.profiler import KinetoStepTracker
|
||||
|
||||
class SimpleNet(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(10, 5)
|
||||
self.fc2 = nn.Linear(5, 2)
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc2(self.fc1(x))
|
||||
|
||||
|
||||
def payload(use_cuda=False):
|
||||
x = torch.randn(10, 10)
|
||||
if use_cuda:
|
||||
x = x.cuda()
|
||||
y = torch.randn(10, 10)
|
||||
if use_cuda:
|
||||
y = y.cuda()
|
||||
z = torch.mm(x, y)
|
||||
z = z + y
|
||||
if use_cuda:
|
||||
z = z.cpu()
|
||||
|
||||
niters = 8
|
||||
use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
|
||||
net = SimpleNet()
|
||||
opt = torch.optim.SGD(net.parameters(), lr=0.01)
|
||||
opt.zero_grad()
|
||||
inputs = torch.rand(10)
|
||||
|
||||
with profile(activities=supported_activities()):
|
||||
payload(use_cuda=use_cuda)
|
||||
|
||||
initial_step = KinetoStepTracker.current_step()
|
||||
|
||||
def run_batch():
|
||||
out = net(inputs)
|
||||
loss = torch.nn.functional.cross_entropy(out, torch.rand(2))
|
||||
loss.backward()
|
||||
opt.step()
|
||||
|
||||
for _ in range(niters):
|
||||
run_batch()
|
||||
|
||||
with profile(
|
||||
activities=supported_activities(),
|
||||
schedule=torch.profiler.schedule(
|
||||
wait=1,
|
||||
warmup=1,
|
||||
active=2),
|
||||
) as p:
|
||||
for _ in range(niters):
|
||||
run_batch()
|
||||
p.step()
|
||||
assert KinetoStepTracker.current_step() == initial_step + 2 * niters
|
||||
"""
|
||||
try:
|
||||
subprocess.check_output(
|
||||
[sys.executable, '-W', 'all', '-c', script],
|
||||
cwd=os.path.dirname(os.path.realpath(__file__))
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
if e.returncode != 0:
|
||||
self.assertTrue(False, "Kineto is not working properly with the Dynolog environment variable")
|
||||
|
||||
|
||||
def find_node_with_name(nodes, name):
|
||||
for node in _utils.traverse_dfs(nodes):
|
||||
|
|
|
|||
|
|
@ -4,6 +4,9 @@ from torch.utils.hooks import RemovableHandle
|
|||
|
||||
_params_t = Union[Iterable[Tensor], Iterable[Dict[str, Any]]]
|
||||
|
||||
def register_optimizer_step_pre_hook(hook: Callable[..., None]) -> RemovableHandle: ...
|
||||
|
||||
def register_optimizer_step_post_hook(hook: Callable[..., None]) -> RemovableHandle: ...
|
||||
|
||||
class Optimizer:
|
||||
defaults: Dict[str, Any]
|
||||
|
|
|
|||
|
|
@ -7,9 +7,12 @@ examine their input shapes and stack traces, study device kernel activity and vi
|
|||
An earlier version of the API in :mod:`torch.autograd` module is considered legacy and will be deprecated.
|
||||
|
||||
"""
|
||||
import os
|
||||
|
||||
from torch._C._autograd import _supported_activities, DeviceType, kineto_available
|
||||
from torch._C._profiler import _ExperimentalConfig, ProfilerActivity, RecordScope
|
||||
from torch.autograd.profiler import record_function
|
||||
from torch.autograd.profiler import record_function, KinetoStepTracker
|
||||
from torch.optim.optimizer import register_optimizer_step_post_hook
|
||||
|
||||
from .profiler import (
|
||||
_KinetoProfile,
|
||||
|
|
@ -35,3 +38,9 @@ __all__ = [
|
|||
]
|
||||
|
||||
from . import itt
|
||||
|
||||
def _optimizer_post_hook(optimizer, args, kwargs):
|
||||
KinetoStepTracker.increment_step("Optimizer")
|
||||
|
||||
if os.environ.get("KINETO_USE_DAEMON", None):
|
||||
_ = register_optimizer_step_post_hook(_optimizer_post_hook)
|
||||
|
|
|
|||
Loading…
Reference in a new issue