[pipelining] clean up stage functions (#140418)

Clean up methods related to stage input/output shape verification which are no longer needed

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140418
Approved by: https://github.com/wconstab
ghstack dependencies: #140019
This commit is contained in:
Howard Huang 2024-11-12 08:53:15 -08:00 committed by PyTorch MergeBot
parent 2ac71a5771
commit 7578a0b268

View file

@ -3,7 +3,7 @@
import logging
import operator
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
@ -1211,163 +1211,6 @@ def build_stage(
)
# Manual PipelineStage functions and definition
METADATA_TENSOR_LEN = 100
PLACEHOLDER_VAL = -1
def _create_empty_tensors(
tensor: Union[torch.Tensor, Iterable[torch.Tensor]], device: torch.device
) -> List[torch.Tensor]:
"""
Creates a list of empty tensors with the same properties (like shape and dtype) as the input tensor(s),
and places them on the specified device.
Args:
tensor (Union[torch.Tensor, List[torch.tensor]]): The input tensor(s).
device (torch.device): The device where the new tensors will be placed.
Returns:
List[torch.Tensor]: A list of empty tensors with the same properties as the input tensor(s).
"""
if isinstance(tensor, torch.Tensor):
return [torch.empty_like(tensor, device=device)]
elif isinstance(tensor, (list, tuple)):
return [torch.empty_like(t, device=device) for t in tensor]
raise TypeError(f"Unsupported type {type(tensor)} cannot create empty tensors")
def _create_metadata_tensor(
tensors: Optional[List[torch.Tensor]] = None,
device: Optional[torch.device] = torch.device("cpu"),
) -> torch.Tensor:
"""
Create a metadata tensor that can be sent over the wire.
This tensor contains the number of dimensions and the shape of each tensor being sent.
The data is of format [num_dims, dim1, dim2, ...].
If the tensor is None, a tensor of only placeholder values will be returned.
Inputs:
tensors: A list of tensors, the tensors will converted into its shape dimensions and
these dimensions will be concatenated.
device: The device where the metadata tensor will be created.
If the tensor is None, then this tensor will contain PLACEHOLDER_VALs.
"""
metadata_tensor = torch.full(
(METADATA_TENSOR_LEN,),
PLACEHOLDER_VAL,
dtype=torch.int32,
device=device,
)
if tensors:
# Create a list of tensors containing the number of dimensions and the shape of each tensor
data = [
# data is of format [num_dims, dim1, dim2, ...]
torch.tensor(
[len(tensor.shape)] + list(tensor.shape),
dtype=torch.int32,
device=device,
)
for tensor in tensors
]
# Concatenate the data into a single tensor
data_tensor = torch.cat(data)
dt_shape = data_tensor.shape[0]
if dt_shape > METADATA_TENSOR_LEN:
raise ValueError(
f"Metadata tensor size ({dt_shape}) exceeds maximum allowed length ({METADATA_TENSOR_LEN})."
)
metadata_tensor[:dt_shape] = data_tensor
return metadata_tensor
def _extract_metadata_from_tensor(tensor: torch.Tensor) -> List[torch.Size]:
"""
Extract the number of dimensions and the shape of each tensor from a metadata tensor.
"""
metadata: List[torch.Size] = []
i = 0
while i < len(tensor) and tensor[i] != PLACEHOLDER_VAL:
num_dims = int(tensor[i].item())
shape = torch.Size(tensor[i + 1 : i + 1 + num_dims].tolist())
metadata.append(shape)
i += num_dims + 1
return metadata
def _get_stage_shapes(
stage_modules: List[nn.Module],
stage_ids: List[int],
num_stages: int,
rank: int,
world_size: int,
device: torch.device,
microbatch: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
):
"""
Performs a dry run through all the pipeline stages (a rank can have multiple pipeline stages in the case of
virtual pipelining) and returns the shape of the inputs and outputs of the module.
Only the first stage must pass in a microbatch.
Each rank must call _get_stage_shapes or the program will hang.
Args:
stage_modules: The chunks assigned to this rank. Rhe length should be 1 for any
non-interleaved schedules and >1 for any interleaved schedules.
stage_ids: The id of the stages assigned to this rank.
num_stages: Total number of stages.
rank: Rank of the current process.
world_size: Number of processes participating in the pipeline.
device: Device where the tensors are allocated.
Returns a dictionary containing the following keys:
"inputs": Shape of the inputs to the module
"outputs": Shape of the outputs of the module
"""
stage_id_to_shapes: Dict[int, Dict[str, list[torch.Size]]] = {}
for stage_id, model in zip(stage_ids, stage_modules):
input_shape_metadata_tensor = _create_metadata_tensor(device=device)
# TODO: Assumes prev_stage == rank - 1 and next_stage == rank + 1
prev_rank = (rank - 1) % world_size
next_rank = (rank + 1) % world_size
shapes = {}
# first stage doesn't receive anything and uses a microbatch
if stage_id == 0:
if microbatch is None:
raise RuntimeError("Microbatch is required for first stage")
example_fwd_inputs = microbatch
if isinstance(example_fwd_inputs, torch.Tensor):
example_fwd_inputs = [example_fwd_inputs]
else:
# other stages must receive shape information
# TODO: send/recv should take a group, rather than use the default group
dist.recv(input_shape_metadata_tensor, prev_rank)
metadata = _extract_metadata_from_tensor(input_shape_metadata_tensor)
example_fwd_inputs = [
torch.empty(shape_list, device=device) for shape_list in metadata
]
shapes["inputs"] = [fwd_input.shape for fwd_input in example_fwd_inputs]
# perform forward
# TODO: if forward fails raise a more descriptive error explaining which stage failed
fwd_outputs = model(*example_fwd_inputs)
fwd_outputs = _create_empty_tensors(fwd_outputs, device)
shapes["outputs"] = [fwd_output.shape for fwd_output in fwd_outputs]
# send shape dims
if stage_id != num_stages - 1:
output_shape_metadata_tensor = _create_metadata_tensor(
fwd_outputs, device=device
)
dist.send(output_shape_metadata_tensor, next_rank)
stage_id_to_shapes[stage_id] = shapes
logger.info(stage_id_to_shapes)
return stage_id_to_shapes
class PipelineStage(_PipelineStageBase):
"""
A class representing a pipeline stage in a pipeline parallelism setup.
@ -1659,100 +1502,3 @@ class PipelineStage(_PipelineStageBase):
ops.append(dist.P2POp(dist.irecv, recv_tensor, self.next_rank, self.group))
return True
def _validate_stage_shapes(pipeline_stages: List[PipelineStage]):
"""
Check that the buffer shapes match between stages was expected by performing an all_gather between
all stages.
"""
if len(pipeline_stages) == 0:
raise ValueError("No pipeline stages provided.")
virtual_pipeline_size = len(pipeline_stages)
all_inputs = []
all_outputs = []
world_size = pipeline_stages[0].group_size
num_stages = pipeline_stages[0].num_stages
# perform all gathers between all stages
for virtual_id, stage in enumerate(pipeline_stages):
world_size = stage.group_size
stage_id: int = stage.stage_index
rank = stage.group_rank
# check that world_size and num_stages are consistent across all stages
if stage.group_size != world_size:
raise ValueError(
f"Stage id {stage_id} has world size ({stage.group_size}) \
which does not match world size ({world_size}) of other stages."
)
if stage.num_stages != num_stages:
raise ValueError(
f"Stage id {stage_id} has num stages ({stage.num_stages}) \
which does not match num stages ({num_stages}) of other stages."
)
pg_rank = dist.get_rank(stage.group)
if rank != pg_rank:
raise ValueError(
f"Rank {rank} is not equal to process group rank {pg_rank}"
)
if (num_stages := stage.num_stages) % world_size != 0:
raise ValueError(
f"Number of stages ({num_stages}) must be a multiple of the world_size ({world_size})"
)
# all gather each ranks inputs
tensor_list = [
_create_metadata_tensor(device=stage.device)
for _ in range(stage.group_size)
]
expected_inputs = stage.inputs
stage_input = _create_metadata_tensor(expected_inputs, device=stage.device)
dist.all_gather(tensor_list, stage_input)
stage_input_shapes = [
_extract_metadata_from_tensor(tensor) for tensor in tensor_list
]
# all gather each ranks outputs
tensor_list = [
_create_metadata_tensor(device=stage.device)
for _ in range(stage.group_size)
]
outputs_meta = stage.get_outputs_meta()
# TODO, (1) are we deleting output validation when we move to shape inference?
# (2) if not, we should support multiple outputs
assert (
len(outputs_meta) == 1
), f"validation logic assumes single output, got {len(outputs_meta)} outputs "
dist.all_gather(tensor_list, outputs_meta[0])
stage_output_shapes = [
_extract_metadata_from_tensor(tensor) for tensor in tensor_list
]
logger.debug(
f"Rank: {pg_rank}", # noqa: G004
f"Stage id: {stage_id}",
f"Stage num stages: {stage.num_stages}",
f"Stage rank: {rank}",
f"Stage world size: {world_size}",
f"Stage {virtual_id * world_size}-{(virtual_id + 1) * world_size - 1} input shapes: {stage_input_shapes}", # noqa: G003
f"Stage {virtual_id * world_size}-{(virtual_id + 1) * world_size - 1} output shapes: {stage_output_shapes}", # noqa: G003
)
all_inputs.extend(stage_input_shapes)
all_outputs.extend(stage_output_shapes)
# log only rank 0's view, they will all be equivalent
if pg_rank == 0:
logger.info(
"all stage inputs: %s \n all stage outputs: %s", all_inputs, all_outputs
)
# Check if the output for stage 0 matches the input at stage 1, and so forth
for i in range(virtual_pipeline_size * world_size - 1):
if (out := all_outputs[i]) != (inp := all_inputs[i + 1]):
raise ValueError(
f"Stage_id {i} output shape {out} at does not match stage_id {i + 1} input shape {inp}."
)