mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[fr][rfc] use a logger to control output for flight recorder analyzer (#139656)
Summary: Use a logger to control output to console. This is useful for hiding out debug/detail messages from the console v/s showing everything together.
Test Plan:
Ran `torchfrtrace` with various switches.
The `-v` verbose swtch
```
torchfrtrace --prefix "trace_" /tmp/ -v
loaded 2 files in 0.2567298412322998s
built groups, memberships
Not all ranks joining collective 3 at entry 2
group info: 0:default_pg
collective: nccl:all_reduce
missing ranks: {1}
input sizes: [[4, 5]]
output sizes: [[4, 5]]
expected ranks: 2
collective state: scheduled
collective stack trace:
<module> at /home/cpio/test/c.py:66
appending a non-matching collective
built collectives, nccl_calls
Groups
id desc size
-------------------- ---------- ------
09000494312501845833 default_pg 2
Memberships
group_id global_rank
-------------------- -------------
09000494312501845833 0
09000494312501845833 1
Collectives
id group_id
---- ----------
0 0
1 0
NCCLCalls
id collective_id group_id global_rank traceback_id collective_type sizes
---- --------------- ---------- ------------- -------------- ----------------- --------
0 0 0 0 0 nccl:all_reduce [[3, 4]]
1 0 0 1 0 nccl:all_reduce [[3, 4]]
2 1 0 0 0 nccl:all_reduce [[3, 4]]
3 1 0 1 0 nccl:all_reduce [[3, 4]]
4 0 0 0 nccl:all_reduce [[4, 5]]
```
Without the verbose switch
```
❯ torchfrtrace --prefix "trace_" /tmp/
Not all ranks joining collective 3 at entry 2
group info: 0:default_pg
collective: nccl:all_reduce
missing ranks: {1}
input sizes: [[4, 5]]
output sizes: [[4, 5]]
expected ranks: 2
collective state: scheduled
collective stack trace:
<module> at /home/cpio/test/c.py:66
```
With the `-j` switch:
```
❯ torchfrtrace --prefix "trace_" /tmp/ -j
Rank 0 Rank 1
------------------------------------------------- -------------------------------------------------
all_reduce(input_sizes=[[3, 4]], state=completed) all_reduce(input_sizes=[[3, 4]], state=completed)
all_reduce(input_sizes=[[3, 4]], state=completed) all_reduce(input_sizes=[[3, 4]], state=completed)
all_reduce(input_sizes=[[4, 5]], state=scheduled)
```
Differential Revision: D65438520
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139656
Approved by: https://github.com/fduwjj
This commit is contained in:
parent
b9f0563aaf
commit
d549ddfb14
4 changed files with 123 additions and 40 deletions
|
|
@ -25,6 +25,7 @@ from tools.flight_recorder.components.utils import (
|
|||
check_size_alltoall,
|
||||
check_version,
|
||||
find_coalesced_group,
|
||||
FlightRecorderLogger,
|
||||
format_frames,
|
||||
get_version_detail,
|
||||
just_print_entries,
|
||||
|
|
@ -33,10 +34,14 @@ from tools.flight_recorder.components.utils import (
|
|||
)
|
||||
|
||||
|
||||
# Set up logging
|
||||
logger: FlightRecorderLogger = FlightRecorderLogger()
|
||||
|
||||
|
||||
try:
|
||||
from tabulate import tabulate
|
||||
except ModuleNotFoundError:
|
||||
print("tabulate is not installed. Proceeding without it.")
|
||||
logger.warning("tabulate is not installed. Proceeding without it.")
|
||||
|
||||
# Define a no-op tabulate function
|
||||
def tabulate(data: Any, headers: Any = None) -> Any: # type: ignore[misc]
|
||||
|
|
@ -311,13 +316,20 @@ def build_collectives(
|
|||
# case one: not every rank join the collective or in the flight recorder.
|
||||
if (candidate_ranks | found_ranks) != expected_ranks:
|
||||
mismatch[pg_name] += 1
|
||||
print(
|
||||
f"Not all ranks joining collective {collective_seq_id} at entry {record_id}",
|
||||
f" for group {pg_desc} collective {profiling_name} ",
|
||||
f"Missing ranks are {expected_ranks - (candidate_ranks | found_ranks)} ",
|
||||
f"{input_sizes} {output_sizes} {len(expected_ranks)} {collective_state} ",
|
||||
f"\nCollective stack traces: \n{collective_frames}",
|
||||
logger.info(
|
||||
"Not all ranks joining collective %s at entry %s",
|
||||
collective_seq_id,
|
||||
record_id,
|
||||
)
|
||||
logger.info("group info: %s", pg_desc)
|
||||
logger.info("collective: %s", profiling_name)
|
||||
missing_ranks = expected_ranks - (candidate_ranks | found_ranks)
|
||||
logger.info("missing ranks: %s", missing_ranks)
|
||||
logger.info("input sizes: %s", input_sizes)
|
||||
logger.info("output sizes: %s", output_sizes)
|
||||
logger.info("expected ranks: %d", len(expected_ranks))
|
||||
logger.info("collective state: %s", collective_state)
|
||||
logger.info("collective stack trace: \n %s", collective_frames)
|
||||
candidate_ranks.update(found_ranks)
|
||||
candidate_idx.update(found_idx)
|
||||
found_idx.clear()
|
||||
|
|
@ -338,13 +350,18 @@ def build_collectives(
|
|||
if fail_check:
|
||||
# When we see errors in all_to_all, it's hard to tell which rank is the source of the error.
|
||||
mismatch[pg_name] += 1
|
||||
print(
|
||||
f"Input/output mismatch in the collective {collective_seq_id} ",
|
||||
f"at entry {record_id} for group {pg_desc} collective {profiling_name} ",
|
||||
f"input_numel {input_numel} output_numel {output_numel} ",
|
||||
f"{input_sizes} {output_sizes} {len(expected_ranks)} {collective_state} ",
|
||||
f"\nCollective stack traces: \n{collective_frames}",
|
||||
logger.info(
|
||||
"Input/output mismatch in the collective %s at entry %s",
|
||||
collective_seq_id,
|
||||
record_id,
|
||||
)
|
||||
logger.info("group info: %s", pg_desc)
|
||||
logger.info("collective: %s", profiling_name)
|
||||
logger.info("input sizes: %s", input_sizes)
|
||||
logger.info("output sizes: %s", output_sizes)
|
||||
logger.info("expected ranks: %d", len(expected_ranks))
|
||||
logger.info("collective state: %s", collective_state)
|
||||
logger.info("collective stack trace: \n%s", collective_frames)
|
||||
candidate_ranks.update(found_ranks)
|
||||
candidate_idx.update(found_idx)
|
||||
found_idx.clear()
|
||||
|
|
@ -366,13 +383,17 @@ def build_collectives(
|
|||
error_msg = ", ".join(
|
||||
f"Culprit rank {error[0]}; {str(error[1])}" for error in errors
|
||||
)
|
||||
print(
|
||||
f"Collective {collective_seq_id} at entry {record_id} errors",
|
||||
f" for group {pg_desc} collective {profiling_name} ",
|
||||
f"{input_sizes} {output_sizes} {len(expected_ranks)} {collective_state} ",
|
||||
f"\nFound errors: {error_msg}.\n",
|
||||
f"\nCollective stack traces: \n{collective_frames} ",
|
||||
logger.info(
|
||||
"Collective %s at entry %s errors", collective_seq_id, record_id
|
||||
)
|
||||
logger.info("group info: %s", pg_desc)
|
||||
logger.info("collective: %s", profiling_name)
|
||||
logger.info("input sizes: %s", input_sizes)
|
||||
logger.info("output sizes: %s", output_sizes)
|
||||
logger.info("expected ranks: %d", len(expected_ranks))
|
||||
logger.info("collective state: %s", collective_state)
|
||||
logger.info("error message: %s", error_msg)
|
||||
logger.info("collective stack trace: \n%s", collective_frames)
|
||||
candidate_ranks.update(found_ranks)
|
||||
candidate_idx.update(found_idx)
|
||||
found_idx.clear()
|
||||
|
|
@ -402,7 +423,7 @@ def build_collectives(
|
|||
# -> since its not a complete collective, no entry goes into collectives but we still record a nccl call
|
||||
# TODO should there be a way to mark 'mismatches'?
|
||||
else:
|
||||
print("appending a non-matching collective")
|
||||
logger.debug("appending a non-matching collective")
|
||||
# TODO: figure out a better for mismatch.
|
||||
# Also, shall we add seq Id as well?
|
||||
for r in candidate_ranks:
|
||||
|
|
@ -418,7 +439,9 @@ def build_collectives(
|
|||
)
|
||||
|
||||
if mismatch[pg_name] > MISMATCH_TAIL:
|
||||
print(f"Too many mismatches for process_group {pg_name}:{desc}, aborting")
|
||||
logger.error(
|
||||
"Too many mismatches for process_group %s: %s aborting", pg_name, desc
|
||||
)
|
||||
sys.exit(-1)
|
||||
|
||||
return tracebacks, collectives, nccl_calls
|
||||
|
|
@ -445,7 +468,7 @@ def build_db(
|
|||
groups, _groups, memberships, _memberships, _pg_guids = build_groups_memberships(
|
||||
pg_config
|
||||
)
|
||||
print("built groups, memberships")
|
||||
logger.debug("built groups, memberships")
|
||||
|
||||
if args.just_print_entries:
|
||||
just_print_entries(entries, _groups, _memberships, _pg_guids, args)
|
||||
|
|
@ -456,12 +479,16 @@ def build_db(
|
|||
tracebacks, collectives, nccl_calls = build_collectives(
|
||||
entries, _groups, _memberships, _pg_guids, version
|
||||
)
|
||||
print("built collectives, nccl_calls")
|
||||
logger.debug("built collectives, nccl_calls")
|
||||
if args.verbose:
|
||||
print("Groups\n", tabulate(groups, headers=Group._fields))
|
||||
print("Memberships\n", tabulate(memberships, headers=Membership._fields))
|
||||
print("Collectives\n", tabulate(collectives, headers=Collective._fields))
|
||||
print("NCCLCalls\n", tabulate(nccl_calls, headers=NCCLCall._fields))
|
||||
logger.debug("Groups")
|
||||
logger.debug(tabulate(groups, headers=Group._fields))
|
||||
logger.debug("Memberships")
|
||||
logger.debug(tabulate(memberships, headers=Membership._fields))
|
||||
logger.debug("Collectives")
|
||||
logger.debug(tabulate(collectives, headers=Collective._fields))
|
||||
logger.debug("NCCLCalls")
|
||||
logger.debug(tabulate(nccl_calls, headers=NCCLCall._fields))
|
||||
db = Database(
|
||||
tracebacks=tracebacks,
|
||||
collectives=collectives,
|
||||
|
|
|
|||
|
|
@ -5,8 +5,14 @@
|
|||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from tools.flight_recorder.components.utils import FlightRecorderLogger
|
||||
|
||||
|
||||
logger = FlightRecorderLogger()
|
||||
|
||||
|
||||
class JobConfig:
|
||||
"""
|
||||
|
|
@ -64,4 +70,6 @@ class JobConfig:
|
|||
assert (
|
||||
args.just_print_entries
|
||||
), "Not support selecting pg filters without printing entries"
|
||||
if args.verbose:
|
||||
logger.set_log_level(logging.DEBUG)
|
||||
return args
|
||||
|
|
|
|||
|
|
@ -13,6 +13,11 @@ import typing
|
|||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from tools.flight_recorder.components.utils import FlightRecorderLogger
|
||||
|
||||
|
||||
logger: FlightRecorderLogger = FlightRecorderLogger()
|
||||
|
||||
|
||||
def read_dump(prefix: str, filename: str) -> Dict[str, Union[str, int, List[Any]]]:
|
||||
basename = os.path.basename(filename)
|
||||
|
|
@ -52,7 +57,7 @@ def _determine_prefix(files: List[str]) -> str:
|
|||
possible_prefixes[p].add(int(r))
|
||||
if len(possible_prefixes) == 1:
|
||||
prefix = next(iter(possible_prefixes))
|
||||
print(f"Inferred common prefix {prefix}")
|
||||
logger.debug("Inferred common prefix %s", prefix)
|
||||
return prefix
|
||||
else:
|
||||
raise ValueError(
|
||||
|
|
@ -68,6 +73,7 @@ def read_dir(
|
|||
details = {}
|
||||
t0 = time.time()
|
||||
version = ""
|
||||
filecount = 0
|
||||
assert os.path.isdir(folder), f"folder {folder} does not exist"
|
||||
for root, _, files in os.walk(folder):
|
||||
if prefix is None:
|
||||
|
|
@ -76,9 +82,10 @@ def read_dir(
|
|||
if f.find(prefix) != 0:
|
||||
continue
|
||||
details[f] = read_dump(prefix, os.path.join(root, f))
|
||||
filecount += 1
|
||||
if not version:
|
||||
version = str(details[f]["version"])
|
||||
tb = time.time()
|
||||
assert len(details) > 0, f"no files loaded from {folder} with prefix {prefix}"
|
||||
print(f"loaded {len(files)} files in {tb - t0}s")
|
||||
logger.debug("loaded %s files in %ss", filecount, tb - t0)
|
||||
return details, version
|
||||
|
|
|
|||
|
|
@ -5,22 +5,62 @@
|
|||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import Any, Dict, List, Set, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from tools.flight_recorder.components.types import (
|
||||
Group,
|
||||
MatchState,
|
||||
Membership,
|
||||
Op,
|
||||
P2P,
|
||||
)
|
||||
from .types import Group, MatchState, Membership, Op, P2P
|
||||
|
||||
|
||||
class FlightRecorderLogger:
|
||||
_instance: Optional[Any] = None
|
||||
logger: logging.Logger
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.logger: logging.Logger = logging.getLogger("Flight Recorder")
|
||||
|
||||
def __new__(cls) -> Any:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance.logger = logging.getLogger("Flight Recorder")
|
||||
cls._instance.logger.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter("%(message)s")
|
||||
ch = logging.StreamHandler()
|
||||
ch.setFormatter(formatter)
|
||||
cls._instance.logger.addHandler(ch)
|
||||
return cls._instance
|
||||
|
||||
def set_log_level(self, level: int) -> None:
|
||||
self.logger.setLevel(level)
|
||||
|
||||
@property
|
||||
def debug(self) -> Callable[..., None]:
|
||||
return self.logger.debug
|
||||
|
||||
@property
|
||||
def info(self) -> Callable[..., None]:
|
||||
return self.logger.info
|
||||
|
||||
@property
|
||||
def warning(self) -> Callable[..., None]:
|
||||
return self.logger.warning
|
||||
|
||||
@property
|
||||
def error(self) -> Callable[..., None]:
|
||||
return self.logger.error
|
||||
|
||||
@property
|
||||
def critical(self) -> Callable[..., None]:
|
||||
return self.logger.critical
|
||||
|
||||
|
||||
logger = FlightRecorderLogger()
|
||||
|
||||
|
||||
try:
|
||||
from tabulate import tabulate
|
||||
except ModuleNotFoundError:
|
||||
print("tabulate is not installed. Proceeding without it.")
|
||||
logger.debug("tabulate is not installed. Proceeding without it.")
|
||||
|
||||
|
||||
def format_frame(frame: Dict[str, str]) -> str:
|
||||
|
|
@ -121,7 +161,8 @@ def match_coalesced_groups(
|
|||
row = []
|
||||
i += 1
|
||||
title = "Match" if match else "MISMATCH"
|
||||
print(f"{title}\n", tabulate(table)) # type: ignore[operator]
|
||||
logger.info("%s \n", title)
|
||||
logger.info("%s", tabulate(table)) # type: ignore[operator]
|
||||
|
||||
# TODO can't verify seq_id bc there might have been valid seq deltas between ranks even within a pg.
|
||||
for op_list in all_ops.values():
|
||||
|
|
@ -248,7 +289,7 @@ def just_print_entries(
|
|||
if progress:
|
||||
rows.append(row)
|
||||
|
||||
print(tabulate(rows, headers=headers))
|
||||
logger.info(tabulate(rows, headers=headers))
|
||||
|
||||
|
||||
def check_no_missing_dump_files(
|
||||
|
|
|
|||
Loading…
Reference in a new issue