[fr][rfc] use a logger to control output for flight recorder analyzer (#139656)

Summary: Use a logger to control output to console. This is useful for hiding out debug/detail messages from the console v/s showing everything together.

Test Plan:
Ran `torchfrtrace` with various switches.

The `-v` verbose swtch
```
torchfrtrace --prefix "trace_" /tmp/ -v
loaded 2 files in 0.2567298412322998s
built groups, memberships
Not all ranks joining collective 3 at entry 2
group info: 0:default_pg
collective: nccl:all_reduce
missing ranks: {1}
input sizes: [[4, 5]]
output sizes: [[4, 5]]
expected ranks: 2
collective state: scheduled
collective stack trace:
 <module> at /home/cpio/test/c.py:66
appending a non-matching collective
built collectives, nccl_calls
Groups
                  id  desc          size
--------------------  ----------  ------
09000494312501845833  default_pg       2
Memberships
            group_id    global_rank
--------------------  -------------
09000494312501845833              0
09000494312501845833              1
Collectives
  id    group_id
----  ----------
   0           0
   1           0
NCCLCalls
  id    collective_id    group_id    global_rank    traceback_id  collective_type    sizes
----  ---------------  ----------  -------------  --------------  -----------------  --------
   0                0           0              0               0  nccl:all_reduce    [[3, 4]]
   1                0           0              1               0  nccl:all_reduce    [[3, 4]]
   2                1           0              0               0  nccl:all_reduce    [[3, 4]]
   3                1           0              1               0  nccl:all_reduce    [[3, 4]]
   4                            0              0               0  nccl:all_reduce    [[4, 5]]
```

Without the verbose switch
```
❯ torchfrtrace --prefix "trace_" /tmp/
Not all ranks joining collective 3 at entry 2
group info: 0:default_pg
collective: nccl:all_reduce
missing ranks: {1}
input sizes: [[4, 5]]
output sizes: [[4, 5]]
expected ranks: 2
collective state: scheduled
collective stack trace:
 <module> at /home/cpio/test/c.py:66
```

With the `-j` switch:
```
❯ torchfrtrace --prefix "trace_" /tmp/ -j
Rank 0                                             Rank 1
-------------------------------------------------  -------------------------------------------------
all_reduce(input_sizes=[[3, 4]], state=completed)  all_reduce(input_sizes=[[3, 4]], state=completed)
all_reduce(input_sizes=[[3, 4]], state=completed)  all_reduce(input_sizes=[[3, 4]], state=completed)
all_reduce(input_sizes=[[4, 5]], state=scheduled)
```

Differential Revision: D65438520

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139656
Approved by: https://github.com/fduwjj
This commit is contained in:
Chirag Pandya 2024-11-05 20:14:18 +00:00 committed by PyTorch MergeBot
parent b9f0563aaf
commit d549ddfb14
4 changed files with 123 additions and 40 deletions

View file

@ -25,6 +25,7 @@ from tools.flight_recorder.components.utils import (
check_size_alltoall,
check_version,
find_coalesced_group,
FlightRecorderLogger,
format_frames,
get_version_detail,
just_print_entries,
@ -33,10 +34,14 @@ from tools.flight_recorder.components.utils import (
)
# Set up logging
logger: FlightRecorderLogger = FlightRecorderLogger()
try:
from tabulate import tabulate
except ModuleNotFoundError:
print("tabulate is not installed. Proceeding without it.")
logger.warning("tabulate is not installed. Proceeding without it.")
# Define a no-op tabulate function
def tabulate(data: Any, headers: Any = None) -> Any: # type: ignore[misc]
@ -311,13 +316,20 @@ def build_collectives(
# case one: not every rank join the collective or in the flight recorder.
if (candidate_ranks | found_ranks) != expected_ranks:
mismatch[pg_name] += 1
print(
f"Not all ranks joining collective {collective_seq_id} at entry {record_id}",
f" for group {pg_desc} collective {profiling_name} ",
f"Missing ranks are {expected_ranks - (candidate_ranks | found_ranks)} ",
f"{input_sizes} {output_sizes} {len(expected_ranks)} {collective_state} ",
f"\nCollective stack traces: \n{collective_frames}",
logger.info(
"Not all ranks joining collective %s at entry %s",
collective_seq_id,
record_id,
)
logger.info("group info: %s", pg_desc)
logger.info("collective: %s", profiling_name)
missing_ranks = expected_ranks - (candidate_ranks | found_ranks)
logger.info("missing ranks: %s", missing_ranks)
logger.info("input sizes: %s", input_sizes)
logger.info("output sizes: %s", output_sizes)
logger.info("expected ranks: %d", len(expected_ranks))
logger.info("collective state: %s", collective_state)
logger.info("collective stack trace: \n %s", collective_frames)
candidate_ranks.update(found_ranks)
candidate_idx.update(found_idx)
found_idx.clear()
@ -338,13 +350,18 @@ def build_collectives(
if fail_check:
# When we see errors in all_to_all, it's hard to tell which rank is the source of the error.
mismatch[pg_name] += 1
print(
f"Input/output mismatch in the collective {collective_seq_id} ",
f"at entry {record_id} for group {pg_desc} collective {profiling_name} ",
f"input_numel {input_numel} output_numel {output_numel} ",
f"{input_sizes} {output_sizes} {len(expected_ranks)} {collective_state} ",
f"\nCollective stack traces: \n{collective_frames}",
logger.info(
"Input/output mismatch in the collective %s at entry %s",
collective_seq_id,
record_id,
)
logger.info("group info: %s", pg_desc)
logger.info("collective: %s", profiling_name)
logger.info("input sizes: %s", input_sizes)
logger.info("output sizes: %s", output_sizes)
logger.info("expected ranks: %d", len(expected_ranks))
logger.info("collective state: %s", collective_state)
logger.info("collective stack trace: \n%s", collective_frames)
candidate_ranks.update(found_ranks)
candidate_idx.update(found_idx)
found_idx.clear()
@ -366,13 +383,17 @@ def build_collectives(
error_msg = ", ".join(
f"Culprit rank {error[0]}; {str(error[1])}" for error in errors
)
print(
f"Collective {collective_seq_id} at entry {record_id} errors",
f" for group {pg_desc} collective {profiling_name} ",
f"{input_sizes} {output_sizes} {len(expected_ranks)} {collective_state} ",
f"\nFound errors: {error_msg}.\n",
f"\nCollective stack traces: \n{collective_frames} ",
logger.info(
"Collective %s at entry %s errors", collective_seq_id, record_id
)
logger.info("group info: %s", pg_desc)
logger.info("collective: %s", profiling_name)
logger.info("input sizes: %s", input_sizes)
logger.info("output sizes: %s", output_sizes)
logger.info("expected ranks: %d", len(expected_ranks))
logger.info("collective state: %s", collective_state)
logger.info("error message: %s", error_msg)
logger.info("collective stack trace: \n%s", collective_frames)
candidate_ranks.update(found_ranks)
candidate_idx.update(found_idx)
found_idx.clear()
@ -402,7 +423,7 @@ def build_collectives(
# -> since its not a complete collective, no entry goes into collectives but we still record a nccl call
# TODO should there be a way to mark 'mismatches'?
else:
print("appending a non-matching collective")
logger.debug("appending a non-matching collective")
# TODO: figure out a better for mismatch.
# Also, shall we add seq Id as well?
for r in candidate_ranks:
@ -418,7 +439,9 @@ def build_collectives(
)
if mismatch[pg_name] > MISMATCH_TAIL:
print(f"Too many mismatches for process_group {pg_name}:{desc}, aborting")
logger.error(
"Too many mismatches for process_group %s: %s aborting", pg_name, desc
)
sys.exit(-1)
return tracebacks, collectives, nccl_calls
@ -445,7 +468,7 @@ def build_db(
groups, _groups, memberships, _memberships, _pg_guids = build_groups_memberships(
pg_config
)
print("built groups, memberships")
logger.debug("built groups, memberships")
if args.just_print_entries:
just_print_entries(entries, _groups, _memberships, _pg_guids, args)
@ -456,12 +479,16 @@ def build_db(
tracebacks, collectives, nccl_calls = build_collectives(
entries, _groups, _memberships, _pg_guids, version
)
print("built collectives, nccl_calls")
logger.debug("built collectives, nccl_calls")
if args.verbose:
print("Groups\n", tabulate(groups, headers=Group._fields))
print("Memberships\n", tabulate(memberships, headers=Membership._fields))
print("Collectives\n", tabulate(collectives, headers=Collective._fields))
print("NCCLCalls\n", tabulate(nccl_calls, headers=NCCLCall._fields))
logger.debug("Groups")
logger.debug(tabulate(groups, headers=Group._fields))
logger.debug("Memberships")
logger.debug(tabulate(memberships, headers=Membership._fields))
logger.debug("Collectives")
logger.debug(tabulate(collectives, headers=Collective._fields))
logger.debug("NCCLCalls")
logger.debug(tabulate(nccl_calls, headers=NCCLCall._fields))
db = Database(
tracebacks=tracebacks,
collectives=collectives,

View file

@ -5,8 +5,14 @@
# LICENSE file in the root directory of this source tree.
import argparse
import logging
from typing import Optional, Sequence
from tools.flight_recorder.components.utils import FlightRecorderLogger
logger = FlightRecorderLogger()
class JobConfig:
"""
@ -64,4 +70,6 @@ class JobConfig:
assert (
args.just_print_entries
), "Not support selecting pg filters without printing entries"
if args.verbose:
logger.set_log_level(logging.DEBUG)
return args

View file

@ -13,6 +13,11 @@ import typing
from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from tools.flight_recorder.components.utils import FlightRecorderLogger
logger: FlightRecorderLogger = FlightRecorderLogger()
def read_dump(prefix: str, filename: str) -> Dict[str, Union[str, int, List[Any]]]:
basename = os.path.basename(filename)
@ -52,7 +57,7 @@ def _determine_prefix(files: List[str]) -> str:
possible_prefixes[p].add(int(r))
if len(possible_prefixes) == 1:
prefix = next(iter(possible_prefixes))
print(f"Inferred common prefix {prefix}")
logger.debug("Inferred common prefix %s", prefix)
return prefix
else:
raise ValueError(
@ -68,6 +73,7 @@ def read_dir(
details = {}
t0 = time.time()
version = ""
filecount = 0
assert os.path.isdir(folder), f"folder {folder} does not exist"
for root, _, files in os.walk(folder):
if prefix is None:
@ -76,9 +82,10 @@ def read_dir(
if f.find(prefix) != 0:
continue
details[f] = read_dump(prefix, os.path.join(root, f))
filecount += 1
if not version:
version = str(details[f]["version"])
tb = time.time()
assert len(details) > 0, f"no files loaded from {folder} with prefix {prefix}"
print(f"loaded {len(files)} files in {tb - t0}s")
logger.debug("loaded %s files in %ss", filecount, tb - t0)
return details, version

View file

@ -5,22 +5,62 @@
# LICENSE file in the root directory of this source tree.
import argparse
import logging
import math
from typing import Any, Dict, List, Set, Tuple
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
from tools.flight_recorder.components.types import (
Group,
MatchState,
Membership,
Op,
P2P,
)
from .types import Group, MatchState, Membership, Op, P2P
class FlightRecorderLogger:
_instance: Optional[Any] = None
logger: logging.Logger
def __init__(self) -> None:
self.logger: logging.Logger = logging.getLogger("Flight Recorder")
def __new__(cls) -> Any:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.logger = logging.getLogger("Flight Recorder")
cls._instance.logger.setLevel(logging.INFO)
formatter = logging.Formatter("%(message)s")
ch = logging.StreamHandler()
ch.setFormatter(formatter)
cls._instance.logger.addHandler(ch)
return cls._instance
def set_log_level(self, level: int) -> None:
self.logger.setLevel(level)
@property
def debug(self) -> Callable[..., None]:
return self.logger.debug
@property
def info(self) -> Callable[..., None]:
return self.logger.info
@property
def warning(self) -> Callable[..., None]:
return self.logger.warning
@property
def error(self) -> Callable[..., None]:
return self.logger.error
@property
def critical(self) -> Callable[..., None]:
return self.logger.critical
logger = FlightRecorderLogger()
try:
from tabulate import tabulate
except ModuleNotFoundError:
print("tabulate is not installed. Proceeding without it.")
logger.debug("tabulate is not installed. Proceeding without it.")
def format_frame(frame: Dict[str, str]) -> str:
@ -121,7 +161,8 @@ def match_coalesced_groups(
row = []
i += 1
title = "Match" if match else "MISMATCH"
print(f"{title}\n", tabulate(table)) # type: ignore[operator]
logger.info("%s \n", title)
logger.info("%s", tabulate(table)) # type: ignore[operator]
# TODO can't verify seq_id bc there might have been valid seq deltas between ranks even within a pg.
for op_list in all_ops.values():
@ -248,7 +289,7 @@ def just_print_entries(
if progress:
rows.append(row)
print(tabulate(rows, headers=headers))
logger.info(tabulate(rows, headers=headers))
def check_no_missing_dump_files(