mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[pipelining] clean up stage functions (#140418)
Clean up methods related to stage input/output shape verification which are no longer needed Pull Request resolved: https://github.com/pytorch/pytorch/pull/140418 Approved by: https://github.com/wconstab ghstack dependencies: #140019
This commit is contained in:
parent
2ac71a5771
commit
7578a0b268
1 changed files with 1 additions and 255 deletions
|
|
@ -3,7 +3,7 @@
|
|||
import logging
|
||||
import operator
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -1211,163 +1211,6 @@ def build_stage(
|
|||
)
|
||||
|
||||
|
||||
# Manual PipelineStage functions and definition
|
||||
|
||||
METADATA_TENSOR_LEN = 100
|
||||
PLACEHOLDER_VAL = -1
|
||||
|
||||
|
||||
def _create_empty_tensors(
|
||||
tensor: Union[torch.Tensor, Iterable[torch.Tensor]], device: torch.device
|
||||
) -> List[torch.Tensor]:
|
||||
"""
|
||||
Creates a list of empty tensors with the same properties (like shape and dtype) as the input tensor(s),
|
||||
and places them on the specified device.
|
||||
Args:
|
||||
tensor (Union[torch.Tensor, List[torch.tensor]]): The input tensor(s).
|
||||
device (torch.device): The device where the new tensors will be placed.
|
||||
Returns:
|
||||
List[torch.Tensor]: A list of empty tensors with the same properties as the input tensor(s).
|
||||
"""
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
return [torch.empty_like(tensor, device=device)]
|
||||
elif isinstance(tensor, (list, tuple)):
|
||||
return [torch.empty_like(t, device=device) for t in tensor]
|
||||
raise TypeError(f"Unsupported type {type(tensor)} cannot create empty tensors")
|
||||
|
||||
|
||||
def _create_metadata_tensor(
|
||||
tensors: Optional[List[torch.Tensor]] = None,
|
||||
device: Optional[torch.device] = torch.device("cpu"),
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Create a metadata tensor that can be sent over the wire.
|
||||
This tensor contains the number of dimensions and the shape of each tensor being sent.
|
||||
|
||||
The data is of format [num_dims, dim1, dim2, ...].
|
||||
If the tensor is None, a tensor of only placeholder values will be returned.
|
||||
|
||||
Inputs:
|
||||
tensors: A list of tensors, the tensors will converted into its shape dimensions and
|
||||
these dimensions will be concatenated.
|
||||
device: The device where the metadata tensor will be created.
|
||||
If the tensor is None, then this tensor will contain PLACEHOLDER_VALs.
|
||||
|
||||
"""
|
||||
metadata_tensor = torch.full(
|
||||
(METADATA_TENSOR_LEN,),
|
||||
PLACEHOLDER_VAL,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
if tensors:
|
||||
# Create a list of tensors containing the number of dimensions and the shape of each tensor
|
||||
data = [
|
||||
# data is of format [num_dims, dim1, dim2, ...]
|
||||
torch.tensor(
|
||||
[len(tensor.shape)] + list(tensor.shape),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
for tensor in tensors
|
||||
]
|
||||
# Concatenate the data into a single tensor
|
||||
data_tensor = torch.cat(data)
|
||||
dt_shape = data_tensor.shape[0]
|
||||
if dt_shape > METADATA_TENSOR_LEN:
|
||||
raise ValueError(
|
||||
f"Metadata tensor size ({dt_shape}) exceeds maximum allowed length ({METADATA_TENSOR_LEN})."
|
||||
)
|
||||
metadata_tensor[:dt_shape] = data_tensor
|
||||
return metadata_tensor
|
||||
|
||||
|
||||
def _extract_metadata_from_tensor(tensor: torch.Tensor) -> List[torch.Size]:
|
||||
"""
|
||||
Extract the number of dimensions and the shape of each tensor from a metadata tensor.
|
||||
"""
|
||||
metadata: List[torch.Size] = []
|
||||
i = 0
|
||||
while i < len(tensor) and tensor[i] != PLACEHOLDER_VAL:
|
||||
num_dims = int(tensor[i].item())
|
||||
shape = torch.Size(tensor[i + 1 : i + 1 + num_dims].tolist())
|
||||
metadata.append(shape)
|
||||
i += num_dims + 1
|
||||
return metadata
|
||||
|
||||
|
||||
def _get_stage_shapes(
|
||||
stage_modules: List[nn.Module],
|
||||
stage_ids: List[int],
|
||||
num_stages: int,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
device: torch.device,
|
||||
microbatch: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
|
||||
):
|
||||
"""
|
||||
Performs a dry run through all the pipeline stages (a rank can have multiple pipeline stages in the case of
|
||||
virtual pipelining) and returns the shape of the inputs and outputs of the module.
|
||||
Only the first stage must pass in a microbatch.
|
||||
|
||||
Each rank must call _get_stage_shapes or the program will hang.
|
||||
|
||||
Args:
|
||||
stage_modules: The chunks assigned to this rank. Rhe length should be 1 for any
|
||||
non-interleaved schedules and >1 for any interleaved schedules.
|
||||
stage_ids: The id of the stages assigned to this rank.
|
||||
num_stages: Total number of stages.
|
||||
rank: Rank of the current process.
|
||||
world_size: Number of processes participating in the pipeline.
|
||||
device: Device where the tensors are allocated.
|
||||
|
||||
Returns a dictionary containing the following keys:
|
||||
"inputs": Shape of the inputs to the module
|
||||
"outputs": Shape of the outputs of the module
|
||||
"""
|
||||
|
||||
stage_id_to_shapes: Dict[int, Dict[str, list[torch.Size]]] = {}
|
||||
for stage_id, model in zip(stage_ids, stage_modules):
|
||||
input_shape_metadata_tensor = _create_metadata_tensor(device=device)
|
||||
# TODO: Assumes prev_stage == rank - 1 and next_stage == rank + 1
|
||||
prev_rank = (rank - 1) % world_size
|
||||
next_rank = (rank + 1) % world_size
|
||||
shapes = {}
|
||||
|
||||
# first stage doesn't receive anything and uses a microbatch
|
||||
if stage_id == 0:
|
||||
if microbatch is None:
|
||||
raise RuntimeError("Microbatch is required for first stage")
|
||||
example_fwd_inputs = microbatch
|
||||
if isinstance(example_fwd_inputs, torch.Tensor):
|
||||
example_fwd_inputs = [example_fwd_inputs]
|
||||
else:
|
||||
# other stages must receive shape information
|
||||
# TODO: send/recv should take a group, rather than use the default group
|
||||
dist.recv(input_shape_metadata_tensor, prev_rank)
|
||||
metadata = _extract_metadata_from_tensor(input_shape_metadata_tensor)
|
||||
example_fwd_inputs = [
|
||||
torch.empty(shape_list, device=device) for shape_list in metadata
|
||||
]
|
||||
shapes["inputs"] = [fwd_input.shape for fwd_input in example_fwd_inputs]
|
||||
|
||||
# perform forward
|
||||
# TODO: if forward fails raise a more descriptive error explaining which stage failed
|
||||
fwd_outputs = model(*example_fwd_inputs)
|
||||
fwd_outputs = _create_empty_tensors(fwd_outputs, device)
|
||||
shapes["outputs"] = [fwd_output.shape for fwd_output in fwd_outputs]
|
||||
|
||||
# send shape dims
|
||||
if stage_id != num_stages - 1:
|
||||
output_shape_metadata_tensor = _create_metadata_tensor(
|
||||
fwd_outputs, device=device
|
||||
)
|
||||
dist.send(output_shape_metadata_tensor, next_rank)
|
||||
stage_id_to_shapes[stage_id] = shapes
|
||||
logger.info(stage_id_to_shapes)
|
||||
return stage_id_to_shapes
|
||||
|
||||
|
||||
class PipelineStage(_PipelineStageBase):
|
||||
"""
|
||||
A class representing a pipeline stage in a pipeline parallelism setup.
|
||||
|
|
@ -1659,100 +1502,3 @@ class PipelineStage(_PipelineStageBase):
|
|||
ops.append(dist.P2POp(dist.irecv, recv_tensor, self.next_rank, self.group))
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _validate_stage_shapes(pipeline_stages: List[PipelineStage]):
|
||||
"""
|
||||
Check that the buffer shapes match between stages was expected by performing an all_gather between
|
||||
all stages.
|
||||
"""
|
||||
if len(pipeline_stages) == 0:
|
||||
raise ValueError("No pipeline stages provided.")
|
||||
|
||||
virtual_pipeline_size = len(pipeline_stages)
|
||||
all_inputs = []
|
||||
all_outputs = []
|
||||
world_size = pipeline_stages[0].group_size
|
||||
num_stages = pipeline_stages[0].num_stages
|
||||
|
||||
# perform all gathers between all stages
|
||||
for virtual_id, stage in enumerate(pipeline_stages):
|
||||
world_size = stage.group_size
|
||||
stage_id: int = stage.stage_index
|
||||
rank = stage.group_rank
|
||||
# check that world_size and num_stages are consistent across all stages
|
||||
if stage.group_size != world_size:
|
||||
raise ValueError(
|
||||
f"Stage id {stage_id} has world size ({stage.group_size}) \
|
||||
which does not match world size ({world_size}) of other stages."
|
||||
)
|
||||
if stage.num_stages != num_stages:
|
||||
raise ValueError(
|
||||
f"Stage id {stage_id} has num stages ({stage.num_stages}) \
|
||||
which does not match num stages ({num_stages}) of other stages."
|
||||
)
|
||||
|
||||
pg_rank = dist.get_rank(stage.group)
|
||||
if rank != pg_rank:
|
||||
raise ValueError(
|
||||
f"Rank {rank} is not equal to process group rank {pg_rank}"
|
||||
)
|
||||
|
||||
if (num_stages := stage.num_stages) % world_size != 0:
|
||||
raise ValueError(
|
||||
f"Number of stages ({num_stages}) must be a multiple of the world_size ({world_size})"
|
||||
)
|
||||
|
||||
# all gather each ranks inputs
|
||||
tensor_list = [
|
||||
_create_metadata_tensor(device=stage.device)
|
||||
for _ in range(stage.group_size)
|
||||
]
|
||||
expected_inputs = stage.inputs
|
||||
stage_input = _create_metadata_tensor(expected_inputs, device=stage.device)
|
||||
dist.all_gather(tensor_list, stage_input)
|
||||
stage_input_shapes = [
|
||||
_extract_metadata_from_tensor(tensor) for tensor in tensor_list
|
||||
]
|
||||
|
||||
# all gather each ranks outputs
|
||||
tensor_list = [
|
||||
_create_metadata_tensor(device=stage.device)
|
||||
for _ in range(stage.group_size)
|
||||
]
|
||||
outputs_meta = stage.get_outputs_meta()
|
||||
# TODO, (1) are we deleting output validation when we move to shape inference?
|
||||
# (2) if not, we should support multiple outputs
|
||||
assert (
|
||||
len(outputs_meta) == 1
|
||||
), f"validation logic assumes single output, got {len(outputs_meta)} outputs "
|
||||
dist.all_gather(tensor_list, outputs_meta[0])
|
||||
stage_output_shapes = [
|
||||
_extract_metadata_from_tensor(tensor) for tensor in tensor_list
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
f"Rank: {pg_rank}", # noqa: G004
|
||||
f"Stage id: {stage_id}",
|
||||
f"Stage num stages: {stage.num_stages}",
|
||||
f"Stage rank: {rank}",
|
||||
f"Stage world size: {world_size}",
|
||||
f"Stage {virtual_id * world_size}-{(virtual_id + 1) * world_size - 1} input shapes: {stage_input_shapes}", # noqa: G003
|
||||
f"Stage {virtual_id * world_size}-{(virtual_id + 1) * world_size - 1} output shapes: {stage_output_shapes}", # noqa: G003
|
||||
)
|
||||
|
||||
all_inputs.extend(stage_input_shapes)
|
||||
all_outputs.extend(stage_output_shapes)
|
||||
|
||||
# log only rank 0's view, they will all be equivalent
|
||||
if pg_rank == 0:
|
||||
logger.info(
|
||||
"all stage inputs: %s \n all stage outputs: %s", all_inputs, all_outputs
|
||||
)
|
||||
|
||||
# Check if the output for stage 0 matches the input at stage 1, and so forth
|
||||
for i in range(virtual_pipeline_size * world_size - 1):
|
||||
if (out := all_outputs[i]) != (inp := all_inputs[i + 1]):
|
||||
raise ValueError(
|
||||
f"Stage_id {i} output shape {out} at does not match stage_id {i + 1} input shape {inp}."
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue