diff --git a/test/distributed/flight_recorder/test_fr_analysis.py b/test/distributed/flight_recorder/test_fr_analysis.py index fd0c6df0cc9..2e389de7694 100644 --- a/test/distributed/flight_recorder/test_fr_analysis.py +++ b/test/distributed/flight_recorder/test_fr_analysis.py @@ -53,10 +53,10 @@ class FlightRecorderEventTest(TestCase): ) e3 = create_one_event( - "alltoall", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1 + "all_to_all", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1 ) e4 = create_one_event( - "alltoall", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1 + "all_to_all", ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1 ) self.assertEqual(match_one_event(e3, e4, membership), MatchState.UNDECIDED) diff --git a/tools/flight_recorder/components/builder.py b/tools/flight_recorder/components/builder.py index a0307c3b502..b3cb9f79246 100644 --- a/tools/flight_recorder/components/builder.py +++ b/tools/flight_recorder/components/builder.py @@ -244,7 +244,7 @@ def build_collectives( nccl_calls.extend(reversed(reversed_calls)) else: has_undecided_case = False - errors = Set() + errors = set() for o in expected_ranks.intersection(set(other_ranks)): for i, e in enumerate(all_entries[o]): # type: ignore[index] # step over ops from other PGs diff --git a/tools/flight_recorder/components/types.py b/tools/flight_recorder/components/types.py index e55c2370f30..4a33f3580e4 100644 --- a/tools/flight_recorder/components/types.py +++ b/tools/flight_recorder/components/types.py @@ -138,8 +138,7 @@ COLLECTIVES = { "_reduce_scatter_base", "gather", "scatter", - "alltoall_base", - "alltoall", + "all_to_all", } P2P = { @@ -158,7 +157,7 @@ class MatchState(Enum): - COLLECTIVE_STATE_MISMATCH: The states of the collective not same, such as one finished while another just started or scheduled. - UNDECIDED: - The match status is ambiguous or cannot be determined, e.g., we might need to check all ranks for alltoall_base. + The match status is ambiguous or cannot be determined, e.g., we might need to check all ranks for all_to_all. """ FULLY_MATCHED = 1 @@ -171,6 +170,8 @@ class MatchState(Enum): def check_size_evenly_broadcasting( list1: List[Any], list2: List[Any], size: int ) -> bool: + if len(list1) != len(list2): + return False ratio = None for a, b in zip(list1, list2): current_ratio = int(a) / int(b) @@ -283,7 +284,7 @@ class Op: elif self.type in COLLECTIVES: if self.type != other.type: return MatchState.COLLECTIVE_TYPE_MISMATCH - if self.type in ["alltoall", "alltoall_base"]: + if self.type == "all_to_all": return MatchState.UNDECIDED if self.type != "scatter" and self.input_sizes != other.input_sizes: return MatchState.SIZE_OR_SYNTAX_MISMATCH @@ -297,14 +298,14 @@ class Op: "all_gather", "all_gather_base", ] and not check_size_evenly_broadcasting( - other.output_sizes, self.input_sizes, self.pg_size + other.output_sizes[0], self.input_sizes[0], self.pg_size ): return MatchState.SIZE_OR_SYNTAX_MISMATCH if self.type in [ "reduce_scatter", "_reduce_scatter_base", ] and not check_size_evenly_broadcasting( - other.input_sizes, self.output_sizes, self.pg_size + other.input_sizes[0], self.output_sizes[0], self.pg_size ): return MatchState.SIZE_OR_SYNTAX_MISMATCH # TODO: need to add more checks for gather and scatter. diff --git a/tools/flight_recorder/components/utils.py b/tools/flight_recorder/components/utils.py index ef0e9a9f138..2f9d382261b 100644 --- a/tools/flight_recorder/components/utils.py +++ b/tools/flight_recorder/components/utils.py @@ -202,8 +202,8 @@ def check_no_missing_dump_files( ) -> None: all_ranks = set() for membership in memberships: - all_ranks.add(str(membership.global_rank)) - dumps_ranks = set(entries.keys()) + all_ranks.add(int(membership.global_rank)) + dumps_ranks = {int(key) for key in entries.keys()} assert ( dumps_ranks == all_ranks ), f"Missing dump files from ranks {all_ranks - dumps_ranks}"