mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[pipelining] Improve schedule csv loading (#142009)
Add small changes based on feedback from Less when testing out https://github.com/pytorch/torchtitan/pull/707 - expose `validate_schedule` as a function - handle spaces around actions in csv file - add error arrow to `_format_pipeline_schedule()` to better show where the step errored Pull Request resolved: https://github.com/pytorch/pytorch/pull/142009 Approved by: https://github.com/lessw2020
This commit is contained in:
parent
86f306b15e
commit
e8e65764d1
2 changed files with 133 additions and 75 deletions
|
|
@ -25,6 +25,7 @@ from torch.distributed.pipelining.schedules import (
|
|||
_PipelineSchedule,
|
||||
_PipelineScheduleRuntime,
|
||||
_simulate_comms_compute,
|
||||
_validate_schedule,
|
||||
B,
|
||||
F,
|
||||
get_schedule_class,
|
||||
|
|
@ -863,6 +864,39 @@ class TestScheduleLowering(TestCase):
|
|||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
class TestValidateSchedule(TestCase):
|
||||
def test_valid_schedule(self):
|
||||
actions = {
|
||||
0: [_Action(0, F, 0), _Action(0, B, 0)],
|
||||
1: [_Action(1, F, 0), _Action(1, B, 0)],
|
||||
}
|
||||
pp_group_size = 2
|
||||
num_stages = 2
|
||||
num_microbatches = 1
|
||||
_validate_schedule(actions, pp_group_size, num_stages, num_microbatches)
|
||||
|
||||
def test_invalid_schedule_missing_rank(self):
|
||||
actions = {
|
||||
0: [_Action(0, F, 0), _Action(0, B, 0)],
|
||||
}
|
||||
pp_group_size = 2
|
||||
num_stages = 2
|
||||
num_microbatches = 1
|
||||
with self.assertRaises(AssertionError):
|
||||
_validate_schedule(actions, pp_group_size, num_stages, num_microbatches)
|
||||
|
||||
def test_invalid_schedule_missing_action(self):
|
||||
actions = {
|
||||
0: [_Action(0, F, 0)],
|
||||
1: [_Action(1, F, 0)],
|
||||
}
|
||||
pp_group_size = 2
|
||||
num_stages = 2
|
||||
num_microbatches = 1
|
||||
with self.assertRaises(AssertionError):
|
||||
_validate_schedule(actions, pp_group_size, num_stages, num_microbatches)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestScheduleLowering)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -138,31 +138,38 @@ class _Action(NamedTuple):
|
|||
return repr
|
||||
|
||||
@staticmethod
|
||||
def from_str(str):
|
||||
def from_str(action_string: str):
|
||||
"""
|
||||
Reverse of __repr__
|
||||
|
||||
String should be formatted as [stage][action type][(microbatch)]
|
||||
e.g. `2F0`, `1UNSHARD`, `3SEND_F1`
|
||||
"""
|
||||
if match := _action_regex.match(str):
|
||||
action_string = action_string.strip()
|
||||
if match := _action_regex.match(action_string):
|
||||
stage_index, computation_type, microbatch_index = match.groups()
|
||||
return _Action(
|
||||
int(stage_index),
|
||||
_ComputationType.from_str(computation_type),
|
||||
int(microbatch_index) if len(microbatch_index) else None,
|
||||
)
|
||||
elif str == "" or str.isspace():
|
||||
elif action_string == "":
|
||||
return None
|
||||
raise RuntimeError(
|
||||
f"Invalid action string: {str}, should be formatted as [stage][action type][(microbatch)] e.g. 2F0"
|
||||
f"Invalid action string: {action_string}, should be formatted as [stage][action type][(microbatch)] e.g. 2F0"
|
||||
)
|
||||
|
||||
|
||||
def _format_pipeline_order(pipeline_order: Dict[int, List[Optional[_Action]]]) -> str:
|
||||
def _format_pipeline_order(
|
||||
pipeline_order: Dict[int, List[Optional[_Action]]],
|
||||
error_step_number: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Formats the pipeline order in a timestep (row) x rank (column) grid of actions
|
||||
and returns the formatted string
|
||||
and returns the formatted string.
|
||||
|
||||
If `error_step_number` is passed in, an additional label will be added to signify which step
|
||||
that it is erroring on.
|
||||
"""
|
||||
|
||||
# don't mutate the original
|
||||
|
|
@ -202,6 +209,12 @@ def _format_pipeline_order(pipeline_order: Dict[int, List[Optional[_Action]]]) -
|
|||
formatted_rows = [
|
||||
f"{label}: "
|
||||
+ " ".join(f"{str(item):<{max_lengths[i]}}" for i, item in enumerate(row))
|
||||
+ (
|
||||
" <-- ERROR HERE"
|
||||
if error_step_number is not None
|
||||
and int(label.split()[1]) == error_step_number
|
||||
else ""
|
||||
)
|
||||
for label, row in zip(step_labels, transposed_actions)
|
||||
]
|
||||
# Join the rows into a single string
|
||||
|
|
@ -987,6 +1000,73 @@ def _add_send_recv(
|
|||
return comm_actions
|
||||
|
||||
|
||||
def _validate_schedule(
|
||||
actions: Dict[int, List[Optional[_Action]]],
|
||||
pp_group_size: int,
|
||||
num_stages: int,
|
||||
num_microbatches: int,
|
||||
):
|
||||
assert (
|
||||
len(actions) == pp_group_size
|
||||
), f"Schedule has incorrect number of ranks - expected {pp_group_size}, actual {len(actions)}"
|
||||
for rank in range(pp_group_size):
|
||||
assert rank in actions, f"Schedule is missing actions for rank {rank}"
|
||||
|
||||
# We will count all the actions per stage and ensure they happen in a valid order
|
||||
# (e.g. F before (B, I) before W for a given microbatch)
|
||||
stage_actions: Dict[int, Dict[_ComputationType, Set]] = {
|
||||
stage_id: {
|
||||
F: set(),
|
||||
B: set(),
|
||||
I: set(),
|
||||
W: set(),
|
||||
}
|
||||
for stage_id in range(num_stages)
|
||||
}
|
||||
for rank in actions:
|
||||
for action in actions[rank]:
|
||||
if action is None:
|
||||
continue
|
||||
assert isinstance(
|
||||
action, _Action
|
||||
), f"Got an invalid action: {action}, expected instance of _Action"
|
||||
s_id = action.stage_index
|
||||
ctype = action.computation_type
|
||||
mb_id = action.microbatch_index
|
||||
if ctype == F:
|
||||
stage_actions[s_id][F].add(mb_id)
|
||||
elif ctype == B:
|
||||
assert (
|
||||
mb_id in stage_actions[s_id][F]
|
||||
), f"Running Full Backward for stage {s_id}, microbatch {mb_id} without first running Forward"
|
||||
stage_actions[s_id][B].add(mb_id)
|
||||
elif ctype == I:
|
||||
assert (
|
||||
mb_id in stage_actions[s_id][F]
|
||||
), f"Running Backward Input for stage {s_id}, microbatch {mb_id} without first running Forward"
|
||||
stage_actions[s_id][I].add(mb_id)
|
||||
elif ctype == W:
|
||||
assert (
|
||||
mb_id in stage_actions[s_id][B]
|
||||
), f"Running Backward Weight for stage {s_id}, microbatch {mb_id} without first running Backward"
|
||||
stage_actions[s_id][W].add(mb_id)
|
||||
|
||||
for s_id in stage_actions:
|
||||
f_mb = len(stage_actions[s_id][F])
|
||||
b_mb = len(stage_actions[s_id][B])
|
||||
i_mb = len(stage_actions[s_id][I])
|
||||
w_mb = len(stage_actions[s_id][W])
|
||||
|
||||
assert (
|
||||
f_mb == num_microbatches
|
||||
), f"Got {f_mb} {F} microbatches for stage {s_id}, expected {num_microbatches}"
|
||||
|
||||
assert (
|
||||
b_mb + (i_mb + w_mb) // 2 == num_microbatches
|
||||
), f"Invalid backward microbatches for stage {s_id}: expected {num_microbatches} total backwards, \
|
||||
but got B={b_mb}, I={i_mb}, W={w_mb}"
|
||||
|
||||
|
||||
class PipelineScheduleMulti(_PipelineSchedule):
|
||||
"""
|
||||
Base class for multi-stage schedules.
|
||||
|
|
@ -1066,72 +1146,6 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
|||
for rank in self.pipeline_order:
|
||||
writer.writerow(self.pipeline_order[rank])
|
||||
|
||||
def _validate_schedule(self):
|
||||
# TODO(whc) this should be merged with the logic in test_schedule.py#L453-L554
|
||||
def _validate_rank_actions(
|
||||
actions: Dict[int, List[_Action | None]],
|
||||
num_stages: int,
|
||||
num_microbatches: int,
|
||||
):
|
||||
# We will count all the actions per stage and ensure they happen in a valid order
|
||||
# (e.g. F before (B, I) before W for a given microbatch)
|
||||
stage_actions: Dict[int, Dict[_ComputationType, Set]] = {
|
||||
stage_id: {
|
||||
F: set(),
|
||||
B: set(),
|
||||
W: set(),
|
||||
}
|
||||
for stage_id in range(num_stages)
|
||||
}
|
||||
for rank in actions:
|
||||
for action in actions[rank]:
|
||||
if action is None:
|
||||
continue
|
||||
assert isinstance(
|
||||
action, _Action
|
||||
), f"Got an invalid action: {action}, expected instance of _Action"
|
||||
s_id = action.stage_index
|
||||
ctype = action.computation_type
|
||||
mb_id = action.microbatch_index
|
||||
if ctype == F:
|
||||
stage_actions[s_id][F].add(mb_id)
|
||||
elif ctype == B:
|
||||
assert (
|
||||
mb_id in stage_actions[s_id][F]
|
||||
), f"Running Full Backward for stage {s_id}, microbatch {mb_id} without first running Forward"
|
||||
stage_actions[s_id][B].add(mb_id)
|
||||
elif ctype == I:
|
||||
assert (
|
||||
mb_id in stage_actions[s_id][F]
|
||||
), f"Running Backward Input for stage {s_id}, microbatch {mb_id} without first running Forward"
|
||||
# TODO(whc) do we need to track I separately from B or should we just merge them for simplicity
|
||||
stage_actions[s_id][B].add(mb_id)
|
||||
elif ctype == W:
|
||||
assert (
|
||||
mb_id in stage_actions[s_id][B]
|
||||
), f"Running Backward Weight for stage {s_id}, microbatch {mb_id} without first running Backward"
|
||||
stage_actions[s_id][W].add(mb_id)
|
||||
|
||||
for s_id in stage_actions:
|
||||
for ctype in (F, B, W):
|
||||
stage_mb = len(stage_actions[s_id][ctype])
|
||||
assert (
|
||||
stage_mb == num_microbatches
|
||||
), f"Got {stage_mb} {ctype} microbatches for stage {s_id}, expected {num_microbatches}"
|
||||
|
||||
assert (
|
||||
len(self.pipeline_order) == self.pp_group_size
|
||||
), f"Schedule has incorrect number of ranks - expected {self.pp_group_size}, actual {len(self.pipeline_order)}"
|
||||
for rank in range(self.pp_group_size):
|
||||
assert (
|
||||
rank in self.pipeline_order
|
||||
), f"Schedule is missing actions for rank {rank}"
|
||||
_validate_rank_actions(
|
||||
self.pipeline_order,
|
||||
self._num_stages,
|
||||
self._n_microbatches,
|
||||
)
|
||||
|
||||
def _load_csv(self, filename, format="compute_only"):
|
||||
"""Load a CSV representation of the schedule from a file with the provided filename.
|
||||
This API will most likely get renamed/refactored so is marked as internal for now.
|
||||
|
|
@ -1143,7 +1157,12 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
|||
reader = csv.reader(csvfile)
|
||||
for rank, row in enumerate(reader):
|
||||
self.pipeline_order[rank] = [_Action.from_str(s) for s in row]
|
||||
self._validate_schedule()
|
||||
_validate_schedule(
|
||||
self.pipeline_order,
|
||||
self.pp_group_size,
|
||||
self._num_stages,
|
||||
self._n_microbatches,
|
||||
)
|
||||
|
||||
def step(self, *args, target=None, losses: Optional[List] = None, **kwargs):
|
||||
"""
|
||||
|
|
@ -1342,7 +1361,12 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
|||
time_step,
|
||||
action,
|
||||
)
|
||||
logger.error("%s", _format_pipeline_order(self.pipeline_order))
|
||||
logger.error(
|
||||
"%s",
|
||||
_format_pipeline_order(
|
||||
self.pipeline_order, error_step_number=time_step
|
||||
),
|
||||
)
|
||||
raise e
|
||||
# Return losses if there is a container passed in
|
||||
self._update_losses(self._stages, losses)
|
||||
|
|
@ -1649,7 +1673,7 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
|||
)
|
||||
# TODO(whc) what is the best practice for printing a multiline log?
|
||||
# logger will split it into multiple log lines, but this makes it hard to read (too wide)
|
||||
print(_format_pipeline_order(self.pipeline_order_with_comms)) # type: ignore[arg-type]
|
||||
print(_format_pipeline_order(self.pipeline_order_with_comms, error_step_number=time_step)) # type: ignore[arg-type]
|
||||
raise e
|
||||
|
||||
# Mostly these operations should have finished long ago, but there isn't an obvious time when to wait for them
|
||||
|
|
|
|||
Loading…
Reference in a new issue