mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add compile time profiler for non fbcode targets (#126904)
This is a tool that allow profiling compile time using strobelight profiler, its a meta only tool. but works on non-fbcode targets. A follow up diff will unify this with caffe2/fb/strobelight/compile_time_profiler.py. example test: ``` run python tools/strobelight/examples/compile_time_profile_example.py ``` ``` python torch/utils/_strobelight/examples/compile_time_profile_example.py strobelight_compile_time_profiler, line 61, 2024-05-23 10:49:28,101, INFO: compile time strobelight profiling enabled strobelight_compile_time_profiler, line 93, 2024-05-23 10:49:28,102, INFO: Unique sample tag for this run is: 2024-05-23-10:49:282334638devvm4561.ash0.facebook.com strobelight_compile_time_profiler, line 94, 2024-05-23 10:49:28,102, INFO: You can use the following link to access the strobelight profile at the end of the run: https://www.internalfb.com/intern/scuba/query/?dataset=pyperf_experimental%2Fon_demand&drillstate=%7B%22purposes%22%3A[]%2C%22end%22%3A%22now%22%2C%22start%22%3A%22-30%20days%22%2C%22filterMode%22%3A%22DEFAULT%22%2C%22modifiers%22%3A[]%2C%22sampleCols%22%3A[]%2C%22cols%22%3A[%22namespace_id%22%2C%22namespace_process_id%22]%2C%22derivedCols%22%3A[]%2C%22mappedCols%22%3A[]%2C%22enumCols%22%3A[]%2C%22return_remainder%22%3Afalse%2C%22should_pivot%22%3Afalse%2C%22is_timeseries%22%3Afalse%2C%22hideEmptyColumns%22%3Afalse%2C%22timezone%22%3A%22America%2FLos_Angeles%22%2C%22compare%22%3A%22none%22%2C%22samplingRatio%22%3A%221%22%2C%22metric%22%3A%22count%22%2C%22aggregation_field%22%3A%22async_stack_complete%22%2C%22top%22%3A10000%2C%22aggregateList%22%3A[]%2C%22param_dimensions%22%3A[%7B%22dim%22%3A%22py_async_stack%22%2C%22op%22%3A%22edge%22%2C%22param%22%3A%220%22%2C%22anchor%22%3A%220%22%7D]%2C%22order%22%3A%22weight%22%2C%22order_desc%22%3Atrue%2C%22constraints%22%3A[[%7B%22column%22%3A%22sample_tags%22%2C%22op%22%3A%22all%22%2C%22value%22%3A[%22[%5C%222024-05-23-10:49:282334638devvm4561.ash0.facebook.com%5C%22]%22]%7D]]%2C%22c_constraints%22%3A[[]]%2C%22b_constraints%22%3A[[]]%2C%22ignoreGroupByInComparison%22%3Afalse%7D&view=GraphProfilerView&&normalized=1712358002&pool=uber strobelight_function_profiler, line 241, 2024-05-23 10:49:34,943, INFO: strobelight run id is: 3507039740348330 strobelight_function_profiler, line 243, 2024-05-23 10:50:00,907, INFO: strobelight profiling running strobelight_function_profiler, line 224, 2024-05-23 10:50:02,741, INFO: strobelight profiling stopped strobelight_function_profiler, line 215, 2024-05-23 10:50:06,173, INFO: Total samples: 7 strobelight_function_profiler, line 215, 2024-05-23 10:50:06,173, INFO: GraphProfiler (python stack): https://fburl.com/scuba/pyperf_experimental/on_demand/75cxdro3 strobelight_function_profiler, line 215, 2024-05-23 10:50:06,173, INFO: Icicle view (python stack): https://fburl.com/scuba/pyperf_experimental/on_demand/qsgydsee strobelight_compile_time_profiler, line 120, 2024-05-23 10:50:06,174, INFO: 1 strobelight success runs out of 1 non-recursive compilation events. strobelight_function_profiler, line 241, 2024-05-23 10:50:08,137, INFO: strobelight run id is: 8721740011604497 strobelight_function_profiler, line 243, 2024-05-23 10:50:34,801, INFO: strobelight profiling running strobelight_function_profiler, line 224, 2024-05-23 10:50:36,803, INFO: strobelight profiling stopped strobelight_function_profiler, line 215, 2024-05-23 10:50:41,289, INFO: Total samples: 3 strobelight_function_profiler, line 215, 2024-05-23 10:50:41,289, INFO: GraphProfiler (python stack): https://fburl.com/scuba/pyperf_experimental/on_demand/qmi2ucwp strobelight_function_profiler, line 215, 2024-05-23 10:50:41,289, INFO: Icicle view (python stack): https://fburl.com/scuba/pyperf_experimental/on_demand/7fjkhs9i strobelight_compile_time_profiler, line 120, 2024-05-23 10:50:41,289, INFO: 2 strobelight success runs out of 2 non-recursive compilation events. strobelight_function_profiler, line 241, 2024-05-23 10:50:43,597, INFO: strobelight run id is: 1932476082259558 strobelight_function_profiler, line 243, 2024-05-23 10:51:09,791, INFO: strobelight profiling running strobelight_function_profiler, line 224, 2024-05-23 10:51:11,883, INFO: strobelight profiling stopped strobelight_function_profiler, line 215, 2024-05-23 10:51:16,218, INFO: Total samples: 3 strobelight_function_profiler, line 215, 2024-05-23 10:51:16,218, INFO: GraphProfiler (python stack): https://fburl.com/scuba/pyperf_experimental/on_demand/vy1ujxec strobelight_function_profiler, line 215, 2024-05-23 10:51:16,218, INFO: Icicle view (python stack): https://fburl.com/scuba/pyperf_experimental/on_demand/2xgadviv strobelight_compile_time_profiler, line 120, 2024-05-23 10:51:16,219, INFO: 3 strobelight success runs out of 3 non-recursive compilation events. ``` or pass TORCH_COMPILE_STROBELIGHT=TRUE for any torch compile python program. ex running on XLNetLMHeadModel. ``` TORCH_COMPILE_STROBELIGHT=TRUE TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 time python benchmarks/dynamo/huggingface.py --ci --accuracy --timing --explain --inductor --device cuda --training --amp --only XLNetLMHeadModel ``` result: Pull Request resolved: https://github.com/pytorch/pytorch/pull/126904 Approved by: https://github.com/aorenste ghstack dependencies: #126444
This commit is contained in:
parent
2b72e2a596
commit
cdf2133186
6 changed files with 538 additions and 3 deletions
0
torch/_strobelight/__init__.py
Normal file
0
torch/_strobelight/__init__.py
Normal file
311
torch/_strobelight/cli_function_profiler.py
Normal file
311
torch/_strobelight/cli_function_profiler.py
Normal file
|
|
@ -0,0 +1,311 @@
|
|||
# mypy: disallow-untyped-defs
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import time
|
||||
from threading import Lock
|
||||
from timeit import default_timer as timer
|
||||
from typing import Any, List, Optional, Sequence
|
||||
|
||||
|
||||
logger = logging.getLogger("strobelight_function_profiler")
|
||||
|
||||
console_handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter(
|
||||
"%(name)s, line %(lineno)d, %(asctime)s, %(levelname)s: %(message)s"
|
||||
)
|
||||
console_handler.setFormatter(formatter)
|
||||
|
||||
logger.addHandler(console_handler)
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.propagate = False
|
||||
|
||||
|
||||
class StrobelightCLIProfilerError(Exception):
|
||||
"""
|
||||
Raised when an error happens during strobelight profiling
|
||||
"""
|
||||
|
||||
|
||||
def _pid_namespace_link(pid: Optional[int] = None) -> str:
|
||||
"""Returns the link to the process's namespace, example: pid:[4026531836]"""
|
||||
PID_NAMESPACE_PATH = "/proc/{}/ns/pid"
|
||||
pid = pid or os.getpid()
|
||||
return os.readlink(PID_NAMESPACE_PATH.format(pid))
|
||||
|
||||
|
||||
def _pid_namespace(pid: Optional[int] = None) -> int:
|
||||
"""Returns the process's namespace id"""
|
||||
pid = pid or os.getpid()
|
||||
link = _pid_namespace_link(pid)
|
||||
return int(link[link.find("[") + 1 : -1])
|
||||
|
||||
|
||||
def _command_to_string(command: Sequence[str]) -> str:
|
||||
return " ".join(command)
|
||||
|
||||
|
||||
class StrobelightCLIFunctionProfiler:
|
||||
"""
|
||||
Note: this is a Meta only tool.
|
||||
|
||||
StrobelightCLIFunctionProfiler can be used to profile a python function and
|
||||
generate a strobelight link with the results. It works on meta servers but
|
||||
does not requries an fbcode target.
|
||||
When stop_at_error is false(default), error during profiling does not prevent
|
||||
the work function from running.
|
||||
|
||||
Check function_profiler_example.py for an example.
|
||||
"""
|
||||
|
||||
# This lock is used to make sure only one thread is running the profiler at any point.
|
||||
_lock = Lock()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
stop_at_error: bool = False,
|
||||
max_profile_duration_sec: int = 60 * 10,
|
||||
sample_each: float = 1e7, # sample each sample_each cycles.
|
||||
run_user_name: str = "pytorch-strobelight-ondemand",
|
||||
timeout_wait_for_running_sec: int = 60,
|
||||
timeout_wait_for_finished_sec: int = 60,
|
||||
recorded_env_variables: Optional[List[str]] = None,
|
||||
sample_tags: Optional[List[str]] = None,
|
||||
stack_max_len: int = 127,
|
||||
async_stack_max_len: int = 127,
|
||||
):
|
||||
self.stop_at_error = stop_at_error
|
||||
self.max_profile_duration_sec = max_profile_duration_sec
|
||||
self.sample_each = sample_each
|
||||
self.run_user_name = run_user_name
|
||||
self.timeout_wait_for_running_sec = timeout_wait_for_running_sec
|
||||
self.timeout_wait_for_finished_sec = timeout_wait_for_finished_sec
|
||||
# Results of the most recent run.
|
||||
# Tracks the strobelight run id of the most recent run
|
||||
self.current_run_id: Optional[int] = None
|
||||
self.profile_result: Optional[List[str]] = None
|
||||
self.sample_tags = sample_tags
|
||||
|
||||
def _run_async(self) -> None:
|
||||
processId = os.getpid()
|
||||
namespace = _pid_namespace(processId)
|
||||
command = [
|
||||
"strobeclient",
|
||||
"run",
|
||||
"--profiler",
|
||||
"pyperf",
|
||||
"--event",
|
||||
"cycles",
|
||||
"--async",
|
||||
"--sample-interval",
|
||||
f"{int(self.sample_each)}",
|
||||
"--duration-ms",
|
||||
f"{int(self.max_profile_duration_sec * 1000)}",
|
||||
"--pid",
|
||||
f"{namespace}:{processId}",
|
||||
]
|
||||
|
||||
if self.sample_tags:
|
||||
command.append("--sample-tags")
|
||||
command.append(",".join(self.sample_tags))
|
||||
|
||||
logger.debug("running command: %s", _command_to_string(command))
|
||||
result = subprocess.run(command, capture_output=True)
|
||||
output = result.stderr.decode("utf-8")
|
||||
logger.debug("output:\n{%s}", output)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise StrobelightCLIProfilerError(
|
||||
f"failed to start strobelight profiling, error in run_async:{output}"
|
||||
)
|
||||
|
||||
if match := re.search(r"INFO Run Id: (-?\d+)", output):
|
||||
self.current_run_id = int(match.group(1))
|
||||
return
|
||||
|
||||
raise StrobelightCLIProfilerError(
|
||||
f"failed to start strobelight profiling, unexpected result {output}"
|
||||
)
|
||||
|
||||
def _wait_for_running(self, counter: int = 0) -> None:
|
||||
if counter > 20:
|
||||
raise StrobelightCLIProfilerError(
|
||||
"wait_for_running called more than 20 times"
|
||||
)
|
||||
|
||||
command = ["strobeclient", "getRunStatus", "--run-id", f"{self.current_run_id}"]
|
||||
logger.debug("running command: %s", _command_to_string(command))
|
||||
result = subprocess.run(command, capture_output=True)
|
||||
output = result.stderr.decode("utf-8")
|
||||
logger.debug("output:\n{%s}", output)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise StrobelightCLIProfilerError(
|
||||
f"failed to start strobelight profiling, error in wait_for_running:{output}"
|
||||
)
|
||||
|
||||
if match := re.search("Profile run status: (.*)", output):
|
||||
current_status = match.group(1)
|
||||
if current_status == "RUNNING":
|
||||
return
|
||||
elif current_status == "PREPARING":
|
||||
time.sleep(10)
|
||||
self._wait_for_running(counter + 1)
|
||||
return
|
||||
else:
|
||||
raise StrobelightCLIProfilerError(f"unexpected {current_status} phase")
|
||||
|
||||
raise StrobelightCLIProfilerError(f"unexpected output\n: {output} ")
|
||||
|
||||
def _stop_run(self) -> None:
|
||||
command = ["strobeclient", "stopRun", "--run-id", str(self.current_run_id)]
|
||||
logger.debug("running command: %s", _command_to_string(command))
|
||||
result = subprocess.run(command, capture_output=True)
|
||||
output = result.stderr.decode("utf-8")
|
||||
logger.debug("output:\n{%s}", output)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise StrobelightCLIProfilerError(
|
||||
f"failed to stop strobelight profiling, return code is not 0 :{output}"
|
||||
)
|
||||
|
||||
if match := re.search("INFO ::1:(.*)", output):
|
||||
current_status = match.group(1)
|
||||
if current_status.__contains__("Success!"):
|
||||
return
|
||||
else:
|
||||
raise StrobelightCLIProfilerError(
|
||||
f"failed to stop strobelight profiling, got {current_status} result"
|
||||
)
|
||||
|
||||
raise StrobelightCLIProfilerError(f"unexpected output\n: {output} ")
|
||||
|
||||
def _get_results(self) -> None:
|
||||
command = ["strobeclient", "getRunStatus", "--run-id", str(self.current_run_id)]
|
||||
logger.debug("running command: %s", _command_to_string(command))
|
||||
result = subprocess.run(command, capture_output=True)
|
||||
output = result.stderr.decode("utf-8")
|
||||
logger.debug("output:\n{%s}", output)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise StrobelightCLIProfilerError(
|
||||
f"failed to extract profiling results, return code is not 0 : {output}"
|
||||
)
|
||||
|
||||
if match := re.search("INFO ::1:(.*)", output):
|
||||
current_status = match.group(1)
|
||||
if current_status.__contains__("Profile run status: PROCESSING"):
|
||||
time.sleep(10)
|
||||
self._get_results()
|
||||
return
|
||||
elif not current_status.__contains__("Profile run finished with SUCCESS"):
|
||||
raise StrobelightCLIProfilerError(
|
||||
f"failed to extract profiling results, unexpected response {output}"
|
||||
)
|
||||
|
||||
self.profile_result = []
|
||||
for item in re.findall(
|
||||
r"(Total samples(.*)|GraphProfiler(.*)|Icicle view \(python stack\)(.*))",
|
||||
output,
|
||||
):
|
||||
self.profile_result += item[0]
|
||||
logger.info(item[0])
|
||||
|
||||
def _stop_strobelight_no_throw(
|
||||
self,
|
||||
collect_results: bool,
|
||||
) -> None:
|
||||
try:
|
||||
# call stop run
|
||||
self._stop_run()
|
||||
logger.info("strobelight profiling stopped")
|
||||
|
||||
logger.debug("collection stopped")
|
||||
|
||||
if not collect_results:
|
||||
return
|
||||
|
||||
self._get_results()
|
||||
except Exception as error:
|
||||
logger.warning("error during stop_strobelight", exc_info=True)
|
||||
|
||||
# Return true if strobelight started and is running. Never throw.
|
||||
def _start_strobelight(self) -> bool:
|
||||
strobelight_started = False
|
||||
try:
|
||||
self._run_async()
|
||||
strobelight_started = True
|
||||
logger.info("strobelight run id is: %s", self.current_run_id)
|
||||
self._wait_for_running()
|
||||
logger.info("strobelight profiling running")
|
||||
return True
|
||||
|
||||
except Exception as error:
|
||||
logger.warning("error during start_strobelight:", exc_info=True)
|
||||
if strobelight_started:
|
||||
self._stop_strobelight_no_throw(collect_results=False)
|
||||
return False
|
||||
|
||||
def profile(self, work_function: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
self.current_run_id = None
|
||||
self.profile_result = None
|
||||
|
||||
if locked := StrobelightCLIFunctionProfiler._lock.acquire(False):
|
||||
if not locked:
|
||||
if self.stop_at_error:
|
||||
raise StrobelightCLIProfilerError("concurrent runs not supported")
|
||||
|
||||
logger.warning("concurrent runs not supported")
|
||||
return work_function(*args, **kwargs)
|
||||
|
||||
started = self._start_strobelight()
|
||||
if not started:
|
||||
if self.stop_at_error:
|
||||
StrobelightCLIFunctionProfiler._lock.release()
|
||||
raise StrobelightCLIProfilerError(
|
||||
"failed to start strobelight profiling"
|
||||
)
|
||||
result = work_function(*args, **kwargs)
|
||||
StrobelightCLIFunctionProfiler._lock.release()
|
||||
return result
|
||||
|
||||
try:
|
||||
logger.debug("collection started")
|
||||
start = timer()
|
||||
result = work_function(*args, **kwargs)
|
||||
end = timer()
|
||||
total_time = end - start # Time in seconds, e.g. 5.38091952400282
|
||||
logger.info("work function took %s seconds", total_time)
|
||||
self._stop_strobelight_no_throw(collect_results=True)
|
||||
StrobelightCLIFunctionProfiler._lock.release()
|
||||
return result
|
||||
except Exception as error:
|
||||
logger.warning("work function throw exception", exc_info=True)
|
||||
self._stop_strobelight_no_throw(collect_results=False)
|
||||
StrobelightCLIFunctionProfiler._lock.release()
|
||||
raise error
|
||||
|
||||
|
||||
# A function decorator that wraps profile, if no profiler is provided one with
|
||||
# default args is created. A function can be annotated as:
|
||||
# @strobelight()
|
||||
# @strobelight(profiler = StrobelightFunctionProfiler(stop_at_error=True,..))
|
||||
# @strobelight(stop_at_error=True,...)
|
||||
def strobelight(
|
||||
profiler: Optional[StrobelightCLIFunctionProfiler] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
if not profiler:
|
||||
profiler = StrobelightCLIFunctionProfiler(**kwargs)
|
||||
|
||||
def strobelight_inner(work_function: Any) -> Any:
|
||||
@functools.wraps(work_function)
|
||||
def wrapper_function(*args: Any, **kwargs: Any) -> Any:
|
||||
return profiler.profile(work_function, *args, **kwargs)
|
||||
|
||||
return wrapper_function
|
||||
|
||||
return strobelight_inner
|
||||
156
torch/_strobelight/compile_time_profiler.py
Normal file
156
torch/_strobelight/compile_time_profiler.py
Normal file
|
|
@ -0,0 +1,156 @@
|
|||
# mypy: disallow-untyped-defs
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from datetime import datetime
|
||||
from socket import gethostname
|
||||
from typing import Any, Optional
|
||||
|
||||
from torch._strobelight.cli_function_profiler import StrobelightCLIFunctionProfiler
|
||||
|
||||
logger = logging.getLogger("strobelight_compile_time_profiler")
|
||||
|
||||
console_handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter(
|
||||
"%(name)s, line %(lineno)d, %(asctime)s, %(levelname)s: %(message)s"
|
||||
)
|
||||
console_handler.setFormatter(formatter)
|
||||
|
||||
logger.addHandler(console_handler)
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.propagate = False
|
||||
|
||||
|
||||
class StrobelightCompileTimeProfiler:
|
||||
success_profile_count: int = 0
|
||||
failed_profile_count: int = 0
|
||||
ignored_profile_runs: int = 0
|
||||
inside_profile_compile_time: bool = False
|
||||
enabled: bool = False
|
||||
# A unique identifier that is used as the run_user_name in the strobelight profile to
|
||||
# associate all compile time profiles together.
|
||||
identifier: Optional[str] = None
|
||||
|
||||
current_phase: Optional[str] = None
|
||||
|
||||
profiler: Optional[Any] = None
|
||||
|
||||
max_stack_length: int = int(
|
||||
os.environ.get("COMPILE_STROBELIGHT_MAX_STACK_LENGTH", 127)
|
||||
)
|
||||
max_profile_time: int = int(
|
||||
os.environ.get("COMPILE_STROBELIGHT_MAX_PROFILE_TIME", 60 * 30)
|
||||
)
|
||||
# Collect sample each x cycles.
|
||||
sample_each: int = int(
|
||||
float(os.environ.get("COMPILE_STROBELIGHT_SAMPLE_RATE", 1e7))
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def enable(cls, profiler_class: Any = StrobelightCLIFunctionProfiler) -> None:
|
||||
if cls.enabled:
|
||||
logger.info("compile time strobelight profiling already enabled")
|
||||
return
|
||||
|
||||
logger.info("compile time strobelight profiling enabled")
|
||||
|
||||
if profiler_class is StrobelightCLIFunctionProfiler:
|
||||
import shutil
|
||||
|
||||
if not shutil.which("strobeclient"):
|
||||
logger.info(
|
||||
"strobeclient not found, cant enable compile time strobelight profiling, seems"
|
||||
"like you are not on a FB machine."
|
||||
)
|
||||
return
|
||||
|
||||
cls.enabled = True
|
||||
cls._cls_init()
|
||||
# profiler_class should have public API similar to that of StrobelightCLIFunctionProfiler.
|
||||
# we have pass different functionProfilerClass for meta-internal fbcode targets.
|
||||
cls.profiler = profiler_class(
|
||||
sample_each=cls.sample_each,
|
||||
max_profile_duration_sec=cls.max_profile_time,
|
||||
stack_max_len=cls.max_stack_length,
|
||||
async_stack_max_len=cls.max_stack_length,
|
||||
run_user_name="pt2-profiler/"
|
||||
+ os.environ.get("USER", os.environ.get("USERNAME", "")),
|
||||
sample_tags={cls.identifier},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _cls_init(cls) -> None:
|
||||
cls.identifier = "{date}{pid}{hostname}".format(
|
||||
date=datetime.now().strftime("%Y-%m-%d-%H:%M:%S"),
|
||||
pid=os.getpid(),
|
||||
hostname=gethostname(),
|
||||
)
|
||||
|
||||
logger.info("Unique sample tag for this run is: %s", cls.identifier)
|
||||
logger.info(
|
||||
"You can use the following link to access the strobelight profile at the end of the run: %s",
|
||||
(
|
||||
"https://www.internalfb.com/intern/scuba/query/?dataset=pyperf_experime"
|
||||
"ntal%2Fon_demand&drillstate=%7B%22purposes%22%3A[]%2C%22end%22%3A%22no"
|
||||
"w%22%2C%22start%22%3A%22-30%20days%22%2C%22filterMode%22%3A%22DEFAULT%"
|
||||
"22%2C%22modifiers%22%3A[]%2C%22sampleCols%22%3A[]%2C%22cols%22%3A[%22n"
|
||||
"amespace_id%22%2C%22namespace_process_id%22]%2C%22derivedCols%22%3A[]%"
|
||||
"2C%22mappedCols%22%3A[]%2C%22enumCols%22%3A[]%2C%22return_remainder%22"
|
||||
"%3Afalse%2C%22should_pivot%22%3Afalse%2C%22is_timeseries%22%3Afalse%2C"
|
||||
"%22hideEmptyColumns%22%3Afalse%2C%22timezone%22%3A%22America%2FLos_Ang"
|
||||
"eles%22%2C%22compare%22%3A%22none%22%2C%22samplingRatio%22%3A%221%22%2"
|
||||
"C%22metric%22%3A%22count%22%2C%22aggregation_field%22%3A%22async_stack"
|
||||
"_complete%22%2C%22top%22%3A10000%2C%22aggregateList%22%3A[]%2C%22param"
|
||||
"_dimensions%22%3A[%7B%22dim%22%3A%22py_async_stack%22%2C%22op%22%3A%22"
|
||||
"edge%22%2C%22param%22%3A%220%22%2C%22anchor%22%3A%220%22%7D]%2C%22orde"
|
||||
"r%22%3A%22weight%22%2C%22order_desc%22%3Atrue%2C%22constraints%22%3A[["
|
||||
"%7B%22column%22%3A%22sample_tags%22%2C%22op%22%3A%22all%22%2C%22value%"
|
||||
f"22%3A[%22[%5C%22{cls.identifier}%5C%22]%22]%7D]]%2C%22c_constraints%22%3A[[]]%2C%22b"
|
||||
"_constraints%22%3A[[]]%2C%22ignoreGroupByInComparison%22%3Afalse%7D&vi"
|
||||
"ew=GraphProfilerView&&normalized=1712358002&pool=uber"
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _log_stats(cls) -> None:
|
||||
logger.info(
|
||||
"%s strobelight success runs out of %s non-recursive compilation events.",
|
||||
cls.success_profile_count,
|
||||
cls.success_profile_count + cls.failed_profile_count,
|
||||
)
|
||||
|
||||
# TODO use threadlevel meta data to tags to record phases.
|
||||
@classmethod
|
||||
def profile_compile_time(
|
||||
cls, func: Any, phase_name: str, *args: Any, **kwargs: Any
|
||||
) -> Any:
|
||||
if not cls.enabled:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
if cls.profiler is None:
|
||||
logger.error("profiler is not set")
|
||||
return
|
||||
|
||||
if cls.inside_profile_compile_time:
|
||||
cls.ignored_profile_runs += 1
|
||||
logger.info(
|
||||
"profile_compile_time is requested for phase: %s while already in running phase: %s, recursive call ignored",
|
||||
phase_name,
|
||||
cls.current_phase,
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
cls.inside_profile_compile_time = True
|
||||
cls.current_phase = phase_name
|
||||
|
||||
work_result = cls.profiler.profile(func, *args, **kwargs)
|
||||
|
||||
if cls.profiler.profile_result is not None:
|
||||
cls.success_profile_count += 1
|
||||
else:
|
||||
cls.failed_profile_count += 1
|
||||
|
||||
cls._log_stats()
|
||||
cls.inside_profile_compile_time = False
|
||||
return work_result
|
||||
35
torch/_strobelight/examples/cli_function_profiler_example.py
Normal file
35
torch/_strobelight/examples/cli_function_profiler_example.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
import torch
|
||||
|
||||
from torch._strobelight.cli_function_profiler import (
|
||||
strobelight,
|
||||
StrobelightCLIFunctionProfiler,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def fn(x, y, z):
|
||||
return x * y + z
|
||||
|
||||
# use decorator with default profiler or optional profile arguments.
|
||||
@strobelight(sample_each=10000, stop_at_error=False)
|
||||
@torch.compile()
|
||||
def work():
|
||||
for i in range(10):
|
||||
torch._dynamo.reset()
|
||||
for j in range(5):
|
||||
torch._dynamo.reset()
|
||||
fn(torch.rand(j, j), torch.rand(j, j), torch.rand(j, j))
|
||||
|
||||
work()
|
||||
|
||||
# or pass a profiler instance.
|
||||
profiler = StrobelightCLIFunctionProfiler(stop_at_error=False)
|
||||
|
||||
@strobelight(profiler, sample_tags=["something", "another"])
|
||||
def work2():
|
||||
sum = 0
|
||||
for i in range(100000000):
|
||||
sum += 1
|
||||
|
||||
work2()
|
||||
22
torch/_strobelight/examples/compile_time_profile_example.py
Normal file
22
torch/_strobelight/examples/compile_time_profile_example.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
import torch
|
||||
|
||||
from torch._strobelight.compile_time_profiler import StrobelightCompileTimeProfiler
|
||||
|
||||
if __name__ == "__main__":
|
||||
# You can pass TORCH_COMPILE_STROBELIGHT=True instead.
|
||||
StrobelightCompileTimeProfiler.enable()
|
||||
|
||||
def fn(x, y, z):
|
||||
return x * y + z
|
||||
|
||||
@torch.compile()
|
||||
def work(n):
|
||||
for i in range(3):
|
||||
for j in range(5):
|
||||
fn(torch.rand(n, n), torch.rand(n, n), torch.rand(n, n))
|
||||
|
||||
# Strobelight will be called only 3 times because dynamo will be disabled after
|
||||
# 3rd iteration.
|
||||
for i in range(3):
|
||||
torch._dynamo.reset()
|
||||
work(i)
|
||||
|
|
@ -6,9 +6,20 @@ import tempfile
|
|||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from torch._strobelight.compile_time_profiler import StrobelightCompileTimeProfiler
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
if os.environ.get("TORCH_COMPILE_STROBELIGHT", False):
|
||||
import shutil
|
||||
|
||||
if not shutil.which("strobeclient"):
|
||||
log.info(
|
||||
"TORCH_COMPILE_STROBELIGHT is true, but seems like you are not on a FB machine."
|
||||
)
|
||||
else:
|
||||
log.info("Strobelight profiler is enabled via environment variable")
|
||||
StrobelightCompileTimeProfiler.enable()
|
||||
|
||||
# this arbitrary-looking assortment of functionality is provided here
|
||||
# to have a central place for overrideable behavior. The motivating
|
||||
|
|
@ -62,8 +73,6 @@ def throw_abstract_impl_not_imported_error(opname, module, context):
|
|||
)
|
||||
|
||||
|
||||
# Meta only, act as nop otherwise.
|
||||
#
|
||||
# NB! This treats "skip" kwarg specially!!
|
||||
def compile_time_strobelight_meta(phase_name):
|
||||
def compile_time_strobelight_meta_inner(function):
|
||||
|
|
@ -71,7 +80,9 @@ def compile_time_strobelight_meta(phase_name):
|
|||
def wrapper_function(*args, **kwargs):
|
||||
if "skip" in kwargs:
|
||||
kwargs["skip"] = kwargs["skip"] + 1
|
||||
return function(*args, **kwargs)
|
||||
return StrobelightCompileTimeProfiler.profile_compile_time(
|
||||
function, phase_name, *args, **kwargs
|
||||
)
|
||||
|
||||
return wrapper_function
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue