pytorch/test/profiler/test_cpp_thread.py
Aaron Orenstein 8c356ce3da Fix lint errors in fbcode (#135614)
Summary: Fixed a bunch of fbcode imports that happened to work but confused autodeps.  After this autodeps still suggests "improvements" to TARGETS (which breaks our builds) but at least it can find all the imports.

Test Plan:
```
fbpython fbcode/tools/build/buck/linters/lint_autoformat.py --linter=autodeps --default-exec-timeout=1800 -- fbcode/caffe2/TARGETS fbcode/caffe2/test/TARGETS
```
Before:
```
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/testing.py:229) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https://fbur$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_export.py:87) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https://fburl$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/test_serdes.py:9) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https://fb$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_serdes.py:10) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https://fburl$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_retraceability.py:7) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https:$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/test_retraceability.py:6) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See ht$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_export_nonstrict.py:7) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See http$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/test_export_nonstrict.py:6) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See $
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/test_export_training_ir_to_run_decomp.py:8) when processing rule "test_export". Please make sure it's listed in the srcs parameter of an$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_export_training_ir_to_run_decomp.py:10) when processing rule "test_export". Please make sure it's listed in the srcs parameter of anoth$
ERROR while processing caffe2/test/TARGETS: Found "//python/typeshed_internal:typeshed_internal_library" owner for "cv2" but it is protected by visibility rules: [] (from caffe2/test/test_bundled_images.py:7) when processing rule "test_bundled_$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "caffe2.test.profiler_test_cpp_thread_lib" (from caffe2/test/profiler/test_cpp_thread.py:29) when processing rule "profiler_test_cpp_thread". Please make sure it's listed in t$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "torch._utils_internal.get_file_path_2" (from caffe2/test/test_custom_ops.py:23) when processing rule "custom_ops". Please make sure it's listed in the srcs parameter of anoth$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "torch._utils_internal.get_file_path_2" (from caffe2/test/test_public_bindings.py:13) when processing rule "public_bindings". Please make sure it's listed in the srcs paramete$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "torch._C._profiler.symbolize_tracebacks" (from caffe2/test/test_cuda.py:3348) when processing rule "test_cuda". Please make sure it's listed in the srcs parameter of another $
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "torch._C._profiler.gather_traceback" (from caffe2/test/test_cuda.py:3348) when processing rule "test_cuda". Please make sure it's listed in the srcs parameter of another rule$
ERROR while processing caffe2/test/TARGETS: Cannot find an owner for include <torch/csrc/autograd/profiler_kineto.h> (from caffe2/test/profiler/test_cpp_thread.cpp:2) when processing profiler_test_cpp_thread_lib.  Some things to try:
```

Differential Revision: D62049222

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135614
Approved by: https://github.com/oulgen, https://github.com/laithsakka
2024-09-13 02:04:34 +00:00

219 lines
7.2 KiB
Python

# Owner(s): ["oncall: profiler"]
import os
import shutil
import subprocess
from unittest import skipIf
import torch
import torch.utils.cpp_extension
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase
def remove_build_path():
default_build_root = torch.utils.cpp_extension.get_default_build_root()
if os.path.exists(default_build_root):
if IS_WINDOWS:
# rmtree returns permission error: [WinError 5] Access is denied
# on Windows, this is a word-around
subprocess.run(["rm", "-rf", default_build_root], stdout=subprocess.PIPE)
else:
shutil.rmtree(default_build_root)
def is_fbcode():
return not hasattr(torch.version, "git_version")
if is_fbcode():
import caffe2.test.profiler_test_cpp_thread_lib as cpp # @manual=//caffe2/test:profiler_test_cpp_thread_lib
else:
# cpp extensions use relative paths. Those paths are relative to
# this file, so we'll change the working directory temporarily
old_working_dir = os.getcwd()
os.chdir(os.path.dirname(os.path.abspath(__file__)))
cpp = torch.utils.cpp_extension.load(
name="profiler_test_cpp_thread_lib",
sources=[
"test_cpp_thread.cpp",
],
verbose=True,
)
# return the working directory (see setUp)
os.chdir(old_working_dir)
KinetoProfiler = None
IterationCount = 5
ActivateIteration = 2
def blueprint(text):
print(f"\33[34m{text}\33[0m")
# onIterationStart() will be called by C++ training engine in cpp_thread_test_lib.cpp
class PythonProfilerEventHandler(cpp.ProfilerEventHandler):
def onIterationStart(self, iteration: int) -> None:
global KinetoProfiler, IterationCount
# it is important to start the profiler on the same thread that step() is called
# and yes, onIterationStart() will always be called on the same thread
if iteration == 0:
# this also means step() starts on iteration 1, not 0
KinetoProfiler.start()
blueprint("starting kineto profiler")
elif iteration == IterationCount - 1:
KinetoProfiler.stop()
blueprint("stopping kineto profiler")
else:
blueprint("stepping kineto profiler")
KinetoProfiler.step()
def emulateTraining(self, iteration: int, thread_id: int) -> None:
# blueprint(f"training iteration {iteration} in thread {thread_id}")
device = torch.device("cuda")
# device = torch.device("cpu")
with torch.autograd.profiler.record_function("user_function"):
a = torch.ones(1, device=device)
b = torch.ones(1, device=device)
torch.add(a, b).cpu()
torch.cuda.synchronize()
class CppThreadTest(TestCase):
ThreadCount = 20 # set to 2 for debugging
EventHandler = None
TraceObject = None
@classmethod
def setUpClass(cls) -> None:
super(TestCase, cls).setUpClass()
CppThreadTest.EventHandler = PythonProfilerEventHandler()
cpp.ProfilerEventHandler.Register(CppThreadTest.EventHandler)
@classmethod
def tearDownClass(cls):
if not is_fbcode():
remove_build_path()
def setUp(self) -> None:
if not torch.cuda.is_available():
self.skipTest("Test machine does not have cuda")
# this clears off events from initialization
self.start_profiler(False)
cpp.start_threads(1, IterationCount, False)
def start_profiler(self, profile_memory):
global KinetoProfiler
KinetoProfiler = torch.profiler.profile(
schedule=torch.profiler.schedule(
wait=1, warmup=1, active=ActivateIteration, repeat=1
),
on_trace_ready=self.set_trace,
with_stack=True,
profile_memory=profile_memory,
record_shapes=True,
)
def set_trace(self, trace_obj) -> None:
CppThreadTest.TraceObject = trace_obj
def assert_text(self, condition, text, msg):
if condition:
print(f"\33[32m{text}\33[0m")
else:
print(f"\33[31m{text}\33[0m")
self.assertTrue(condition, msg)
def check_trace(self, expected, mem=False) -> None:
blueprint("verifying trace")
event_list = CppThreadTest.TraceObject.events()
for key, values in expected.items():
count = values[0]
min_count = count * (ActivateIteration - 1)
device = values[1]
filtered = filter(
lambda ev: ev.name == key
and str(ev.device_type) == f"DeviceType.{device}",
event_list,
)
if mem:
actual = 0
for ev in filtered:
sev = str(ev)
has_cuda_memory_usage = (
sev.find("cuda_memory_usage=0 ") < 0
and sev.find("cuda_memory_usage=") > 0
)
if has_cuda_memory_usage:
actual += 1
self.assert_text(
actual >= min_count,
f"{key}: {actual} >= {min_count}",
"not enough event with cuda_memory_usage set",
)
else:
actual = len(list(filtered))
if count == 1: # test_without
count *= ActivateIteration
self.assert_text(
actual == count,
f"{key}: {actual} == {count}",
"baseline event count incorrect",
)
else:
self.assert_text(
actual >= min_count,
f"{key}: {actual} >= {min_count}",
"not enough event recorded",
)
@skipIf(
IS_WINDOWS,
"Failing on windows cuda, see https://github.com/pytorch/pytorch/pull/130037 for slightly more context",
)
def test_with_enable_profiler_in_child_thread(self) -> None:
self.start_profiler(False)
cpp.start_threads(self.ThreadCount, IterationCount, True)
self.check_trace(
{
"aten::add": [self.ThreadCount, "CPU"],
"user_function": [self.ThreadCount, "CUDA"],
}
)
@skipIf(
IS_WINDOWS,
"Failing on windows cuda, see https://github.com/pytorch/pytorch/pull/130037 for slightly more context",
)
def test_without_enable_profiler_in_child_thread(self) -> None:
self.start_profiler(False)
cpp.start_threads(self.ThreadCount, IterationCount, False)
self.check_trace(
{
"aten::add": [1, "CPU"],
"user_function": [1, "CUDA"],
}
)
@skipIf(
IS_WINDOWS,
"Failing on windows cuda, see https://github.com/pytorch/pytorch/pull/130037 for slightly more context",
)
def test_profile_memory(self) -> None:
self.start_profiler(True)
cpp.start_threads(self.ThreadCount, IterationCount, True)
self.check_trace(
{
"aten::add": [self.ThreadCount, "CPU"],
},
mem=True,
)
if __name__ == "__main__":
run_tests()