[pipelining] Improve schedule csv loading (#142009)

Add small changes based on feedback from Less when testing out https://github.com/pytorch/torchtitan/pull/707
- expose `validate_schedule` as a function
- handle spaces around actions in csv file
- add error arrow to `_format_pipeline_schedule()` to better show where the step errored

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142009
Approved by: https://github.com/lessw2020
This commit is contained in:
Howard Huang 2024-12-03 16:58:24 -08:00 committed by PyTorch MergeBot
parent 86f306b15e
commit e8e65764d1
2 changed files with 133 additions and 75 deletions

View file

@ -25,6 +25,7 @@ from torch.distributed.pipelining.schedules import (
_PipelineSchedule,
_PipelineScheduleRuntime,
_simulate_comms_compute,
_validate_schedule,
B,
F,
get_schedule_class,
@ -863,6 +864,39 @@ class TestScheduleLowering(TestCase):
torch.distributed.destroy_process_group()
class TestValidateSchedule(TestCase):
def test_valid_schedule(self):
actions = {
0: [_Action(0, F, 0), _Action(0, B, 0)],
1: [_Action(1, F, 0), _Action(1, B, 0)],
}
pp_group_size = 2
num_stages = 2
num_microbatches = 1
_validate_schedule(actions, pp_group_size, num_stages, num_microbatches)
def test_invalid_schedule_missing_rank(self):
actions = {
0: [_Action(0, F, 0), _Action(0, B, 0)],
}
pp_group_size = 2
num_stages = 2
num_microbatches = 1
with self.assertRaises(AssertionError):
_validate_schedule(actions, pp_group_size, num_stages, num_microbatches)
def test_invalid_schedule_missing_action(self):
actions = {
0: [_Action(0, F, 0)],
1: [_Action(1, F, 0)],
}
pp_group_size = 2
num_stages = 2
num_microbatches = 1
with self.assertRaises(AssertionError):
_validate_schedule(actions, pp_group_size, num_stages, num_microbatches)
instantiate_parametrized_tests(TestScheduleLowering)
if __name__ == "__main__":

View file

@ -138,31 +138,38 @@ class _Action(NamedTuple):
return repr
@staticmethod
def from_str(str):
def from_str(action_string: str):
"""
Reverse of __repr__
String should be formatted as [stage][action type][(microbatch)]
e.g. `2F0`, `1UNSHARD`, `3SEND_F1`
"""
if match := _action_regex.match(str):
action_string = action_string.strip()
if match := _action_regex.match(action_string):
stage_index, computation_type, microbatch_index = match.groups()
return _Action(
int(stage_index),
_ComputationType.from_str(computation_type),
int(microbatch_index) if len(microbatch_index) else None,
)
elif str == "" or str.isspace():
elif action_string == "":
return None
raise RuntimeError(
f"Invalid action string: {str}, should be formatted as [stage][action type][(microbatch)] e.g. 2F0"
f"Invalid action string: {action_string}, should be formatted as [stage][action type][(microbatch)] e.g. 2F0"
)
def _format_pipeline_order(pipeline_order: Dict[int, List[Optional[_Action]]]) -> str:
def _format_pipeline_order(
pipeline_order: Dict[int, List[Optional[_Action]]],
error_step_number: Optional[int] = None,
) -> str:
"""
Formats the pipeline order in a timestep (row) x rank (column) grid of actions
and returns the formatted string
and returns the formatted string.
If `error_step_number` is passed in, an additional label will be added to signify which step
that it is erroring on.
"""
# don't mutate the original
@ -202,6 +209,12 @@ def _format_pipeline_order(pipeline_order: Dict[int, List[Optional[_Action]]]) -
formatted_rows = [
f"{label}: "
+ " ".join(f"{str(item):<{max_lengths[i]}}" for i, item in enumerate(row))
+ (
" <-- ERROR HERE"
if error_step_number is not None
and int(label.split()[1]) == error_step_number
else ""
)
for label, row in zip(step_labels, transposed_actions)
]
# Join the rows into a single string
@ -987,6 +1000,73 @@ def _add_send_recv(
return comm_actions
def _validate_schedule(
actions: Dict[int, List[Optional[_Action]]],
pp_group_size: int,
num_stages: int,
num_microbatches: int,
):
assert (
len(actions) == pp_group_size
), f"Schedule has incorrect number of ranks - expected {pp_group_size}, actual {len(actions)}"
for rank in range(pp_group_size):
assert rank in actions, f"Schedule is missing actions for rank {rank}"
# We will count all the actions per stage and ensure they happen in a valid order
# (e.g. F before (B, I) before W for a given microbatch)
stage_actions: Dict[int, Dict[_ComputationType, Set]] = {
stage_id: {
F: set(),
B: set(),
I: set(),
W: set(),
}
for stage_id in range(num_stages)
}
for rank in actions:
for action in actions[rank]:
if action is None:
continue
assert isinstance(
action, _Action
), f"Got an invalid action: {action}, expected instance of _Action"
s_id = action.stage_index
ctype = action.computation_type
mb_id = action.microbatch_index
if ctype == F:
stage_actions[s_id][F].add(mb_id)
elif ctype == B:
assert (
mb_id in stage_actions[s_id][F]
), f"Running Full Backward for stage {s_id}, microbatch {mb_id} without first running Forward"
stage_actions[s_id][B].add(mb_id)
elif ctype == I:
assert (
mb_id in stage_actions[s_id][F]
), f"Running Backward Input for stage {s_id}, microbatch {mb_id} without first running Forward"
stage_actions[s_id][I].add(mb_id)
elif ctype == W:
assert (
mb_id in stage_actions[s_id][B]
), f"Running Backward Weight for stage {s_id}, microbatch {mb_id} without first running Backward"
stage_actions[s_id][W].add(mb_id)
for s_id in stage_actions:
f_mb = len(stage_actions[s_id][F])
b_mb = len(stage_actions[s_id][B])
i_mb = len(stage_actions[s_id][I])
w_mb = len(stage_actions[s_id][W])
assert (
f_mb == num_microbatches
), f"Got {f_mb} {F} microbatches for stage {s_id}, expected {num_microbatches}"
assert (
b_mb + (i_mb + w_mb) // 2 == num_microbatches
), f"Invalid backward microbatches for stage {s_id}: expected {num_microbatches} total backwards, \
but got B={b_mb}, I={i_mb}, W={w_mb}"
class PipelineScheduleMulti(_PipelineSchedule):
"""
Base class for multi-stage schedules.
@ -1066,72 +1146,6 @@ class PipelineScheduleMulti(_PipelineSchedule):
for rank in self.pipeline_order:
writer.writerow(self.pipeline_order[rank])
def _validate_schedule(self):
# TODO(whc) this should be merged with the logic in test_schedule.py#L453-L554
def _validate_rank_actions(
actions: Dict[int, List[_Action | None]],
num_stages: int,
num_microbatches: int,
):
# We will count all the actions per stage and ensure they happen in a valid order
# (e.g. F before (B, I) before W for a given microbatch)
stage_actions: Dict[int, Dict[_ComputationType, Set]] = {
stage_id: {
F: set(),
B: set(),
W: set(),
}
for stage_id in range(num_stages)
}
for rank in actions:
for action in actions[rank]:
if action is None:
continue
assert isinstance(
action, _Action
), f"Got an invalid action: {action}, expected instance of _Action"
s_id = action.stage_index
ctype = action.computation_type
mb_id = action.microbatch_index
if ctype == F:
stage_actions[s_id][F].add(mb_id)
elif ctype == B:
assert (
mb_id in stage_actions[s_id][F]
), f"Running Full Backward for stage {s_id}, microbatch {mb_id} without first running Forward"
stage_actions[s_id][B].add(mb_id)
elif ctype == I:
assert (
mb_id in stage_actions[s_id][F]
), f"Running Backward Input for stage {s_id}, microbatch {mb_id} without first running Forward"
# TODO(whc) do we need to track I separately from B or should we just merge them for simplicity
stage_actions[s_id][B].add(mb_id)
elif ctype == W:
assert (
mb_id in stage_actions[s_id][B]
), f"Running Backward Weight for stage {s_id}, microbatch {mb_id} without first running Backward"
stage_actions[s_id][W].add(mb_id)
for s_id in stage_actions:
for ctype in (F, B, W):
stage_mb = len(stage_actions[s_id][ctype])
assert (
stage_mb == num_microbatches
), f"Got {stage_mb} {ctype} microbatches for stage {s_id}, expected {num_microbatches}"
assert (
len(self.pipeline_order) == self.pp_group_size
), f"Schedule has incorrect number of ranks - expected {self.pp_group_size}, actual {len(self.pipeline_order)}"
for rank in range(self.pp_group_size):
assert (
rank in self.pipeline_order
), f"Schedule is missing actions for rank {rank}"
_validate_rank_actions(
self.pipeline_order,
self._num_stages,
self._n_microbatches,
)
def _load_csv(self, filename, format="compute_only"):
"""Load a CSV representation of the schedule from a file with the provided filename.
This API will most likely get renamed/refactored so is marked as internal for now.
@ -1143,7 +1157,12 @@ class PipelineScheduleMulti(_PipelineSchedule):
reader = csv.reader(csvfile)
for rank, row in enumerate(reader):
self.pipeline_order[rank] = [_Action.from_str(s) for s in row]
self._validate_schedule()
_validate_schedule(
self.pipeline_order,
self.pp_group_size,
self._num_stages,
self._n_microbatches,
)
def step(self, *args, target=None, losses: Optional[List] = None, **kwargs):
"""
@ -1342,7 +1361,12 @@ class PipelineScheduleMulti(_PipelineSchedule):
time_step,
action,
)
logger.error("%s", _format_pipeline_order(self.pipeline_order))
logger.error(
"%s",
_format_pipeline_order(
self.pipeline_order, error_step_number=time_step
),
)
raise e
# Return losses if there is a container passed in
self._update_losses(self._stages, losses)
@ -1649,7 +1673,7 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
)
# TODO(whc) what is the best practice for printing a multiline log?
# logger will split it into multiple log lines, but this makes it hard to read (too wide)
print(_format_pipeline_order(self.pipeline_order_with_comms)) # type: ignore[arg-type]
print(_format_pipeline_order(self.pipeline_order_with_comms, error_step_number=time_step)) # type: ignore[arg-type]
raise e
# Mostly these operations should have finished long ago, but there isn't an obvious time when to wait for them