[fr][c10d] fix flaky test (#143878)

Summary:
Test erroneously assumed that input/output sizes are same and that all
states are matchable.

Fixes issue #143798

Test Plan:
Test passes

Reviewers
Test passes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143878
Approved by: https://github.com/fduwjj
ghstack dependencies: #143865
This commit is contained in:
Chirag Pandya 2024-12-26 12:20:06 -08:00 committed by PyTorch MergeBot
parent 1cd70e7e23
commit 809106a93f

View file

@ -1,5 +1,6 @@
# Owner(s): ["oncall: distributed"]
import math
import pathlib
import sys
@ -113,14 +114,22 @@ class FlightRecorderEventTest(TestCase):
def test_all_events(self):
for collective in sorted(COLLECTIVES):
input_sizes = [[4, 4]]
output_sizes = [[4, 4]]
expectedState = MatchState.FULLY_MATCHED
if collective == "_reduce_scatter_base":
input_sizes = [[4, 4]]
output_sizes = [[input_sizes[0][0] * 2]]
if collective == "all_gather":
output_sizes = [[math.prod(input_sizes[0]) * 2]]
if collective == "all_to_all":
expectedState = MatchState.UNDECIDED
event = create_one_event(
collective, ("0", "default"), [[4, 4]], [[4, 4]], "scheduled", 1
collective, ("0", "default"), input_sizes, output_sizes, "scheduled", 1
)
membership = {"0": {0, 1}}
self.assertEqual(
match_one_event(event, event, membership, "0"), MatchState.FULLY_MATCHED
)
break
result = match_one_event(event, event, membership, "0")
self.assertEqual(result, expectedState)
if __name__ == "__main__":