mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145141 Approved by: https://github.com/bobrenjc93
458 lines
17 KiB
Python
458 lines
17 KiB
Python
# mypy: allow-untyped-defs
|
|
from collections.abc import Sequence
|
|
from dataclasses import dataclass
|
|
from functools import cached_property
|
|
from typing import Any, Optional, Union
|
|
|
|
import torch
|
|
from torch._ops import OpOverload
|
|
from torch.distributed.device_mesh import DeviceMesh
|
|
from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
|
from torch.distributed.tensor.placement_types import Placement
|
|
|
|
|
|
try:
|
|
from torch.utils._cxx_pytree import tree_leaves, tree_map_only, TreeSpec
|
|
except ImportError:
|
|
from torch.utils._pytree import ( # type: ignore[no-redef, assignment]
|
|
tree_leaves,
|
|
tree_map_only,
|
|
TreeSpec,
|
|
)
|
|
|
|
|
|
# Common type aliases
|
|
ArgsType = tuple[object, ...]
|
|
KwargsType = dict[str, object]
|
|
|
|
PlacementList = list[Optional[Placement]]
|
|
|
|
# ATen op schemas could have Tensor, Tuple[Tensor] and List[Tensor], so output type sould
|
|
# be the same set of possibilities.
|
|
OutputSpecType = Optional[Union[DTensorSpec, Sequence[Optional[DTensorSpec]]]]
|
|
|
|
|
|
def _rebuild_tensor_from_dtensor_meta(arg) -> object:
|
|
"""
|
|
This is used to propagate tensor metadata, must be under fake mode
|
|
"""
|
|
assert arg.tensor_meta is not None, "DTensorSpec does not contain tensor_meta."
|
|
return torch.empty_strided(
|
|
arg.tensor_meta.shape,
|
|
arg.tensor_meta.stride,
|
|
dtype=arg.tensor_meta.dtype,
|
|
)
|
|
|
|
|
|
def _is_inplace_op(op: OpOverload):
|
|
# simple analysis of function schema to determine
|
|
# if this is an inplace variant, it might not
|
|
# be entirely correct, but it's good enough for now.
|
|
return op._schema.name[-1] == "_"
|
|
|
|
|
|
def _is_out_variant_op(op: OpOverload):
|
|
# simple analysis of function schema to determine
|
|
# if this is an out variant, it might not
|
|
# be entirely correct, but it's good enough for now.
|
|
return "out" in op._schema.overload_name
|
|
|
|
|
|
def _pretty_print_spec(spec: object) -> str:
|
|
if spec is None:
|
|
return "None"
|
|
elif isinstance(spec, DTensorSpec):
|
|
return "".join([str(p) for p in spec.placements])
|
|
elif isinstance(spec, Sequence):
|
|
return "(" + ", ".join([_pretty_print_spec(s) for s in spec]) + ")"
|
|
else:
|
|
raise RuntimeError(f"Unknown spec type to print: spec={spec}")
|
|
|
|
|
|
@dataclass
|
|
class PlacementStrategy:
|
|
"""
|
|
A placement strategy describes acceptable sharding placements of the output
|
|
and the tensor arguments of an operation.
|
|
|
|
note: when the op return value is a single DTensor object, output_specs is
|
|
DTensorSpec; when the return value is a tuple of Optional[DTensor],
|
|
output_specs is a tuple of Optional[DTensorSpec].
|
|
"""
|
|
|
|
output_specs: Union[DTensorSpec, tuple[Optional[DTensorSpec], ...]]
|
|
input_specs: Optional[Sequence[DTensorSpec]] = None
|
|
|
|
# redistribute costs for this op placement strategy
|
|
# we need a nested list to record the cost for each
|
|
# operand of this operator, and for each operand of
|
|
# this operator it might have multiple placement strategies
|
|
redistribute_cost: Optional[list[list[float]]] = None
|
|
|
|
@cached_property
|
|
def output_spec(self) -> DTensorSpec:
|
|
"""
|
|
This function requires that the strategy have exactly one DTensorSpec as the
|
|
output spec. If the output_specs is a tuple, we throw an exception.
|
|
"""
|
|
if isinstance(self.output_specs, DTensorSpec):
|
|
return self.output_specs
|
|
else:
|
|
raise ValueError(
|
|
f"function output_spec expects a single DTensorSpec but got: {self.output_specs}"
|
|
)
|
|
|
|
def input_spec(self, index: int = 0) -> DTensorSpec:
|
|
assert self.input_specs is not None, "input_specs of PlacementStrategy is None!"
|
|
assert len(self.input_specs) > index, (
|
|
f"Invalid index {index} for input_specs of length "
|
|
f"{len(self.input_specs)}: {self.input_specs}"
|
|
)
|
|
return self.input_specs[index]
|
|
|
|
def __str__(self) -> str:
|
|
if self.input_specs is not None:
|
|
input_specs_str = f"{_pretty_print_spec(self.input_specs)} -> "
|
|
else:
|
|
input_specs_str = ""
|
|
output_spec_str = _pretty_print_spec(self.output_specs)
|
|
return f"{input_specs_str}{output_spec_str}"
|
|
|
|
|
|
class StrategyType:
|
|
"""
|
|
Base class type for op strategy, We have two StrategyType:
|
|
OpStrategy and TupleStrategy
|
|
"""
|
|
|
|
|
|
class OpStrategy(StrategyType):
|
|
"""
|
|
OpStrategy that consists of a list of placement strategies associated with the op
|
|
"""
|
|
|
|
def __init__(self, strategies: list[PlacementStrategy]) -> None:
|
|
super().__init__()
|
|
self.strategies: list[PlacementStrategy] = strategies
|
|
|
|
def __str__(self) -> str:
|
|
strategy_list_str = ", ".join([str(strategy) for strategy in self.strategies])
|
|
mesh_shape = self.mesh_shape
|
|
return f"[{strategy_list_str}] @ mesh: {mesh_shape}"
|
|
|
|
def max_num_shards(self) -> int:
|
|
"""
|
|
Returns the max number of shards across all placement strategies
|
|
"""
|
|
return max(strategy.output_spec.num_shards for strategy in self.strategies)
|
|
|
|
@property
|
|
def mesh_shape(self):
|
|
output_spec = self.strategies[0].output_specs
|
|
if isinstance(output_spec, DTensorSpec):
|
|
return output_spec.mesh.shape
|
|
else:
|
|
assert isinstance(
|
|
output_spec, tuple
|
|
), "found no DTensorSpec in the OpStrategy!"
|
|
assert output_spec[0] is not None
|
|
return output_spec[0].mesh.shape
|
|
|
|
@property
|
|
def ndim(self):
|
|
return self.strategies[0].output_spec.ndim
|
|
|
|
@property
|
|
def shape(self):
|
|
return self.strategies[0].output_spec.shape
|
|
|
|
|
|
class TupleStrategy(StrategyType):
|
|
"""
|
|
TupleStrategy represents the output strategy of this op is a tuple
|
|
of strategy, i.e. If the output of this op is a tuple of tensors or list of tensors
|
|
with possibly different placement strategies, we should return a TupleStrategy that
|
|
contains a tuple of OpStrategy, where each child represents the sharding strategy
|
|
of "each element" of the tuple/list of tensors the op returns.
|
|
|
|
NOTE: if the output of the op is a List[Tensor] and they share the same placement
|
|
strategy, then we should return a single OpStrategy instead of a TupleStrategy
|
|
"""
|
|
|
|
def __init__(self, childs: Sequence[StrategyType]) -> None:
|
|
super().__init__()
|
|
self.childs: Sequence[StrategyType] = childs
|
|
|
|
def __str__(self) -> str:
|
|
child_strategies_str = ", ".join(
|
|
[f"{str(strat)}" for idx, strat in enumerate(self.childs)]
|
|
)
|
|
return f"TupleStrategy({child_strategies_str})"
|
|
|
|
|
|
@dataclass
|
|
class RuntimeSchemaInfo:
|
|
"""
|
|
RuntimeSchemaInfo stores the operator schema related information for runtime (eager)
|
|
execution. This is mainly used for two ways: 1. to generate hash for args to determine
|
|
whether to re-run sharding prop or not 2. to determine if we need pytree
|
|
"""
|
|
|
|
# This static_argnum records static arg "starting index" for ops that have non-tensor
|
|
# args/kwargs which would affect sharding propagation results. All args starting from
|
|
# this index would be hashed to our sharding cache.
|
|
# Note that only a few ops need this information, e.g. view, transpose, var.dim, etc.
|
|
static_argnum: int = 100
|
|
# This static_kwargkey records static kwarg names which would affect sharding prop
|
|
static_kwargkey: Optional[list[str]] = None
|
|
# each op can decide if it wants to use pytree flatten/unflatten during operator
|
|
# eager execution, by default we don't need to do flatten/unflatten, only if the
|
|
# op indicate it needs to, this is to accelerate eager performance.
|
|
needs_pytree: bool = False
|
|
|
|
|
|
@dataclass
|
|
class OpSchema:
|
|
"""
|
|
OpSchema is a data class that describes an operator input schemas, it includes
|
|
DTensorSpecs (instead of DTensor) and non-tensor args/kwargs (positional order
|
|
preserved). It is mainly used by the DTensor's dispatching logic to perform various
|
|
actions (i.e. sharding propagation, caching sharding decisions, redistribute, etc.)
|
|
|
|
NOTE: this should be used as a read only data class
|
|
TODO: make this a frozen dataclass
|
|
|
|
Args:
|
|
op: the operator overload we are intercepting
|
|
args_schema: contains args except that the DTensor args have been replaced
|
|
with its DTensorSpec or OpStrategy
|
|
kwargs_schema: contains kwargs except that the DTensor kwargs have been replaced
|
|
with its DTensorSpec or OpStrategy
|
|
"""
|
|
|
|
op: OpOverload
|
|
args_schema: ArgsType
|
|
kwargs_schema: KwargsType
|
|
|
|
schema_info: Optional[RuntimeSchemaInfo] = None
|
|
|
|
@property
|
|
def args_spec(self) -> tuple[DTensorSpec, ...]:
|
|
"""
|
|
args_spec: Tuple[DTensorSpec, ...]: contains a clean list of args spec list
|
|
with NO non-DTensor positional arguments (i.e. int/float/tuple, etc)
|
|
mainly used by sharding propagation to propagate the output spec
|
|
"""
|
|
args = (
|
|
tree_leaves(self.args_schema)
|
|
if self.schema_info is not None and self.schema_info.needs_pytree
|
|
else self.args_schema
|
|
)
|
|
return tuple(item for item in args if isinstance(item, DTensorSpec))
|
|
|
|
@property
|
|
def args_strategy(self) -> tuple[OpStrategy, ...]:
|
|
# filter out non-relevant values from args schema to get a clean OpStrategy list
|
|
# separate with args_spec for the ease of type annotation
|
|
# TODO: see if we should merge this with args_spec
|
|
args = (
|
|
tree_leaves(self.args_schema)
|
|
if self.schema_info is not None and self.schema_info.needs_pytree
|
|
else self.args_schema
|
|
)
|
|
return tuple(item for item in args if isinstance(item, OpStrategy))
|
|
|
|
def __repr__(self) -> str:
|
|
args_schema = ", ".join([str(arg_schema) for arg_schema in self.args_schema])
|
|
return (
|
|
f"OpSchema(op={self.op},"
|
|
f" args_schema=({args_schema}),"
|
|
f" kwargs_schema={self.kwargs_schema})"
|
|
)
|
|
|
|
def __str__(self) -> str:
|
|
args_schema: list[str] = []
|
|
mesh_shape = None
|
|
for arg in self.args_schema:
|
|
if isinstance(arg, DTensorSpec):
|
|
args_schema.append(str(arg))
|
|
mesh_shape = arg.mesh.shape
|
|
elif isinstance(arg, OpStrategy):
|
|
assert len(arg.strategies) == 1
|
|
args_schema.append(_pretty_print_spec(arg.strategies[0].output_specs))
|
|
mesh_shape = arg.mesh_shape
|
|
elif isinstance(arg, TupleStrategy):
|
|
first_op_strtgy = arg.childs[0]
|
|
assert isinstance(first_op_strtgy, OpStrategy)
|
|
mesh_shape = first_op_strtgy.mesh_shape
|
|
args_schema.append(str(arg))
|
|
else:
|
|
args_schema.append(str(arg))
|
|
return f"Op(op={self.op}, args_schema={', '.join(args_schema)} @ mesh: {mesh_shape})"
|
|
|
|
def __post_init__(self) -> None:
|
|
has_symints = False
|
|
for a in self.args_schema:
|
|
if isinstance(a, DTensorSpec) and a.tensor_meta is not None:
|
|
if any(isinstance(s, torch.SymInt) for s in a.tensor_meta.shape):
|
|
has_symints = True
|
|
break
|
|
self.has_symints = has_symints
|
|
|
|
def arg_type_tensor_or_tensor_list_like(self, arg_idx: int) -> bool:
|
|
arg = self.args_schema[arg_idx]
|
|
is_tensor = isinstance(arg, DTensorSpec)
|
|
if is_tensor:
|
|
return True
|
|
|
|
if not isinstance(arg, list):
|
|
return False
|
|
|
|
return all(isinstance(e, DTensorSpec) or e is None for e in arg)
|
|
|
|
def return_type_tuple_tensor_like(self) -> bool:
|
|
# all dispatch ops could only return Tuple[Tensor] or have None/ints/floats
|
|
# in the tuple, but the first element must be a Tensor, so this check is enough
|
|
return_types = self.op._schema.returns
|
|
return len(return_types) > 1 and isinstance(
|
|
return_types[0].type, torch.TensorType
|
|
)
|
|
|
|
def return_type_tensor(self) -> bool:
|
|
return_types = self.op._schema.returns
|
|
# all dispatch ops only return Tensor or Tuple[Tensor] for tensor like
|
|
# return types, so this check is enough for tensor like types
|
|
return isinstance(return_types[0].type, torch.TensorType)
|
|
|
|
def __hash__(self) -> int:
|
|
# Only hash args and kwargs that op indicates to hash
|
|
if not self.schema_info:
|
|
static_argnum = len(self.args_schema)
|
|
static_kwargkey = None
|
|
else:
|
|
static_argnum = self.schema_info.static_argnum
|
|
static_kwargkey = self.schema_info.static_kwargkey
|
|
|
|
args_to_hash = tuple(
|
|
tuple(e) if isinstance(e, list) else e
|
|
for i, e in enumerate(self.args_schema)
|
|
if self.arg_type_tensor_or_tensor_list_like(i) or i >= static_argnum
|
|
)
|
|
if static_kwargkey is not None:
|
|
kwargs_to_hash = tuple(
|
|
self.kwargs_schema.get(k, None) for k in static_kwargkey
|
|
)
|
|
return hash((self.op, args_to_hash, kwargs_to_hash))
|
|
else:
|
|
return hash((self.op, args_to_hash))
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
# early return checks
|
|
if not isinstance(other, OpSchema):
|
|
return False
|
|
|
|
if self.op != other.op:
|
|
return False
|
|
|
|
if len(self.args_schema) != len(other.args_schema):
|
|
return False
|
|
|
|
# compare each element and early return if any of them is different
|
|
if not self.schema_info:
|
|
static_argnum = len(self.args_schema)
|
|
static_kwargkey = None
|
|
else:
|
|
static_argnum = self.schema_info.static_argnum
|
|
static_kwargkey = self.schema_info.static_kwargkey
|
|
|
|
for i, (self_arg, other_arg) in enumerate(
|
|
zip(self.args_schema, other.args_schema)
|
|
):
|
|
if isinstance(self_arg, DTensorSpec) and self_arg != other_arg:
|
|
return False
|
|
elif i >= static_argnum and self_arg != other_arg:
|
|
return False
|
|
|
|
# check kwarg equality when there's a static kwarg key
|
|
if static_kwargkey:
|
|
for key in static_kwargkey:
|
|
if self.kwargs_schema.get(key, None) != other.kwargs_schema.get(
|
|
key, None
|
|
):
|
|
return False
|
|
|
|
return True
|
|
|
|
def gen_fake_args(self) -> ArgsType:
|
|
"""
|
|
gen_fake_args: generate fake args for the operator, this is mainly used
|
|
by sharding propagation rules to generate fake args for the operator
|
|
to run the local tensor operator and get the output spec.
|
|
"""
|
|
return tree_map_only(
|
|
DTensorSpec, _rebuild_tensor_from_dtensor_meta, self.args_schema
|
|
)
|
|
|
|
def gen_fake_kwargs(self) -> KwargsType:
|
|
"""
|
|
gen_fake_kwargs: generate fake kwargs for the operator, this is mainly used
|
|
by sharding propagation rules to generate fake kwargs for the operator
|
|
to run the local tensor operator and get the output spec.
|
|
"""
|
|
return tree_map_only(
|
|
DTensorSpec, _rebuild_tensor_from_dtensor_meta, self.kwargs_schema
|
|
)
|
|
|
|
def _inplace_rewrap_schema_suggestion(self, origin_schema: "OpSchema") -> None:
|
|
suggestion_args_spec = self.args_spec
|
|
new_arg_schema: list[object] = []
|
|
idx_of_args_spec = 0
|
|
if (
|
|
origin_schema.schema_info is not None
|
|
and origin_schema.schema_info.needs_pytree
|
|
):
|
|
args_schema: Sequence[Any] = tree_leaves(origin_schema.args_schema)
|
|
else:
|
|
args_schema = origin_schema.args_schema
|
|
for arg in args_schema:
|
|
if isinstance(arg, DTensorSpec):
|
|
new_arg_schema.append(suggestion_args_spec[idx_of_args_spec])
|
|
idx_of_args_spec += 1
|
|
else:
|
|
new_arg_schema.append(arg)
|
|
self.args_schema = tuple(new_arg_schema)
|
|
self.kwargs_schema = origin_schema.kwargs_schema
|
|
|
|
|
|
@dataclass
|
|
class OutputSharding:
|
|
"""
|
|
OutputSharding is a data class that is used by the sharding propagation,
|
|
it could set the output_spec upon successful propagation. If needs_redistribute
|
|
is set to True, a redistribute_schema would be returned together to indicate
|
|
the input arguments needs to be redistributed before the op execution.
|
|
|
|
NOTE: the redistribute_schema generated by sharding propagation should be
|
|
exactly the same as the operator OpSchema, except the DTensorSpecs
|
|
"""
|
|
|
|
output_spec: OutputSpecType
|
|
redistribute_schema: Optional[OpSchema] = None
|
|
needs_redistribute: bool = False
|
|
|
|
|
|
@dataclass
|
|
class OpInfo:
|
|
"""
|
|
All Runtime Op execution info are packed here
|
|
"""
|
|
|
|
mesh: DeviceMesh
|
|
schema: OpSchema
|
|
flat_args_schema: list[object]
|
|
local_args: Sequence[object]
|
|
local_kwargs: dict[str, object]
|
|
args_tree_spec: Optional[TreeSpec] = None
|
|
|
|
# the output sharding info
|
|
output_sharding: Optional[OutputSharding] = None
|