Enable TORCH_TRACE by default in all Tupperware like environments (#120915)

Summary:
This is a reimplemented version of the FB specific code in https://www.internalfb.com/diff/D54230697

The new strategy is that we unconditionally install an FB handler to trace_log logger (and always set level to DEBUG). When the first log message is emitted, we check the JK/filesystem to see if we should actually do logging. If we decide we don't do logging, we remove the handler from trace_log and are done.

build_only[github-export-checks,executorch,pytorch_benchmark,pytorch_quantization,pytorch_distributed,pytorch_distributed_gpu,pytorch_dynamo_inductor,pytorch_functorch,pytorch_fx2trt,pytorch_diff_train_tests_ads,glow_fb_pytorch_tests,training_platform,training_platform_compatibility,training_toolkit_applications,training_toolkit_examples,training_toolkit_model_optimization,dper3_pytorch,xplat_caffe2,pytorch_dev,android-pytorch-instrumentation-tests,smartpytorchgithub_first_try_merge,frl-target-determinator,f6-buck,training_platform_for_github,sigmoid_cpu,sigmoid_gpu,aiplatform_modelprocessing_for_github,accelerators_workloads_models_slimdsnn,ae_aotinductor_benchmark_test,aps_,aps_deterministic_ne_tests,dper_lib_silvertorch,torchrec,torchrec_fb,deeplearning_aot_inductor]

Test Plan:
sandcastle

```
buck2 test 'fbcode//mode/dev-nosan' fbcode//torchrec/inference/tests:test_single_gpu_executor -- --exact 'torchrec/inference/tests:test_single_gpu_executor - TorchDeployGPUTest.NestedModelSingleGPU'
buck2 test 'fbcode//mode/dev-nosan' fbcode//dper_lib/silvertorch/modules/dynamic_stats/tests:accumulators_test -- --exact 'dper_lib/silvertorch/modules/dynamic_stats/tests:accumulators_test - test_global_fixed_interval_accumulator (dper_lib.silvertorch.modules.dynamic_stats.tests.accumulators_test.GlobalFixedIntervalUnivalentAcculumatorTest)'
```

Also running a test flow with/without JK enabled

Differential Revision: D54275086

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120915
Approved by: https://github.com/yanboliang
This commit is contained in:
Edward Yang 2024-03-01 04:47:13 +00:00 committed by PyTorch MergeBot
parent 518a23bb03
commit 02a410ee12
2 changed files with 89 additions and 12 deletions

View file

@ -6,6 +6,7 @@ import logging
import os
import os.path
import re
import tempfile
from dataclasses import dataclass, field
from importlib import __import__
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
@ -844,7 +845,6 @@ def _reset_logs():
log.setLevel(logging.NOTSET)
log.propagate = True
trace_log.setLevel(logging.WARNING)
trace_log.propagate = False
_clear_handlers(trace_log)
@ -911,19 +911,26 @@ def _init_logs(log_file_name=None):
handler: Optional[logging.Handler] = None
if trace_file_name is not None:
handler = logging.FileHandler(trace_file_name)
if handler is not None:
trace_log.setLevel(logging.DEBUG)
trace_log_handler = _track_handler(handler)
trace_log_handler.setFormatter(TorchLogsFormatter(trace=True))
trace_log.addHandler(trace_log_handler)
else:
# This handler may remove itself if we are not actually in an FB
# environment. This allows us to defer actually initializing it until
# we actually need to log anything. This is important because JK
# initializes a C++ singleton, which will pork our process if we
# subsequently fork.
handler = LazyFbTraceHandler()
# This log is ALWAYS at debug level. We will additionally test if there
# are any handlers before deciding to actually call logging on this. Do
# not manually call
trace_log.setLevel(logging.DEBUG)
trace_log_handler = _track_handler(handler)
trace_log_handler.setFormatter(TorchLogsFormatter(trace=True))
trace_log.addHandler(trace_log_handler)
class FreshFileHandler(logging.StreamHandler):
class LazyFbTraceHandler(logging.StreamHandler):
"""Like FileHandler, but the file is allocated lazily only upon the first log message"""
def __init__(self, filename_cb):
self.filename_cb = filename_cb
self.filename = None
def __init__(self):
# This is implemented in the same way that delay is implemented on
# FileHandler
logging.Handler.__init__(self)
@ -954,8 +961,50 @@ class FreshFileHandler(logging.StreamHandler):
def emit(self, record):
if self.stream is None:
# TODO: more robust is_fbcode test
import torch.version
TRACE_LOG_DIR = "/logs"
open_func = self._builtin_open
self.stream = open_func(self.filename_cb(), "w")
ok = False
import torch.version as torch_version
if hasattr(torch_version, "git_version"):
log.info("LazyFbTraceHandler: disabled because not fbcode")
elif not torch._utils_internal.justknobs_check("pytorch/trace:enable"):
log.info(
"LazyFbTraceHandler: disabled because justknobs_check('pytorch/trace:enable') returned False"
)
elif not os.path.exists(TRACE_LOG_DIR):
log.info(
"LazyFbTraceHandler: disabled because %s does not exist",
TRACE_LOG_DIR,
)
elif not os.access(TRACE_LOG_DIR, os.W_OK):
log.info(
"LazyFbTraceHandler: disabled because %s is not writeable",
TRACE_LOG_DIR,
)
else:
ok = True
if ok:
ranksuffix = ""
if dist.is_available() and dist.is_initialized():
ranksuffix = f"rank_{dist.get_rank()}_"
self.stream = tempfile.NamedTemporaryFile(
mode="w+",
suffix=".log",
prefix=f"dedicated_log_torch_trace_{ranksuffix}",
dir=TRACE_LOG_DIR,
delete=False,
)
log.info("LazyFbTraceHandler: logging to %s", self.stream.name)
else:
# We go poof, remove and no-op
trace_log.removeHandler(self)
return
if self.stream:
super().emit(record)
@ -1004,7 +1053,9 @@ def trace_structured(
assert callable(
payload_fn
), f"payload_fn should be callable, but got {type(payload_fn)}"
if trace_log.isEnabledFor(logging.DEBUG):
# trace_log never propagates and is ALWAYS DEBUG, so also check that there
# are handlers instead of checking the log level
if trace_log.handlers:
record: Dict[str, object] = {}
record[name] = metadata_fn()
if not suppress_context:

View file

@ -95,6 +95,32 @@ def log_export_usage(**kwargs):
pass
def justknobs_check(name: str) -> bool:
"""
This function can be used to killswitch functionality in FB prod,
where you can toggle this value to False in JK without having to
do a code push. In OSS, we always have everything turned on all
the time, because downstream users can simply choose to not update
PyTorch. (If more fine-grained enable/disable is needed, we could
potentially have a map we lookup name in to toggle behavior. But
the point is that it's all tied to source code in OSS, since there's
no live server to query.)
This is the bare minimum functionality I needed to do some killswitches.
We have a more detailed plan at
https://docs.google.com/document/d/1Ukerh9_42SeGh89J-tGtecpHBPwGlkQ043pddkKb3PU/edit
In particular, in some circumstances it may be necessary to read in
a knob once at process start, and then use it consistently for the
rest of the process. Future functionality will codify these patterns
into a better high level API.
WARNING: Do NOT call this function at module import time, JK is not
fork safe and you will break anyone who forks the process and then
hits JK again.
"""
return True
@functools.lru_cache(None)
def max_clock_rate():
from triton.testing import nvsmi