mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
This reverts commit e84d1121ad.
Reverted https://github.com/pytorch/pytorch/pull/127690 on behalf of https://github.com/ZainRizvi due to Sorry but this is breaking internally. More details in D65483292 ([comment](https://github.com/pytorch/pytorch/pull/127690#issuecomment-2458381056))
67 lines
2.3 KiB
Python
67 lines
2.3 KiB
Python
# mypy: allow-untyped-defs
|
|
import warnings
|
|
from typing import Tuple, Union
|
|
|
|
from torch.distributed.device_mesh import _mesh_resources
|
|
from torch.distributed.tensor import DeviceMesh
|
|
from torch.distributed.tensor.placement_types import Placement
|
|
|
|
|
|
try:
|
|
from torch._dynamo.external_utils import is_compiling as is_torchdynamo_compiling
|
|
except Exception:
|
|
|
|
def is_torchdynamo_compiling(): # type: ignore[misc]
|
|
return False
|
|
|
|
|
|
LayoutsType = Union[Placement, Tuple[Placement, ...]]
|
|
|
|
|
|
def _deprecate_warnings(func_name: str, extra_msg: str) -> None:
|
|
"""
|
|
Inject common validation logics for `_prepare_input` funcs via this decorator.
|
|
|
|
Include verifying that input needs to be either a :class:`Tensor` or :class:`DTensor`
|
|
and only 1D :class:`DeviceMesh` is passed in.
|
|
"""
|
|
# TODO: Will follow up with dynamo POC to make warnings.warn working with dynamo.
|
|
if not is_torchdynamo_compiling():
|
|
warnings.warn(
|
|
f"{func_name} is deprecated and will be removed soon. {extra_msg}",
|
|
FutureWarning,
|
|
stacklevel=3,
|
|
)
|
|
|
|
|
|
def _validate_tp_mesh_dim(
|
|
device_mesh: DeviceMesh,
|
|
) -> None:
|
|
"""
|
|
Check whether TP mesh dimension is valid or not.
|
|
|
|
Args:
|
|
device_mesh (:class:`DeviceMesh`):
|
|
The `device_mesh` where we perform
|
|
Tensor Parallelism on.
|
|
|
|
Return:
|
|
`True` if the mesh dimension
|
|
is valid, `False` otherwise.
|
|
"""
|
|
if device_mesh.ndim > 1:
|
|
raise ValueError(
|
|
f"Tensor Parallel only accepts a 1D DeviceMesh, but found {device_mesh.ndim}D!"
|
|
'If you have a 2-D or N-D device_mesh, consider passing in device_mesh["tp"]'
|
|
)
|
|
|
|
root_mesh = _mesh_resources.get_root_mesh(device_mesh)
|
|
# if a root mesh is not the same as device_mesh,
|
|
# meaning the device_mesh is sliced out from the root mesh.
|
|
if root_mesh and root_mesh != device_mesh:
|
|
tp_mesh_dim_in_root = _mesh_resources.get_root_mesh_dim(device_mesh)
|
|
if tp_mesh_dim_in_root != root_mesh.ndim - 1:
|
|
raise RuntimeError(
|
|
f"Found TP device_mesh on the {tp_mesh_dim_in_root} dimension of its parent mesh.",
|
|
"Currently we only support intranode TP and TP needs to be the innermost dimension on its parent mesh.",
|
|
)
|