diff --git a/torch/distributed/pipelining/stage.py b/torch/distributed/pipelining/stage.py index 3dfdc3b1784..c7a3cbdf2f1 100644 --- a/torch/distributed/pipelining/stage.py +++ b/torch/distributed/pipelining/stage.py @@ -3,7 +3,7 @@ import logging import operator from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -1211,163 +1211,6 @@ def build_stage( ) -# Manual PipelineStage functions and definition - -METADATA_TENSOR_LEN = 100 -PLACEHOLDER_VAL = -1 - - -def _create_empty_tensors( - tensor: Union[torch.Tensor, Iterable[torch.Tensor]], device: torch.device -) -> List[torch.Tensor]: - """ - Creates a list of empty tensors with the same properties (like shape and dtype) as the input tensor(s), - and places them on the specified device. - Args: - tensor (Union[torch.Tensor, List[torch.tensor]]): The input tensor(s). - device (torch.device): The device where the new tensors will be placed. - Returns: - List[torch.Tensor]: A list of empty tensors with the same properties as the input tensor(s). - """ - if isinstance(tensor, torch.Tensor): - return [torch.empty_like(tensor, device=device)] - elif isinstance(tensor, (list, tuple)): - return [torch.empty_like(t, device=device) for t in tensor] - raise TypeError(f"Unsupported type {type(tensor)} cannot create empty tensors") - - -def _create_metadata_tensor( - tensors: Optional[List[torch.Tensor]] = None, - device: Optional[torch.device] = torch.device("cpu"), -) -> torch.Tensor: - """ - Create a metadata tensor that can be sent over the wire. - This tensor contains the number of dimensions and the shape of each tensor being sent. - - The data is of format [num_dims, dim1, dim2, ...]. - If the tensor is None, a tensor of only placeholder values will be returned. - - Inputs: - tensors: A list of tensors, the tensors will converted into its shape dimensions and - these dimensions will be concatenated. - device: The device where the metadata tensor will be created. - If the tensor is None, then this tensor will contain PLACEHOLDER_VALs. - - """ - metadata_tensor = torch.full( - (METADATA_TENSOR_LEN,), - PLACEHOLDER_VAL, - dtype=torch.int32, - device=device, - ) - if tensors: - # Create a list of tensors containing the number of dimensions and the shape of each tensor - data = [ - # data is of format [num_dims, dim1, dim2, ...] - torch.tensor( - [len(tensor.shape)] + list(tensor.shape), - dtype=torch.int32, - device=device, - ) - for tensor in tensors - ] - # Concatenate the data into a single tensor - data_tensor = torch.cat(data) - dt_shape = data_tensor.shape[0] - if dt_shape > METADATA_TENSOR_LEN: - raise ValueError( - f"Metadata tensor size ({dt_shape}) exceeds maximum allowed length ({METADATA_TENSOR_LEN})." - ) - metadata_tensor[:dt_shape] = data_tensor - return metadata_tensor - - -def _extract_metadata_from_tensor(tensor: torch.Tensor) -> List[torch.Size]: - """ - Extract the number of dimensions and the shape of each tensor from a metadata tensor. - """ - metadata: List[torch.Size] = [] - i = 0 - while i < len(tensor) and tensor[i] != PLACEHOLDER_VAL: - num_dims = int(tensor[i].item()) - shape = torch.Size(tensor[i + 1 : i + 1 + num_dims].tolist()) - metadata.append(shape) - i += num_dims + 1 - return metadata - - -def _get_stage_shapes( - stage_modules: List[nn.Module], - stage_ids: List[int], - num_stages: int, - rank: int, - world_size: int, - device: torch.device, - microbatch: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, -): - """ - Performs a dry run through all the pipeline stages (a rank can have multiple pipeline stages in the case of - virtual pipelining) and returns the shape of the inputs and outputs of the module. - Only the first stage must pass in a microbatch. - - Each rank must call _get_stage_shapes or the program will hang. - - Args: - stage_modules: The chunks assigned to this rank. Rhe length should be 1 for any - non-interleaved schedules and >1 for any interleaved schedules. - stage_ids: The id of the stages assigned to this rank. - num_stages: Total number of stages. - rank: Rank of the current process. - world_size: Number of processes participating in the pipeline. - device: Device where the tensors are allocated. - - Returns a dictionary containing the following keys: - "inputs": Shape of the inputs to the module - "outputs": Shape of the outputs of the module - """ - - stage_id_to_shapes: Dict[int, Dict[str, list[torch.Size]]] = {} - for stage_id, model in zip(stage_ids, stage_modules): - input_shape_metadata_tensor = _create_metadata_tensor(device=device) - # TODO: Assumes prev_stage == rank - 1 and next_stage == rank + 1 - prev_rank = (rank - 1) % world_size - next_rank = (rank + 1) % world_size - shapes = {} - - # first stage doesn't receive anything and uses a microbatch - if stage_id == 0: - if microbatch is None: - raise RuntimeError("Microbatch is required for first stage") - example_fwd_inputs = microbatch - if isinstance(example_fwd_inputs, torch.Tensor): - example_fwd_inputs = [example_fwd_inputs] - else: - # other stages must receive shape information - # TODO: send/recv should take a group, rather than use the default group - dist.recv(input_shape_metadata_tensor, prev_rank) - metadata = _extract_metadata_from_tensor(input_shape_metadata_tensor) - example_fwd_inputs = [ - torch.empty(shape_list, device=device) for shape_list in metadata - ] - shapes["inputs"] = [fwd_input.shape for fwd_input in example_fwd_inputs] - - # perform forward - # TODO: if forward fails raise a more descriptive error explaining which stage failed - fwd_outputs = model(*example_fwd_inputs) - fwd_outputs = _create_empty_tensors(fwd_outputs, device) - shapes["outputs"] = [fwd_output.shape for fwd_output in fwd_outputs] - - # send shape dims - if stage_id != num_stages - 1: - output_shape_metadata_tensor = _create_metadata_tensor( - fwd_outputs, device=device - ) - dist.send(output_shape_metadata_tensor, next_rank) - stage_id_to_shapes[stage_id] = shapes - logger.info(stage_id_to_shapes) - return stage_id_to_shapes - - class PipelineStage(_PipelineStageBase): """ A class representing a pipeline stage in a pipeline parallelism setup. @@ -1659,100 +1502,3 @@ class PipelineStage(_PipelineStageBase): ops.append(dist.P2POp(dist.irecv, recv_tensor, self.next_rank, self.group)) return True - - -def _validate_stage_shapes(pipeline_stages: List[PipelineStage]): - """ - Check that the buffer shapes match between stages was expected by performing an all_gather between - all stages. - """ - if len(pipeline_stages) == 0: - raise ValueError("No pipeline stages provided.") - - virtual_pipeline_size = len(pipeline_stages) - all_inputs = [] - all_outputs = [] - world_size = pipeline_stages[0].group_size - num_stages = pipeline_stages[0].num_stages - - # perform all gathers between all stages - for virtual_id, stage in enumerate(pipeline_stages): - world_size = stage.group_size - stage_id: int = stage.stage_index - rank = stage.group_rank - # check that world_size and num_stages are consistent across all stages - if stage.group_size != world_size: - raise ValueError( - f"Stage id {stage_id} has world size ({stage.group_size}) \ - which does not match world size ({world_size}) of other stages." - ) - if stage.num_stages != num_stages: - raise ValueError( - f"Stage id {stage_id} has num stages ({stage.num_stages}) \ - which does not match num stages ({num_stages}) of other stages." - ) - - pg_rank = dist.get_rank(stage.group) - if rank != pg_rank: - raise ValueError( - f"Rank {rank} is not equal to process group rank {pg_rank}" - ) - - if (num_stages := stage.num_stages) % world_size != 0: - raise ValueError( - f"Number of stages ({num_stages}) must be a multiple of the world_size ({world_size})" - ) - - # all gather each ranks inputs - tensor_list = [ - _create_metadata_tensor(device=stage.device) - for _ in range(stage.group_size) - ] - expected_inputs = stage.inputs - stage_input = _create_metadata_tensor(expected_inputs, device=stage.device) - dist.all_gather(tensor_list, stage_input) - stage_input_shapes = [ - _extract_metadata_from_tensor(tensor) for tensor in tensor_list - ] - - # all gather each ranks outputs - tensor_list = [ - _create_metadata_tensor(device=stage.device) - for _ in range(stage.group_size) - ] - outputs_meta = stage.get_outputs_meta() - # TODO, (1) are we deleting output validation when we move to shape inference? - # (2) if not, we should support multiple outputs - assert ( - len(outputs_meta) == 1 - ), f"validation logic assumes single output, got {len(outputs_meta)} outputs " - dist.all_gather(tensor_list, outputs_meta[0]) - stage_output_shapes = [ - _extract_metadata_from_tensor(tensor) for tensor in tensor_list - ] - - logger.debug( - f"Rank: {pg_rank}", # noqa: G004 - f"Stage id: {stage_id}", - f"Stage num stages: {stage.num_stages}", - f"Stage rank: {rank}", - f"Stage world size: {world_size}", - f"Stage {virtual_id * world_size}-{(virtual_id + 1) * world_size - 1} input shapes: {stage_input_shapes}", # noqa: G003 - f"Stage {virtual_id * world_size}-{(virtual_id + 1) * world_size - 1} output shapes: {stage_output_shapes}", # noqa: G003 - ) - - all_inputs.extend(stage_input_shapes) - all_outputs.extend(stage_output_shapes) - - # log only rank 0's view, they will all be equivalent - if pg_rank == 0: - logger.info( - "all stage inputs: %s \n all stage outputs: %s", all_inputs, all_outputs - ) - - # Check if the output for stage 0 matches the input at stage 1, and so forth - for i in range(virtual_pipeline_size * world_size - 1): - if (out := all_outputs[i]) != (inp := all_inputs[i + 1]): - raise ValueError( - f"Stage_id {i} output shape {out} at does not match stage_id {i + 1} input shape {inp}." - )