From d549ddfb149189aaf734c1b287c3d9f0f6b3cd9d Mon Sep 17 00:00:00 2001 From: Chirag Pandya Date: Tue, 5 Nov 2024 20:14:18 +0000 Subject: [PATCH] [fr][rfc] use a logger to control output for flight recorder analyzer (#139656) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Use a logger to control output to console. This is useful for hiding out debug/detail messages from the console v/s showing everything together. Test Plan: Ran `torchfrtrace` with various switches. The `-v` verbose swtch ``` torchfrtrace --prefix "trace_" /tmp/ -v loaded 2 files in 0.2567298412322998s built groups, memberships Not all ranks joining collective 3 at entry 2 group info: 0:default_pg collective: nccl:all_reduce missing ranks: {1} input sizes: [[4, 5]] output sizes: [[4, 5]] expected ranks: 2 collective state: scheduled collective stack trace: at /home/cpio/test/c.py:66 appending a non-matching collective built collectives, nccl_calls Groups id desc size -------------------- ---------- ------ 09000494312501845833 default_pg 2 Memberships group_id global_rank -------------------- ------------- 09000494312501845833 0 09000494312501845833 1 Collectives id group_id ---- ---------- 0 0 1 0 NCCLCalls id collective_id group_id global_rank traceback_id collective_type sizes ---- --------------- ---------- ------------- -------------- ----------------- -------- 0 0 0 0 0 nccl:all_reduce [[3, 4]] 1 0 0 1 0 nccl:all_reduce [[3, 4]] 2 1 0 0 0 nccl:all_reduce [[3, 4]] 3 1 0 1 0 nccl:all_reduce [[3, 4]] 4 0 0 0 nccl:all_reduce [[4, 5]] ``` Without the verbose switch ``` ❯ torchfrtrace --prefix "trace_" /tmp/ Not all ranks joining collective 3 at entry 2 group info: 0:default_pg collective: nccl:all_reduce missing ranks: {1} input sizes: [[4, 5]] output sizes: [[4, 5]] expected ranks: 2 collective state: scheduled collective stack trace: at /home/cpio/test/c.py:66 ``` With the `-j` switch: ``` ❯ torchfrtrace --prefix "trace_" /tmp/ -j Rank 0 Rank 1 ------------------------------------------------- ------------------------------------------------- all_reduce(input_sizes=[[3, 4]], state=completed) all_reduce(input_sizes=[[3, 4]], state=completed) all_reduce(input_sizes=[[3, 4]], state=completed) all_reduce(input_sizes=[[3, 4]], state=completed) all_reduce(input_sizes=[[4, 5]], state=scheduled) ``` Differential Revision: D65438520 Pull Request resolved: https://github.com/pytorch/pytorch/pull/139656 Approved by: https://github.com/fduwjj --- tools/flight_recorder/components/builder.py | 81 ++++++++++++------- .../components/config_manager.py | 8 ++ tools/flight_recorder/components/loader.py | 11 ++- tools/flight_recorder/components/utils.py | 63 ++++++++++++--- 4 files changed, 123 insertions(+), 40 deletions(-) diff --git a/tools/flight_recorder/components/builder.py b/tools/flight_recorder/components/builder.py index bd5b65e62c2..639c1246a97 100644 --- a/tools/flight_recorder/components/builder.py +++ b/tools/flight_recorder/components/builder.py @@ -25,6 +25,7 @@ from tools.flight_recorder.components.utils import ( check_size_alltoall, check_version, find_coalesced_group, + FlightRecorderLogger, format_frames, get_version_detail, just_print_entries, @@ -33,10 +34,14 @@ from tools.flight_recorder.components.utils import ( ) +# Set up logging +logger: FlightRecorderLogger = FlightRecorderLogger() + + try: from tabulate import tabulate except ModuleNotFoundError: - print("tabulate is not installed. Proceeding without it.") + logger.warning("tabulate is not installed. Proceeding without it.") # Define a no-op tabulate function def tabulate(data: Any, headers: Any = None) -> Any: # type: ignore[misc] @@ -311,13 +316,20 @@ def build_collectives( # case one: not every rank join the collective or in the flight recorder. if (candidate_ranks | found_ranks) != expected_ranks: mismatch[pg_name] += 1 - print( - f"Not all ranks joining collective {collective_seq_id} at entry {record_id}", - f" for group {pg_desc} collective {profiling_name} ", - f"Missing ranks are {expected_ranks - (candidate_ranks | found_ranks)} ", - f"{input_sizes} {output_sizes} {len(expected_ranks)} {collective_state} ", - f"\nCollective stack traces: \n{collective_frames}", + logger.info( + "Not all ranks joining collective %s at entry %s", + collective_seq_id, + record_id, ) + logger.info("group info: %s", pg_desc) + logger.info("collective: %s", profiling_name) + missing_ranks = expected_ranks - (candidate_ranks | found_ranks) + logger.info("missing ranks: %s", missing_ranks) + logger.info("input sizes: %s", input_sizes) + logger.info("output sizes: %s", output_sizes) + logger.info("expected ranks: %d", len(expected_ranks)) + logger.info("collective state: %s", collective_state) + logger.info("collective stack trace: \n %s", collective_frames) candidate_ranks.update(found_ranks) candidate_idx.update(found_idx) found_idx.clear() @@ -338,13 +350,18 @@ def build_collectives( if fail_check: # When we see errors in all_to_all, it's hard to tell which rank is the source of the error. mismatch[pg_name] += 1 - print( - f"Input/output mismatch in the collective {collective_seq_id} ", - f"at entry {record_id} for group {pg_desc} collective {profiling_name} ", - f"input_numel {input_numel} output_numel {output_numel} ", - f"{input_sizes} {output_sizes} {len(expected_ranks)} {collective_state} ", - f"\nCollective stack traces: \n{collective_frames}", + logger.info( + "Input/output mismatch in the collective %s at entry %s", + collective_seq_id, + record_id, ) + logger.info("group info: %s", pg_desc) + logger.info("collective: %s", profiling_name) + logger.info("input sizes: %s", input_sizes) + logger.info("output sizes: %s", output_sizes) + logger.info("expected ranks: %d", len(expected_ranks)) + logger.info("collective state: %s", collective_state) + logger.info("collective stack trace: \n%s", collective_frames) candidate_ranks.update(found_ranks) candidate_idx.update(found_idx) found_idx.clear() @@ -366,13 +383,17 @@ def build_collectives( error_msg = ", ".join( f"Culprit rank {error[0]}; {str(error[1])}" for error in errors ) - print( - f"Collective {collective_seq_id} at entry {record_id} errors", - f" for group {pg_desc} collective {profiling_name} ", - f"{input_sizes} {output_sizes} {len(expected_ranks)} {collective_state} ", - f"\nFound errors: {error_msg}.\n", - f"\nCollective stack traces: \n{collective_frames} ", + logger.info( + "Collective %s at entry %s errors", collective_seq_id, record_id ) + logger.info("group info: %s", pg_desc) + logger.info("collective: %s", profiling_name) + logger.info("input sizes: %s", input_sizes) + logger.info("output sizes: %s", output_sizes) + logger.info("expected ranks: %d", len(expected_ranks)) + logger.info("collective state: %s", collective_state) + logger.info("error message: %s", error_msg) + logger.info("collective stack trace: \n%s", collective_frames) candidate_ranks.update(found_ranks) candidate_idx.update(found_idx) found_idx.clear() @@ -402,7 +423,7 @@ def build_collectives( # -> since its not a complete collective, no entry goes into collectives but we still record a nccl call # TODO should there be a way to mark 'mismatches'? else: - print("appending a non-matching collective") + logger.debug("appending a non-matching collective") # TODO: figure out a better for mismatch. # Also, shall we add seq Id as well? for r in candidate_ranks: @@ -418,7 +439,9 @@ def build_collectives( ) if mismatch[pg_name] > MISMATCH_TAIL: - print(f"Too many mismatches for process_group {pg_name}:{desc}, aborting") + logger.error( + "Too many mismatches for process_group %s: %s aborting", pg_name, desc + ) sys.exit(-1) return tracebacks, collectives, nccl_calls @@ -445,7 +468,7 @@ def build_db( groups, _groups, memberships, _memberships, _pg_guids = build_groups_memberships( pg_config ) - print("built groups, memberships") + logger.debug("built groups, memberships") if args.just_print_entries: just_print_entries(entries, _groups, _memberships, _pg_guids, args) @@ -456,12 +479,16 @@ def build_db( tracebacks, collectives, nccl_calls = build_collectives( entries, _groups, _memberships, _pg_guids, version ) - print("built collectives, nccl_calls") + logger.debug("built collectives, nccl_calls") if args.verbose: - print("Groups\n", tabulate(groups, headers=Group._fields)) - print("Memberships\n", tabulate(memberships, headers=Membership._fields)) - print("Collectives\n", tabulate(collectives, headers=Collective._fields)) - print("NCCLCalls\n", tabulate(nccl_calls, headers=NCCLCall._fields)) + logger.debug("Groups") + logger.debug(tabulate(groups, headers=Group._fields)) + logger.debug("Memberships") + logger.debug(tabulate(memberships, headers=Membership._fields)) + logger.debug("Collectives") + logger.debug(tabulate(collectives, headers=Collective._fields)) + logger.debug("NCCLCalls") + logger.debug(tabulate(nccl_calls, headers=NCCLCall._fields)) db = Database( tracebacks=tracebacks, collectives=collectives, diff --git a/tools/flight_recorder/components/config_manager.py b/tools/flight_recorder/components/config_manager.py index d4912c427b3..5a203bbe380 100644 --- a/tools/flight_recorder/components/config_manager.py +++ b/tools/flight_recorder/components/config_manager.py @@ -5,8 +5,14 @@ # LICENSE file in the root directory of this source tree. import argparse +import logging from typing import Optional, Sequence +from tools.flight_recorder.components.utils import FlightRecorderLogger + + +logger = FlightRecorderLogger() + class JobConfig: """ @@ -64,4 +70,6 @@ class JobConfig: assert ( args.just_print_entries ), "Not support selecting pg filters without printing entries" + if args.verbose: + logger.set_log_level(logging.DEBUG) return args diff --git a/tools/flight_recorder/components/loader.py b/tools/flight_recorder/components/loader.py index 442b000532f..30daad78752 100644 --- a/tools/flight_recorder/components/loader.py +++ b/tools/flight_recorder/components/loader.py @@ -13,6 +13,11 @@ import typing from collections import defaultdict from typing import Any, Dict, List, Optional, Set, Tuple, Union +from tools.flight_recorder.components.utils import FlightRecorderLogger + + +logger: FlightRecorderLogger = FlightRecorderLogger() + def read_dump(prefix: str, filename: str) -> Dict[str, Union[str, int, List[Any]]]: basename = os.path.basename(filename) @@ -52,7 +57,7 @@ def _determine_prefix(files: List[str]) -> str: possible_prefixes[p].add(int(r)) if len(possible_prefixes) == 1: prefix = next(iter(possible_prefixes)) - print(f"Inferred common prefix {prefix}") + logger.debug("Inferred common prefix %s", prefix) return prefix else: raise ValueError( @@ -68,6 +73,7 @@ def read_dir( details = {} t0 = time.time() version = "" + filecount = 0 assert os.path.isdir(folder), f"folder {folder} does not exist" for root, _, files in os.walk(folder): if prefix is None: @@ -76,9 +82,10 @@ def read_dir( if f.find(prefix) != 0: continue details[f] = read_dump(prefix, os.path.join(root, f)) + filecount += 1 if not version: version = str(details[f]["version"]) tb = time.time() assert len(details) > 0, f"no files loaded from {folder} with prefix {prefix}" - print(f"loaded {len(files)} files in {tb - t0}s") + logger.debug("loaded %s files in %ss", filecount, tb - t0) return details, version diff --git a/tools/flight_recorder/components/utils.py b/tools/flight_recorder/components/utils.py index 7f2af5eeb29..37a6e100b60 100644 --- a/tools/flight_recorder/components/utils.py +++ b/tools/flight_recorder/components/utils.py @@ -5,22 +5,62 @@ # LICENSE file in the root directory of this source tree. import argparse +import logging import math -from typing import Any, Dict, List, Set, Tuple +from typing import Any, Callable, Dict, List, Optional, Set, Tuple -from tools.flight_recorder.components.types import ( - Group, - MatchState, - Membership, - Op, - P2P, -) +from .types import Group, MatchState, Membership, Op, P2P + + +class FlightRecorderLogger: + _instance: Optional[Any] = None + logger: logging.Logger + + def __init__(self) -> None: + self.logger: logging.Logger = logging.getLogger("Flight Recorder") + + def __new__(cls) -> Any: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance.logger = logging.getLogger("Flight Recorder") + cls._instance.logger.setLevel(logging.INFO) + formatter = logging.Formatter("%(message)s") + ch = logging.StreamHandler() + ch.setFormatter(formatter) + cls._instance.logger.addHandler(ch) + return cls._instance + + def set_log_level(self, level: int) -> None: + self.logger.setLevel(level) + + @property + def debug(self) -> Callable[..., None]: + return self.logger.debug + + @property + def info(self) -> Callable[..., None]: + return self.logger.info + + @property + def warning(self) -> Callable[..., None]: + return self.logger.warning + + @property + def error(self) -> Callable[..., None]: + return self.logger.error + + @property + def critical(self) -> Callable[..., None]: + return self.logger.critical + + +logger = FlightRecorderLogger() try: from tabulate import tabulate except ModuleNotFoundError: - print("tabulate is not installed. Proceeding without it.") + logger.debug("tabulate is not installed. Proceeding without it.") def format_frame(frame: Dict[str, str]) -> str: @@ -121,7 +161,8 @@ def match_coalesced_groups( row = [] i += 1 title = "Match" if match else "MISMATCH" - print(f"{title}\n", tabulate(table)) # type: ignore[operator] + logger.info("%s \n", title) + logger.info("%s", tabulate(table)) # type: ignore[operator] # TODO can't verify seq_id bc there might have been valid seq deltas between ranks even within a pg. for op_list in all_ops.values(): @@ -248,7 +289,7 @@ def just_print_entries( if progress: rows.append(row) - print(tabulate(rows, headers=headers)) + logger.info(tabulate(rows, headers=headers)) def check_no_missing_dump_files(