diff --git a/torch/_strobelight/__init__.py b/torch/_strobelight/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/torch/_strobelight/cli_function_profiler.py b/torch/_strobelight/cli_function_profiler.py new file mode 100644 index 00000000000..72e44ec42ac --- /dev/null +++ b/torch/_strobelight/cli_function_profiler.py @@ -0,0 +1,311 @@ +# mypy: disallow-untyped-defs + +import functools +import logging +import os +import re +import subprocess +import time +from threading import Lock +from timeit import default_timer as timer +from typing import Any, List, Optional, Sequence + + +logger = logging.getLogger("strobelight_function_profiler") + +console_handler = logging.StreamHandler() +formatter = logging.Formatter( + "%(name)s, line %(lineno)d, %(asctime)s, %(levelname)s: %(message)s" +) +console_handler.setFormatter(formatter) + +logger.addHandler(console_handler) +logger.setLevel(logging.INFO) +logger.propagate = False + + +class StrobelightCLIProfilerError(Exception): + """ + Raised when an error happens during strobelight profiling + """ + + +def _pid_namespace_link(pid: Optional[int] = None) -> str: + """Returns the link to the process's namespace, example: pid:[4026531836]""" + PID_NAMESPACE_PATH = "/proc/{}/ns/pid" + pid = pid or os.getpid() + return os.readlink(PID_NAMESPACE_PATH.format(pid)) + + +def _pid_namespace(pid: Optional[int] = None) -> int: + """Returns the process's namespace id""" + pid = pid or os.getpid() + link = _pid_namespace_link(pid) + return int(link[link.find("[") + 1 : -1]) + + +def _command_to_string(command: Sequence[str]) -> str: + return " ".join(command) + + +class StrobelightCLIFunctionProfiler: + """ + Note: this is a Meta only tool. + + StrobelightCLIFunctionProfiler can be used to profile a python function and + generate a strobelight link with the results. It works on meta servers but + does not requries an fbcode target. + When stop_at_error is false(default), error during profiling does not prevent + the work function from running. + + Check function_profiler_example.py for an example. + """ + + # This lock is used to make sure only one thread is running the profiler at any point. + _lock = Lock() + + def __init__( + self, + *, + stop_at_error: bool = False, + max_profile_duration_sec: int = 60 * 10, + sample_each: float = 1e7, # sample each sample_each cycles. + run_user_name: str = "pytorch-strobelight-ondemand", + timeout_wait_for_running_sec: int = 60, + timeout_wait_for_finished_sec: int = 60, + recorded_env_variables: Optional[List[str]] = None, + sample_tags: Optional[List[str]] = None, + stack_max_len: int = 127, + async_stack_max_len: int = 127, + ): + self.stop_at_error = stop_at_error + self.max_profile_duration_sec = max_profile_duration_sec + self.sample_each = sample_each + self.run_user_name = run_user_name + self.timeout_wait_for_running_sec = timeout_wait_for_running_sec + self.timeout_wait_for_finished_sec = timeout_wait_for_finished_sec + # Results of the most recent run. + # Tracks the strobelight run id of the most recent run + self.current_run_id: Optional[int] = None + self.profile_result: Optional[List[str]] = None + self.sample_tags = sample_tags + + def _run_async(self) -> None: + processId = os.getpid() + namespace = _pid_namespace(processId) + command = [ + "strobeclient", + "run", + "--profiler", + "pyperf", + "--event", + "cycles", + "--async", + "--sample-interval", + f"{int(self.sample_each)}", + "--duration-ms", + f"{int(self.max_profile_duration_sec * 1000)}", + "--pid", + f"{namespace}:{processId}", + ] + + if self.sample_tags: + command.append("--sample-tags") + command.append(",".join(self.sample_tags)) + + logger.debug("running command: %s", _command_to_string(command)) + result = subprocess.run(command, capture_output=True) + output = result.stderr.decode("utf-8") + logger.debug("output:\n{%s}", output) + + if result.returncode != 0: + raise StrobelightCLIProfilerError( + f"failed to start strobelight profiling, error in run_async:{output}" + ) + + if match := re.search(r"INFO Run Id: (-?\d+)", output): + self.current_run_id = int(match.group(1)) + return + + raise StrobelightCLIProfilerError( + f"failed to start strobelight profiling, unexpected result {output}" + ) + + def _wait_for_running(self, counter: int = 0) -> None: + if counter > 20: + raise StrobelightCLIProfilerError( + "wait_for_running called more than 20 times" + ) + + command = ["strobeclient", "getRunStatus", "--run-id", f"{self.current_run_id}"] + logger.debug("running command: %s", _command_to_string(command)) + result = subprocess.run(command, capture_output=True) + output = result.stderr.decode("utf-8") + logger.debug("output:\n{%s}", output) + + if result.returncode != 0: + raise StrobelightCLIProfilerError( + f"failed to start strobelight profiling, error in wait_for_running:{output}" + ) + + if match := re.search("Profile run status: (.*)", output): + current_status = match.group(1) + if current_status == "RUNNING": + return + elif current_status == "PREPARING": + time.sleep(10) + self._wait_for_running(counter + 1) + return + else: + raise StrobelightCLIProfilerError(f"unexpected {current_status} phase") + + raise StrobelightCLIProfilerError(f"unexpected output\n: {output} ") + + def _stop_run(self) -> None: + command = ["strobeclient", "stopRun", "--run-id", str(self.current_run_id)] + logger.debug("running command: %s", _command_to_string(command)) + result = subprocess.run(command, capture_output=True) + output = result.stderr.decode("utf-8") + logger.debug("output:\n{%s}", output) + + if result.returncode != 0: + raise StrobelightCLIProfilerError( + f"failed to stop strobelight profiling, return code is not 0 :{output}" + ) + + if match := re.search("INFO ::1:(.*)", output): + current_status = match.group(1) + if current_status.__contains__("Success!"): + return + else: + raise StrobelightCLIProfilerError( + f"failed to stop strobelight profiling, got {current_status} result" + ) + + raise StrobelightCLIProfilerError(f"unexpected output\n: {output} ") + + def _get_results(self) -> None: + command = ["strobeclient", "getRunStatus", "--run-id", str(self.current_run_id)] + logger.debug("running command: %s", _command_to_string(command)) + result = subprocess.run(command, capture_output=True) + output = result.stderr.decode("utf-8") + logger.debug("output:\n{%s}", output) + + if result.returncode != 0: + raise StrobelightCLIProfilerError( + f"failed to extract profiling results, return code is not 0 : {output}" + ) + + if match := re.search("INFO ::1:(.*)", output): + current_status = match.group(1) + if current_status.__contains__("Profile run status: PROCESSING"): + time.sleep(10) + self._get_results() + return + elif not current_status.__contains__("Profile run finished with SUCCESS"): + raise StrobelightCLIProfilerError( + f"failed to extract profiling results, unexpected response {output}" + ) + + self.profile_result = [] + for item in re.findall( + r"(Total samples(.*)|GraphProfiler(.*)|Icicle view \(python stack\)(.*))", + output, + ): + self.profile_result += item[0] + logger.info(item[0]) + + def _stop_strobelight_no_throw( + self, + collect_results: bool, + ) -> None: + try: + # call stop run + self._stop_run() + logger.info("strobelight profiling stopped") + + logger.debug("collection stopped") + + if not collect_results: + return + + self._get_results() + except Exception as error: + logger.warning("error during stop_strobelight", exc_info=True) + + # Return true if strobelight started and is running. Never throw. + def _start_strobelight(self) -> bool: + strobelight_started = False + try: + self._run_async() + strobelight_started = True + logger.info("strobelight run id is: %s", self.current_run_id) + self._wait_for_running() + logger.info("strobelight profiling running") + return True + + except Exception as error: + logger.warning("error during start_strobelight:", exc_info=True) + if strobelight_started: + self._stop_strobelight_no_throw(collect_results=False) + return False + + def profile(self, work_function: Any, *args: Any, **kwargs: Any) -> Any: + self.current_run_id = None + self.profile_result = None + + if locked := StrobelightCLIFunctionProfiler._lock.acquire(False): + if not locked: + if self.stop_at_error: + raise StrobelightCLIProfilerError("concurrent runs not supported") + + logger.warning("concurrent runs not supported") + return work_function(*args, **kwargs) + + started = self._start_strobelight() + if not started: + if self.stop_at_error: + StrobelightCLIFunctionProfiler._lock.release() + raise StrobelightCLIProfilerError( + "failed to start strobelight profiling" + ) + result = work_function(*args, **kwargs) + StrobelightCLIFunctionProfiler._lock.release() + return result + + try: + logger.debug("collection started") + start = timer() + result = work_function(*args, **kwargs) + end = timer() + total_time = end - start # Time in seconds, e.g. 5.38091952400282 + logger.info("work function took %s seconds", total_time) + self._stop_strobelight_no_throw(collect_results=True) + StrobelightCLIFunctionProfiler._lock.release() + return result + except Exception as error: + logger.warning("work function throw exception", exc_info=True) + self._stop_strobelight_no_throw(collect_results=False) + StrobelightCLIFunctionProfiler._lock.release() + raise error + + +# A function decorator that wraps profile, if no profiler is provided one with +# default args is created. A function can be annotated as: +# @strobelight() +# @strobelight(profiler = StrobelightFunctionProfiler(stop_at_error=True,..)) +# @strobelight(stop_at_error=True,...) +def strobelight( + profiler: Optional[StrobelightCLIFunctionProfiler] = None, **kwargs: Any +) -> Any: + if not profiler: + profiler = StrobelightCLIFunctionProfiler(**kwargs) + + def strobelight_inner(work_function: Any) -> Any: + @functools.wraps(work_function) + def wrapper_function(*args: Any, **kwargs: Any) -> Any: + return profiler.profile(work_function, *args, **kwargs) + + return wrapper_function + + return strobelight_inner diff --git a/torch/_strobelight/compile_time_profiler.py b/torch/_strobelight/compile_time_profiler.py new file mode 100644 index 00000000000..2c7af6b100c --- /dev/null +++ b/torch/_strobelight/compile_time_profiler.py @@ -0,0 +1,156 @@ +# mypy: disallow-untyped-defs + +import logging +import os + +from datetime import datetime +from socket import gethostname +from typing import Any, Optional + +from torch._strobelight.cli_function_profiler import StrobelightCLIFunctionProfiler + +logger = logging.getLogger("strobelight_compile_time_profiler") + +console_handler = logging.StreamHandler() +formatter = logging.Formatter( + "%(name)s, line %(lineno)d, %(asctime)s, %(levelname)s: %(message)s" +) +console_handler.setFormatter(formatter) + +logger.addHandler(console_handler) +logger.setLevel(logging.INFO) +logger.propagate = False + + +class StrobelightCompileTimeProfiler: + success_profile_count: int = 0 + failed_profile_count: int = 0 + ignored_profile_runs: int = 0 + inside_profile_compile_time: bool = False + enabled: bool = False + # A unique identifier that is used as the run_user_name in the strobelight profile to + # associate all compile time profiles together. + identifier: Optional[str] = None + + current_phase: Optional[str] = None + + profiler: Optional[Any] = None + + max_stack_length: int = int( + os.environ.get("COMPILE_STROBELIGHT_MAX_STACK_LENGTH", 127) + ) + max_profile_time: int = int( + os.environ.get("COMPILE_STROBELIGHT_MAX_PROFILE_TIME", 60 * 30) + ) + # Collect sample each x cycles. + sample_each: int = int( + float(os.environ.get("COMPILE_STROBELIGHT_SAMPLE_RATE", 1e7)) + ) + + @classmethod + def enable(cls, profiler_class: Any = StrobelightCLIFunctionProfiler) -> None: + if cls.enabled: + logger.info("compile time strobelight profiling already enabled") + return + + logger.info("compile time strobelight profiling enabled") + + if profiler_class is StrobelightCLIFunctionProfiler: + import shutil + + if not shutil.which("strobeclient"): + logger.info( + "strobeclient not found, cant enable compile time strobelight profiling, seems" + "like you are not on a FB machine." + ) + return + + cls.enabled = True + cls._cls_init() + # profiler_class should have public API similar to that of StrobelightCLIFunctionProfiler. + # we have pass different functionProfilerClass for meta-internal fbcode targets. + cls.profiler = profiler_class( + sample_each=cls.sample_each, + max_profile_duration_sec=cls.max_profile_time, + stack_max_len=cls.max_stack_length, + async_stack_max_len=cls.max_stack_length, + run_user_name="pt2-profiler/" + + os.environ.get("USER", os.environ.get("USERNAME", "")), + sample_tags={cls.identifier}, + ) + + @classmethod + def _cls_init(cls) -> None: + cls.identifier = "{date}{pid}{hostname}".format( + date=datetime.now().strftime("%Y-%m-%d-%H:%M:%S"), + pid=os.getpid(), + hostname=gethostname(), + ) + + logger.info("Unique sample tag for this run is: %s", cls.identifier) + logger.info( + "You can use the following link to access the strobelight profile at the end of the run: %s", + ( + "https://www.internalfb.com/intern/scuba/query/?dataset=pyperf_experime" + "ntal%2Fon_demand&drillstate=%7B%22purposes%22%3A[]%2C%22end%22%3A%22no" + "w%22%2C%22start%22%3A%22-30%20days%22%2C%22filterMode%22%3A%22DEFAULT%" + "22%2C%22modifiers%22%3A[]%2C%22sampleCols%22%3A[]%2C%22cols%22%3A[%22n" + "amespace_id%22%2C%22namespace_process_id%22]%2C%22derivedCols%22%3A[]%" + "2C%22mappedCols%22%3A[]%2C%22enumCols%22%3A[]%2C%22return_remainder%22" + "%3Afalse%2C%22should_pivot%22%3Afalse%2C%22is_timeseries%22%3Afalse%2C" + "%22hideEmptyColumns%22%3Afalse%2C%22timezone%22%3A%22America%2FLos_Ang" + "eles%22%2C%22compare%22%3A%22none%22%2C%22samplingRatio%22%3A%221%22%2" + "C%22metric%22%3A%22count%22%2C%22aggregation_field%22%3A%22async_stack" + "_complete%22%2C%22top%22%3A10000%2C%22aggregateList%22%3A[]%2C%22param" + "_dimensions%22%3A[%7B%22dim%22%3A%22py_async_stack%22%2C%22op%22%3A%22" + "edge%22%2C%22param%22%3A%220%22%2C%22anchor%22%3A%220%22%7D]%2C%22orde" + "r%22%3A%22weight%22%2C%22order_desc%22%3Atrue%2C%22constraints%22%3A[[" + "%7B%22column%22%3A%22sample_tags%22%2C%22op%22%3A%22all%22%2C%22value%" + f"22%3A[%22[%5C%22{cls.identifier}%5C%22]%22]%7D]]%2C%22c_constraints%22%3A[[]]%2C%22b" + "_constraints%22%3A[[]]%2C%22ignoreGroupByInComparison%22%3Afalse%7D&vi" + "ew=GraphProfilerView&&normalized=1712358002&pool=uber" + ), + ) + + @classmethod + def _log_stats(cls) -> None: + logger.info( + "%s strobelight success runs out of %s non-recursive compilation events.", + cls.success_profile_count, + cls.success_profile_count + cls.failed_profile_count, + ) + + # TODO use threadlevel meta data to tags to record phases. + @classmethod + def profile_compile_time( + cls, func: Any, phase_name: str, *args: Any, **kwargs: Any + ) -> Any: + if not cls.enabled: + return func(*args, **kwargs) + + if cls.profiler is None: + logger.error("profiler is not set") + return + + if cls.inside_profile_compile_time: + cls.ignored_profile_runs += 1 + logger.info( + "profile_compile_time is requested for phase: %s while already in running phase: %s, recursive call ignored", + phase_name, + cls.current_phase, + ) + return func(*args, **kwargs) + + cls.inside_profile_compile_time = True + cls.current_phase = phase_name + + work_result = cls.profiler.profile(func, *args, **kwargs) + + if cls.profiler.profile_result is not None: + cls.success_profile_count += 1 + else: + cls.failed_profile_count += 1 + + cls._log_stats() + cls.inside_profile_compile_time = False + return work_result diff --git a/torch/_strobelight/examples/cli_function_profiler_example.py b/torch/_strobelight/examples/cli_function_profiler_example.py new file mode 100644 index 00000000000..8142ef1bdc7 --- /dev/null +++ b/torch/_strobelight/examples/cli_function_profiler_example.py @@ -0,0 +1,35 @@ +import torch + +from torch._strobelight.cli_function_profiler import ( + strobelight, + StrobelightCLIFunctionProfiler, +) + + +if __name__ == "__main__": + + def fn(x, y, z): + return x * y + z + + # use decorator with default profiler or optional profile arguments. + @strobelight(sample_each=10000, stop_at_error=False) + @torch.compile() + def work(): + for i in range(10): + torch._dynamo.reset() + for j in range(5): + torch._dynamo.reset() + fn(torch.rand(j, j), torch.rand(j, j), torch.rand(j, j)) + + work() + + # or pass a profiler instance. + profiler = StrobelightCLIFunctionProfiler(stop_at_error=False) + + @strobelight(profiler, sample_tags=["something", "another"]) + def work2(): + sum = 0 + for i in range(100000000): + sum += 1 + + work2() diff --git a/torch/_strobelight/examples/compile_time_profile_example.py b/torch/_strobelight/examples/compile_time_profile_example.py new file mode 100644 index 00000000000..33872720607 --- /dev/null +++ b/torch/_strobelight/examples/compile_time_profile_example.py @@ -0,0 +1,22 @@ +import torch + +from torch._strobelight.compile_time_profiler import StrobelightCompileTimeProfiler + +if __name__ == "__main__": + # You can pass TORCH_COMPILE_STROBELIGHT=True instead. + StrobelightCompileTimeProfiler.enable() + + def fn(x, y, z): + return x * y + z + + @torch.compile() + def work(n): + for i in range(3): + for j in range(5): + fn(torch.rand(n, n), torch.rand(n, n), torch.rand(n, n)) + + # Strobelight will be called only 3 times because dynamo will be disabled after + # 3rd iteration. + for i in range(3): + torch._dynamo.reset() + work(i) diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index aa80221e3bf..91b7a3722f5 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -6,9 +6,20 @@ import tempfile from typing import Any, Dict, Optional import torch +from torch._strobelight.compile_time_profiler import StrobelightCompileTimeProfiler log = logging.getLogger(__name__) +if os.environ.get("TORCH_COMPILE_STROBELIGHT", False): + import shutil + + if not shutil.which("strobeclient"): + log.info( + "TORCH_COMPILE_STROBELIGHT is true, but seems like you are not on a FB machine." + ) + else: + log.info("Strobelight profiler is enabled via environment variable") + StrobelightCompileTimeProfiler.enable() # this arbitrary-looking assortment of functionality is provided here # to have a central place for overrideable behavior. The motivating @@ -62,8 +73,6 @@ def throw_abstract_impl_not_imported_error(opname, module, context): ) -# Meta only, act as nop otherwise. -# # NB! This treats "skip" kwarg specially!! def compile_time_strobelight_meta(phase_name): def compile_time_strobelight_meta_inner(function): @@ -71,7 +80,9 @@ def compile_time_strobelight_meta(phase_name): def wrapper_function(*args, **kwargs): if "skip" in kwargs: kwargs["skip"] = kwargs["skip"] + 1 - return function(*args, **kwargs) + return StrobelightCompileTimeProfiler.profile_compile_time( + function, phase_name, *args, **kwargs + ) return wrapper_function