diff --git a/test/distributed/_spmd/test_tracing.py b/test/distributed/_spmd/test_tracing.py index a2b4b3a9b1c..000d2e86f61 100644 --- a/test/distributed/_spmd/test_tracing.py +++ b/test/distributed/_spmd/test_tracing.py @@ -2,17 +2,28 @@ from copy import deepcopy from functools import wraps -from typing import List +from typing import Any, Dict, List import numpy as np import torch +import torch.distributed as dist +import torch.fx as fx import torch.nn as nn -from torch.distributed._spmd.api import Schema, SPMD +from torch.distributed._spmd.api import ( + Override, + Schema, + SPMD, + compile, +) from torch.distributed._spmd.comm_tensor import CommTensor from torch.distributed._tensor import DeviceMesh, Replicate +from torch.distributed._tensor.ops.utils import register_prop_rule +from torch.distributed._tensor.op_schema import OpSchema, OutputSharding +from torch.distributed._tensor.placement_types import DTensorSpec from torch.distributed.distributed_c10d import get_global_rank, get_world_size from torch.fx.experimental.proxy_tensor import make_fx from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, @@ -369,5 +380,234 @@ class TraceModuleTest(DTensorTestBase): ) +class DataDependentModule(nn.Module): + def __init__(self, world_size): + super().__init__() + self.world_size = world_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise RuntimeError( + "This eager implementation shouldn't be executed." + "This implementation is just an example of how to get around " + "data-dependant user-defined modules. " + ) + shape = x.shape + x = x.view(-1) + positive = x[x >= 0] + negative = x[x < 0] + + in_sizes = torch.tensor( + [positive.numel(), negative.numel()], dtype=torch.int32 + ) + out_sizes = torch.empty_like(in_sizes) + dist.all_to_all_single( + out_sizes, + in_sizes, + output_split_sizes=[1, 1], + input_split_sizes=[1, 1], + ) + + xs = [positive, negative] + ys = [ + torch.Tensor(out_sizes[i].item()) for i in range(out_sizes.numel()) + ] + dist.all_to_all(ys, xs) + + # some dummy compute + for y in ys: + y.add_(1) + + dist.all_to_all(xs, ys) + + return torch.cat(xs).reshape(shape) + + +class DummyModel(nn.Module): + def __init__(self, world_size): + super().__init__() + self.l1 = nn.Linear(10, 10) + self.ddm = DataDependentModule(world_size) + self.l2 = nn.Linear(10, 10) + self.relu = nn.ReLU() + + def forward(self, x): + assert len(x.size()) == 2 + + return self.relu(self.l2(self.ddm(self.l1(x)))) + + +def ddm(x: torch.Tensor) -> torch.Tensor: + return x + + +def ddm_backward(grad: torch.Tensor) -> torch.Tensor: + return grad + + +dummy_lib = torch.library.Library("dummy", "DEF") +dummy_lib.define("ddm(Tensor x) -> Tensor") +dummy_lib.impl("ddm", ddm, "CompositeExplicitAutograd") +dummy_lib.define("ddm_backward(Tensor x) -> Tensor") +dummy_lib.impl("ddm_backward", ddm_backward, "CompositeExplicitAutograd") + + +def _identity_prop_rule(op_schema: OpSchema) -> OutputSharding: + (x,) = op_schema.args_schema + assert isinstance(x, DTensorSpec), f"expecting DTensorSpec but got {x}" + + return OutputSharding(output_spec=DTensorSpec(x.mesh, x.placements)) + + +@register_prop_rule(torch.ops.dummy.ddm.default) +def _prop_ddm(op_schema: OpSchema) -> OutputSharding: + return _identity_prop_rule(op_schema) + + +@register_prop_rule(torch.ops.dummy.ddm_backward.default) +def _prop_ddm_backward(op_schema: OpSchema) -> OutputSharding: + return _identity_prop_rule(op_schema) + + +class DDMFunction(torch.autograd.Function): + @staticmethod + def forward(ctx: Any, x: torch.Tensor) -> torch.Tensor: + return torch.ops.dummy.ddm(x) + + @staticmethod + def backward(ctx: Any, grad_x: torch.Tensor) -> torch.Tensor: + return torch.ops.dummy.ddm_backward(grad_x) + + +class DummyDDM(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return DDMFunction.apply(x) + + +class TraceTrainStepTest(DTensorTestBase): + @property + def world_size(self): + return 2 + + @skip_if_lt_x_gpu(2) + @with_comms + def test_train_step_simple(self): + @compile() + def train_step(mod, inp): + mod(inp).sum().backward() + return [p.grad for p in mod.parameters()] + + rank = torch.distributed.get_rank() + inp = torch.randn(2, 10).cuda(rank) + # FIXME(@mrshenli): remove manual seed once dist.compile can synchronize + # module parameters. + torch.manual_seed(0) + mod = nn.Linear(10, 10).cuda(rank) + + ddp_mod = DDP(deepcopy(mod), device_ids=[rank]) + ddp_inp = deepcopy(inp) + + grads = train_step(mod, inp) + ddp_mod(ddp_inp).sum().backward() + + for g1, p2 in zip(grads, ddp_mod.parameters()): + # FIXME(@mrshenli): DDP by default divides gradients by world size. + # Should we match that behavior? + self.assertEqual(g1 / self.world_size, p2.grad) + + @skip_if_lt_x_gpu(2) + @with_comms + def test_sgd(self): + @compile() + def train_step(mod, opt, inp): + mod(inp).sum().backward() + opt.step() + + rank = torch.distributed.get_rank() + mod = nn.Linear(10, 10).cuda(rank) + # FIXME(@mrshenli): we have to enable foreach to get better perf + opt = torch.optim.SGD(mod.parameters(), lr=0.01, foreach=False) + inp = torch.zeros(2, 10).cuda(rank) + + mod(inp).sum().backward() + opt.step() + + # FIXME(@mrshenli): inplace op + DTensor does not trigger allreduce + train_step(mod, opt, inp) + + @skip_if_lt_x_gpu(2) + @with_comms + def test_adam(self): + @compile() + def train_step(mod, opt, inp): + mod(inp).sum().backward() + opt.step() + + rank = torch.distributed.get_rank() + mod = nn.Linear(10, 10).cuda(rank) + opt = torch.optim.Adam( + mod.parameters(), lr=0.01, foreach=False, capturable=True + ) + inp = torch.zeros(2, 10).cuda(rank) + + mod(inp).sum().backward() + opt.step() + + # FIXME(@mrshenli): inplace op + DTensor does not trigger allreduce + train_step(mod, opt, inp) + + @skip_if_lt_x_gpu(2) + @with_comms + def test_train_step_override(self): + transform_targets = [] + + class DDMOverride(Override): + def replacement( + self, orig_submodule: torch.nn.Module + ) -> torch.nn.Module: + return DummyDDM() + + def transform( + self, gm: fx.GraphModule, schema_map: Dict[str, Schema] + ) -> fx.Graph: + nonlocal transform_targets + for node in gm.graph.nodes: + if node.target in [ + torch.ops.dummy.ddm.default, + torch.ops.dummy.ddm_backward.default, + ]: + transform_targets.append(node.target) + # N.B.: this is not a complete subgraph representing + # original logic, as we are testing the ability to + # modify graph after DTensor expansion. + with gm.graph.inserting_before(node): + new_node = gm.graph.call_function( + torch.add, args=node.args + ) + node.replace_all_uses_with(new_node) + + gm.graph.lint() + gm.graph.eliminate_dead_code() + + return gm + + @compile(module_override={DataDependentModule: DDMOverride()}) + def train_step(mod, opt, inp): + mod(inp).sum().backward() + opt.step() + + rank = torch.distributed.get_rank() + mod = DummyModel(self.world_size).cuda(rank) + opt = torch.optim.SGD(mod.parameters(), lr=0.01, foreach=False) + # FIXME: symbolic tracing treats bs=1 as constant, have to use bs > 1. + inp = torch.randn(4, 10).cuda(rank) + train_step(mod, opt, inp) + + # checking transforms are indeed invoked. + self.assertEqual(transform_targets, [torch.ops.dummy.ddm.default, torch.ops.dummy.ddm_backward.default]) + + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/_spmd/api.py b/torch/distributed/_spmd/api.py index 9a38abeba9c..03ada510752 100644 --- a/torch/distributed/_spmd/api.py +++ b/torch/distributed/_spmd/api.py @@ -1,10 +1,39 @@ -from typing import Dict, Optional, Sequence, Tuple +from abc import ABC, abstractmethod +from contextlib import contextmanager, nullcontext +from copy import copy +from functools import wraps, partial +from typing import ( + Any, + Callable, + Dict, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) +import torch import torch.distributed as dist import torch.nn as nn -from torch.distributed._spmd.distribute import distribute, Schema +import torch.utils._pytree as pytree +from torch import fx +from torch.distributed._spmd.distribute import ( + _convert_to_distributed, + distribute, + Schema, +) from torch.distributed._spmd.distributed_graph import DistributedGraph -from torch.distributed._tensor import Placement, Replicate +from torch.distributed._tensor import ( + DeviceMesh, + Placement, + Replicate, + Shard, +) +from torch.nn.utils import stateless +from functorch import make_fx +from torch.nn.utils._named_member_accessor import NamedMemberAccessor class SPMD(nn.Module): @@ -53,3 +82,253 @@ class SPMD(nn.Module): assert self._compiled_m is not None return self._compiled_m(*args, **kwargs) + + +class Override(ABC): + r""" + Override the tracing and transformation behavior of :meth:`~torch.distributed._spmd.compile`. + This is useful when any part of the model is not traceable or if you prefer + to not trace it due to any reason. More specifically, users can implement + :meth:`torch.distributed._spmd.Override.replacement` to replace an original + submodule with the return new submodule. The new submodule contrains + operations that users preferred to be traced, which simply be a dummy + placeholder operator. After tracing, users can implement + :meth:`torch.distributed._spmd.Override.transform` to transform the traced + graph, where the dummy placeholder operator serves as an anchor to insert + new sub-graphs. + """ + + @abstractmethod + def replacement(self, orig_submodule: torch.nn.Module) -> torch.nn.Module: + r""" + Implement this method to return a new :class:`nn.Module` instance to + replace the ``orig_submodule`` argument in the model. This helps if + ``orig_submodule`` is not traceable or should not be traced. + + Args: + orig_submodule (class:`nn.Module`): original submodule instance to replace. + + Returns: + A new :class:`nn.Module` instance to replace the original one. + """ + pass + + @abstractmethod + def transform( + self, gm: fx.GraphModule, schema_map: Dict[str, Schema] + ) -> fx.Graph: + r""" + Given a DTensor-expanded graph and shardig schema for every node, + conduct additional transformation for the sub-graph from the :class:`nn.Module` + returned by :meth:`torch.distributed._spmd.Override.replacement` if + necessary. + + Args: + gm (:class:`fx.Graph`): a DTensor-expanded graph. + schema_map (Dict[str, :class:`Schema`]): a dictionary maps from node + name to DTensor schema. + + Returns: + The :class:`fx.Graph` after transformation. + """ + pass + + +def _dtensor_expand( + gm: fx.GraphModule, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + named_states: Dict[str, Any], + params_and_buffers: Dict[str, Any], +) -> Tuple[fx.GraphModule, Dict[str, Schema]]: + flat_args, _ = pytree.tree_flatten(list(args) + list(kwargs.values())) + + mesh = DeviceMesh("cuda", torch.arange(dist.get_world_size()).cuda()) + shard_schema: Schema = Schema(mesh=mesh, placements=[Shard(0)]) + # FIXME: allow other sharding schemas + replicate_schema: Schema = Schema(mesh=mesh, placements=[Replicate()]) + + inps, schemas = [], [] + for a in flat_args: + if isinstance(a, torch.Tensor): + inps.append(a) + schemas.append(shard_schema) + elif isinstance(a, nn.Module) or isinstance(a, torch.optim.Optimizer): + # nn.Module or optimizer placeholder is captured by make_fx but + # never used in the graph + inps.append(torch.empty(0)) + schemas.append(shard_schema) + + for o in pytree.tree_flatten(named_states)[0]: + if isinstance(o, torch.Tensor): + inps.append(o) + schemas.append(replicate_schema) + else: + inps.append(torch.empty(0)) + schemas.append(replicate_schema) + + for p in pytree.tree_flatten(params_and_buffers)[0]: + assert isinstance( + p, torch.Tensor + ), f"expecting Tensor but got {type(p)}" + inps.append(p) + schemas.append(replicate_schema) + + return _convert_to_distributed(gm, inps, schemas, _allow_partial=False) + + +@contextmanager +def _rematerialize_optimizer( + opt: torch.optim.Optimizer, + named_states: Dict[str, Any], + params: Dict[str, nn.Parameter], +): + assert opt is not None + + # update opt.state with proxy tensors + orig_states: Dict[str, Any] = copy(opt.state) + for n in named_states: + # opt.state's key type is string, but optimizer uses Parameter as keys + opt.state[params[n]] = named_states[n] # type: ignore[index] + + # FIXME: support multiple parameter groups + param_group = opt.param_groups[0] + orig_params = param_group["params"] + # FIXME(@mrshenli): exclude buffers + param_group["params"] = params.values() + + try: + yield + finally: + param_group["params"] = orig_params + opt.state.update(orig_states) + + +@contextmanager +def _enable_compile(): + # The return value of torch._utils.is_compiling changes optimizer behavior. + # We need that function to return True to include optimizer in the graph. + # See: https://github.com/pytorch/pytorch/blob/a524123c91ab399c9dd6882c1189596dd77e7734/torch/optim/optimizer.py#L41 + def f_true(): + return True + + orig_is_compiling_code = torch._utils.is_compiling.__code__ + torch._utils.is_compiling.__code__ = f_true.__code__ + try: + yield + finally: + torch._utils.is_compiling.__code__ = orig_is_compiling_code + + +def compile(module_override: Optional[Dict[Type[Any], Override]] = None): + r""" + Compile and optimize a callable, which can be a train step within a training + loop. This method will extract :class:`nn.Module` and :class:`torch.optim.Optimizer` + instances from the input arguments and trace operations applied to their + parameters and states. + + Args: + module_override (Optional[Dict[Type[Any], Override]]): a dictionary maps + from target :class:`nn.Module` types to :class:`Override` objects. + The :class:`Override` objects provide :class:`nn.Module` replacements + during tracing and a graph transformation function after tracing. + (Default: ``None``) + """ + + def inner(func: Callable): + @wraps(func) + def wrapper(*args, **kwargs): + # 1. Extract nn.Module and Optimizer from args and kwargs + # FIXME(@mrshenli): support multiple nn.Module instances + # FIXME(@mrshenli): support multiple Optiimzer instances + # FIXME(@mrshenli): need to broadcast model to sync parameters + mod, opt = None, None + for arg in pytree.tree_flatten(list(args) + list(kwargs.values()))[ + 0 + ]: + if isinstance(arg, nn.Module): + assert mod is None, "Only support single nn.Module for now" + mod = arg + if isinstance(arg, torch.optim.Optimizer): + assert opt is None, "Only support single Optimizer for now" + opt = arg + + assert ( + mod is not None + ), "Couldn't find nn.Module instances from the arguments." + + # 2. Override target submodules (e.g., MoE) with dummy replacements + if module_override: + accessor = NamedMemberAccessor(mod) + + for typ, override in module_override.items(): + for name, submodule in mod.named_modules(): + if isinstance(submodule, typ): + accessor.swap_submodule( + name, override.replacement(submodule) + ) + + # 3. Trace statelss version of the train_step + params_and_buffers: Dict[str, Union[torch.Tensor, nn.Parameter]] = { + **dict(mod.named_parameters(remove_duplicate=False)), + **dict(mod.named_buffers(remove_duplicate=False)), + } + + named_states = {} + if opt is not None: + opt_states, spec = pytree.tree_flatten(dict(opt.state)) + + # Pass named_states instead of opt.state to stateless_func, because + # the later uses nn.Parameter as key. During tracing, we need to + # make sure optimizers can find the states using proxy tensors. + for n, p in params_and_buffers.items(): + if p in opt.state: + # opt.state's key type is string, but optimizer uses + # Parameter as keys + named_states[n] = opt.state[p] # type: ignore[index] + + # Lift states and parameters as function arguments so that make_fx + # can trace operations applied to them. + def stateless_func( + func, args, kwargs, named_states, params_and_buffers + ): + with stateless._reparametrize_module( + cast(nn.Module, mod), params_and_buffers + ), _rematerialize_optimizer( + opt, named_states, params_and_buffers + ) if opt else nullcontext(): + ret = func(*args, **kwargs) + # make sure updated parameters are returned + return ret, list(mod.parameters()) # type: ignore[union-attr] + + # FIXME: Using symbolic tracing to work around. Otherwise it hits + # shape mismatch error, as we use local inputs to trace local graph + # and use DTensor to expand operators, where DTensor's shape is the + # global shape. + with _enable_compile(): + # FIXME: functionalize crashes with + # "UnsupportedFakeTensorException: meta converter nyi" + gm = make_fx( + partial(stateless_func, func), + tracing_mode="symbolic", + _allow_non_fake_inputs=True, + )(args, kwargs, named_states, params_and_buffers) + + # 4. Use DTensor to insert collectives + gm, name_to_spec = _dtensor_expand( + gm, args, kwargs, named_states, params_and_buffers + ) + + # 5. Replace previously inserted dummy ones with real graphs. + if module_override: + for _, override in module_override.items(): + gm = override.transform(gm, name_to_spec) + + with torch.no_grad(): + # N.B.: we don't need autograd as backward has already been + # captured in the graph. + return gm(args, kwargs, named_states, params_and_buffers)[0] + + return wrapper + + return inner diff --git a/torch/distributed/_spmd/distribute.py b/torch/distributed/_spmd/distribute.py index c02695524bd..09064f3f5ee 100644 --- a/torch/distributed/_spmd/distribute.py +++ b/torch/distributed/_spmd/distribute.py @@ -445,9 +445,16 @@ def _convert_to_distributed( ) elif isinstance(node.target, torch._ops.OpOverload): - node_replacements[node] = _get_dtensor_dispatch_graph( - node, node_to_obj - ) + if not node.target._schema.name[-1] == "_": + node_replacements[node] = _get_dtensor_dispatch_graph( + node, node_to_obj + ) + else: + # FIXME(@mrshenli, @wanchaol): this prevents DTensor to insert + # allreduce for partial DTensor objects. + # FIXME: assuming it's inplace on the first arugment + node_to_obj[node] = node_to_obj[node.args[0]] + logger.info(f"Skipping expanding inplace operator {node.target.name()}") elif node.op == OP.OUTPUT: if not _allow_partial: # Returns an expanded dummy add node that ensures diff --git a/torch/nn/utils/_named_member_accessor.py b/torch/nn/utils/_named_member_accessor.py index 1c65dbaf9b5..426c6df7f37 100644 --- a/torch/nn/utils/_named_member_accessor.py +++ b/torch/nn/utils/_named_member_accessor.py @@ -82,6 +82,29 @@ def swap_tensor( return orig_tensor +def swap_submodule( + module: "torch.nn.Module", + name: str, + submodule: "torch.nn.Module", +) -> "torch.nn.Module": + if not isinstance(module, torch.nn.Module): + raise TypeError(f"{module} is not an instance of torch.nn.Module") + if not isinstance(submodule, torch.nn.Module): + raise TypeError(f"{submodule} is not an instance of torch.nn.Module") + if "." in name: + raise KeyError('submodule name can\'t contain "."') + if name == "": + raise KeyError('submodule name can\'t be empty string ""') + if name not in module._modules: + raise KeyError(f"submodule {name} does not exist") + + orig_submodule = module._modules[name] + if not isinstance(orig_submodule, torch.nn.Module): + raise TypeError(f"{name} attribute is not an instance of torch.nn.Module") + module._modules[name] = submodule + return orig_submodule + + class NamedMemberAccessor: """ A class that provides a way to access the submodules and parameters/buffers @@ -128,6 +151,15 @@ class NamedMemberAccessor: self.memo[name] = submodule return submodule + def swap_submodule(self, path: str, value: "torch.nn.Module") -> "torch.nn.Module": + """ + Swap the submodule specified by the given ``path`` to ``value``. + For example, to swap the attribute mod.layer1.conv1 use + ``accessor.swap_submodule("layer1.conv1", conv2)``. + """ + prefix, _, attr = path.rpartition(".") + return swap_submodule(self.get_submodule(prefix), attr, value) + def get_tensor(self, name: str) -> torch.Tensor: """ Get the tensor specified by the given path to value.