mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132352 Approved by: https://github.com/ezyang ghstack dependencies: #132335, #132351
156 lines
5.5 KiB
Python
156 lines
5.5 KiB
Python
# Owner(s): ["oncall: profiler"]
|
|
|
|
# if tqdm is not shutdown properly, it will leave the monitor thread alive.
|
|
# This causes an issue in the multithreading test because we check all events
|
|
# in that test with their tids. The events that correspond to these lingering
|
|
# threads all have TID of (uint64_t)(-1) which is invalid.
|
|
# The work around is turnning off monitoring thread when tqdm is loaded.
|
|
# Since these are unit tests, it is safe to turn off monitor thread.
|
|
try:
|
|
import tqdm
|
|
|
|
tqdm.tqdm.monitor_interval = 0
|
|
except ImportError:
|
|
None
|
|
|
|
from typing import Any, Dict
|
|
|
|
import torch
|
|
import torch.optim
|
|
import torch.utils.data
|
|
import torch.utils.data.datapipes as dp
|
|
from torch.autograd import (
|
|
_record_function_with_args_enter,
|
|
_record_function_with_args_exit,
|
|
)
|
|
from torch.autograd.profiler import profile as _profile
|
|
from torch.profiler import kineto_available, record_function
|
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
|
|
|
|
Json = Dict[str, Any]
|
|
|
|
|
|
class TestRecordFunction(TestCase):
|
|
def _record_function_with_param(self):
|
|
u = torch.randn(3, 4, 5, requires_grad=True)
|
|
with _profile(
|
|
with_stack=True, use_kineto=kineto_available(), record_shapes=True
|
|
) as prof:
|
|
with record_function("## TEST 1 ##", "1, 2, 3"):
|
|
rf_handle = _record_function_with_args_enter(
|
|
"## TEST 2 ##", 1, False, 2.5, [u, u], "hello", u
|
|
)
|
|
_record_function_with_args_exit(rf_handle)
|
|
with record_function("## TEST 3 ##"):
|
|
rf_handle = _record_function_with_args_enter("## TEST 4 ##")
|
|
_record_function_with_args_exit(rf_handle)
|
|
return prof
|
|
|
|
def test_record_function(self):
|
|
prof_result = self._record_function_with_param()
|
|
found_test_1 = False
|
|
found_test_2 = False
|
|
found_test_3 = False
|
|
found_test_4 = False
|
|
for e in prof_result.function_events:
|
|
if "## TEST 1 ##" == e.name:
|
|
found_test_1 = True
|
|
self.assertTrue(e.input_shapes == [[]])
|
|
elif "## TEST 2 ##" == e.name:
|
|
found_test_2 = True
|
|
self.assertTrue(e.input_shapes == [[], [], [], [], [], [3, 4, 5]])
|
|
elif "## TEST 3 ##" == e.name:
|
|
found_test_3 = True
|
|
self.assertTrue(e.input_shapes == [])
|
|
elif "## TEST 4 ##" == e.name:
|
|
found_test_4 = True
|
|
self.assertTrue(e.input_shapes == [])
|
|
self.assertTrue(found_test_1)
|
|
self.assertTrue(found_test_2)
|
|
self.assertTrue(found_test_3)
|
|
self.assertTrue(found_test_4)
|
|
|
|
def test_datapipe_with_record_function(self):
|
|
with _profile(
|
|
with_stack=True, use_kineto=kineto_available(), record_shapes=True
|
|
) as prof:
|
|
input_dp1 = dp.iter.IterableWrapper(range(4))
|
|
input_dp2 = dp.iter.IterableWrapper(range(4, 8))
|
|
input_dp3 = dp.iter.IterableWrapper(range(8, 12))
|
|
output_dp = input_dp1.mux(input_dp2, input_dp3)
|
|
output = list(output_dp)
|
|
|
|
has_iter = False
|
|
has_mux = False
|
|
for e in prof.function_events:
|
|
if has_iter and has_mux:
|
|
break
|
|
|
|
if not has_iter and "IterableWrapper" in e.name:
|
|
has_iter = True
|
|
if not has_mux and "Multiplexer" in e.name:
|
|
has_mux = True
|
|
self.assertTrue(has_iter)
|
|
self.assertTrue(has_mux)
|
|
|
|
def test_datapipe_delegation_with_profiler(self):
|
|
class IDPIterator(torch.utils.data.IterDataPipe):
|
|
def __init__(self) -> None:
|
|
self.data = list(range(10))
|
|
self._idx = 0
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self):
|
|
if self._idx >= 10:
|
|
self._idx = 0
|
|
raise StopIteration
|
|
self._idx += 1
|
|
return self.data[self._idx - 1]
|
|
|
|
def get_value(self, idx):
|
|
return self.data[idx]
|
|
|
|
dp1 = IDPIterator() # The object itself is an iterator
|
|
self.assertEqual(5, dp1.get_value(5))
|
|
it_dp1 = iter(dp1) # This creates the 1st iterator
|
|
self.assertEqual(5, it_dp1.get_value(5)) # type: ignore[attr-defined]
|
|
self.assertEqual(list(range(10)), list(it_dp1))
|
|
|
|
class IDPDelegator(torch.utils.data.IterDataPipe):
|
|
def __init__(self, datapipe):
|
|
self.datapipe = datapipe
|
|
|
|
def __iter__(self):
|
|
return iter(self.datapipe)
|
|
|
|
dp2 = IDPDelegator(dp1)
|
|
it_dp2 = iter(dp2)
|
|
self.assertEqual(5, it_dp2.get_value(5))
|
|
self.assertEqual(list(range(10)), list(it_dp2))
|
|
|
|
def test_datapipe_with_record_function_fork(self):
|
|
with _profile(
|
|
with_stack=True, use_kineto=kineto_available(), record_shapes=True
|
|
) as prof:
|
|
input_dp = dp.iter.IterableWrapper(range(10))
|
|
dp1, dp2, dp3 = input_dp.fork(num_instances=3)
|
|
output1 = list(dp1)
|
|
has_iter = False
|
|
has_child = False
|
|
for e in prof.function_events:
|
|
if has_iter and has_child:
|
|
break
|
|
|
|
if not has_iter and "IterableWrapper" in e.name:
|
|
has_iter = True
|
|
if not has_child and "_ChildDataPipe" in e.name:
|
|
has_child = True
|
|
self.assertTrue(has_iter)
|
|
self.assertTrue(has_child)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|