mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Enable full train_step tracing and customizable dist graph expansion (#97416)
This commit adds an entry point for full `train_step` tracing and expansion. Model forward, backwrd, and optimizer step will be included in one graph. DTensor expansion will be applied on top to insert collective communications. Users can also provide an `Override` implementation to skip non-traceable submodules and directly install submodule logic to the DTensor-expanded graph by inserting `fx.Nodes`. Differential Revision: [D44325177](https://our.internmc.facebook.com/intern/diff/D44325177) Pull Request resolved: https://github.com/pytorch/pytorch/pull/97416 Approved by: https://github.com/yifuwang, https://github.com/wanchaol
This commit is contained in:
parent
e67b58105a
commit
75fb0b6c9f
4 changed files with 566 additions and 8 deletions
|
|
@ -2,17 +2,28 @@
|
|||
|
||||
from copy import deepcopy
|
||||
from functools import wraps
|
||||
from typing import List
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.fx as fx
|
||||
import torch.nn as nn
|
||||
from torch.distributed._spmd.api import Schema, SPMD
|
||||
from torch.distributed._spmd.api import (
|
||||
Override,
|
||||
Schema,
|
||||
SPMD,
|
||||
compile,
|
||||
)
|
||||
from torch.distributed._spmd.comm_tensor import CommTensor
|
||||
from torch.distributed._tensor import DeviceMesh, Replicate
|
||||
from torch.distributed._tensor.ops.utils import register_prop_rule
|
||||
from torch.distributed._tensor.op_schema import OpSchema, OutputSharding
|
||||
from torch.distributed._tensor.placement_types import DTensorSpec
|
||||
from torch.distributed.distributed_c10d import get_global_rank, get_world_size
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
|
|
@ -369,5 +380,234 @@ class TraceModuleTest(DTensorTestBase):
|
|||
)
|
||||
|
||||
|
||||
class DataDependentModule(nn.Module):
|
||||
def __init__(self, world_size):
|
||||
super().__init__()
|
||||
self.world_size = world_size
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
raise RuntimeError(
|
||||
"This eager implementation shouldn't be executed."
|
||||
"This implementation is just an example of how to get around "
|
||||
"data-dependant user-defined modules. "
|
||||
)
|
||||
shape = x.shape
|
||||
x = x.view(-1)
|
||||
positive = x[x >= 0]
|
||||
negative = x[x < 0]
|
||||
|
||||
in_sizes = torch.tensor(
|
||||
[positive.numel(), negative.numel()], dtype=torch.int32
|
||||
)
|
||||
out_sizes = torch.empty_like(in_sizes)
|
||||
dist.all_to_all_single(
|
||||
out_sizes,
|
||||
in_sizes,
|
||||
output_split_sizes=[1, 1],
|
||||
input_split_sizes=[1, 1],
|
||||
)
|
||||
|
||||
xs = [positive, negative]
|
||||
ys = [
|
||||
torch.Tensor(out_sizes[i].item()) for i in range(out_sizes.numel())
|
||||
]
|
||||
dist.all_to_all(ys, xs)
|
||||
|
||||
# some dummy compute
|
||||
for y in ys:
|
||||
y.add_(1)
|
||||
|
||||
dist.all_to_all(xs, ys)
|
||||
|
||||
return torch.cat(xs).reshape(shape)
|
||||
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
def __init__(self, world_size):
|
||||
super().__init__()
|
||||
self.l1 = nn.Linear(10, 10)
|
||||
self.ddm = DataDependentModule(world_size)
|
||||
self.l2 = nn.Linear(10, 10)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
assert len(x.size()) == 2
|
||||
|
||||
return self.relu(self.l2(self.ddm(self.l1(x))))
|
||||
|
||||
|
||||
def ddm(x: torch.Tensor) -> torch.Tensor:
|
||||
return x
|
||||
|
||||
|
||||
def ddm_backward(grad: torch.Tensor) -> torch.Tensor:
|
||||
return grad
|
||||
|
||||
|
||||
dummy_lib = torch.library.Library("dummy", "DEF")
|
||||
dummy_lib.define("ddm(Tensor x) -> Tensor")
|
||||
dummy_lib.impl("ddm", ddm, "CompositeExplicitAutograd")
|
||||
dummy_lib.define("ddm_backward(Tensor x) -> Tensor")
|
||||
dummy_lib.impl("ddm_backward", ddm_backward, "CompositeExplicitAutograd")
|
||||
|
||||
|
||||
def _identity_prop_rule(op_schema: OpSchema) -> OutputSharding:
|
||||
(x,) = op_schema.args_schema
|
||||
assert isinstance(x, DTensorSpec), f"expecting DTensorSpec but got {x}"
|
||||
|
||||
return OutputSharding(output_spec=DTensorSpec(x.mesh, x.placements))
|
||||
|
||||
|
||||
@register_prop_rule(torch.ops.dummy.ddm.default)
|
||||
def _prop_ddm(op_schema: OpSchema) -> OutputSharding:
|
||||
return _identity_prop_rule(op_schema)
|
||||
|
||||
|
||||
@register_prop_rule(torch.ops.dummy.ddm_backward.default)
|
||||
def _prop_ddm_backward(op_schema: OpSchema) -> OutputSharding:
|
||||
return _identity_prop_rule(op_schema)
|
||||
|
||||
|
||||
class DDMFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx: Any, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.dummy.ddm(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, grad_x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.dummy.ddm_backward(grad_x)
|
||||
|
||||
|
||||
class DummyDDM(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return DDMFunction.apply(x)
|
||||
|
||||
|
||||
class TraceTrainStepTest(DTensorTestBase):
|
||||
@property
|
||||
def world_size(self):
|
||||
return 2
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@with_comms
|
||||
def test_train_step_simple(self):
|
||||
@compile()
|
||||
def train_step(mod, inp):
|
||||
mod(inp).sum().backward()
|
||||
return [p.grad for p in mod.parameters()]
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
inp = torch.randn(2, 10).cuda(rank)
|
||||
# FIXME(@mrshenli): remove manual seed once dist.compile can synchronize
|
||||
# module parameters.
|
||||
torch.manual_seed(0)
|
||||
mod = nn.Linear(10, 10).cuda(rank)
|
||||
|
||||
ddp_mod = DDP(deepcopy(mod), device_ids=[rank])
|
||||
ddp_inp = deepcopy(inp)
|
||||
|
||||
grads = train_step(mod, inp)
|
||||
ddp_mod(ddp_inp).sum().backward()
|
||||
|
||||
for g1, p2 in zip(grads, ddp_mod.parameters()):
|
||||
# FIXME(@mrshenli): DDP by default divides gradients by world size.
|
||||
# Should we match that behavior?
|
||||
self.assertEqual(g1 / self.world_size, p2.grad)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@with_comms
|
||||
def test_sgd(self):
|
||||
@compile()
|
||||
def train_step(mod, opt, inp):
|
||||
mod(inp).sum().backward()
|
||||
opt.step()
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
mod = nn.Linear(10, 10).cuda(rank)
|
||||
# FIXME(@mrshenli): we have to enable foreach to get better perf
|
||||
opt = torch.optim.SGD(mod.parameters(), lr=0.01, foreach=False)
|
||||
inp = torch.zeros(2, 10).cuda(rank)
|
||||
|
||||
mod(inp).sum().backward()
|
||||
opt.step()
|
||||
|
||||
# FIXME(@mrshenli): inplace op + DTensor does not trigger allreduce
|
||||
train_step(mod, opt, inp)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@with_comms
|
||||
def test_adam(self):
|
||||
@compile()
|
||||
def train_step(mod, opt, inp):
|
||||
mod(inp).sum().backward()
|
||||
opt.step()
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
mod = nn.Linear(10, 10).cuda(rank)
|
||||
opt = torch.optim.Adam(
|
||||
mod.parameters(), lr=0.01, foreach=False, capturable=True
|
||||
)
|
||||
inp = torch.zeros(2, 10).cuda(rank)
|
||||
|
||||
mod(inp).sum().backward()
|
||||
opt.step()
|
||||
|
||||
# FIXME(@mrshenli): inplace op + DTensor does not trigger allreduce
|
||||
train_step(mod, opt, inp)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@with_comms
|
||||
def test_train_step_override(self):
|
||||
transform_targets = []
|
||||
|
||||
class DDMOverride(Override):
|
||||
def replacement(
|
||||
self, orig_submodule: torch.nn.Module
|
||||
) -> torch.nn.Module:
|
||||
return DummyDDM()
|
||||
|
||||
def transform(
|
||||
self, gm: fx.GraphModule, schema_map: Dict[str, Schema]
|
||||
) -> fx.Graph:
|
||||
nonlocal transform_targets
|
||||
for node in gm.graph.nodes:
|
||||
if node.target in [
|
||||
torch.ops.dummy.ddm.default,
|
||||
torch.ops.dummy.ddm_backward.default,
|
||||
]:
|
||||
transform_targets.append(node.target)
|
||||
# N.B.: this is not a complete subgraph representing
|
||||
# original logic, as we are testing the ability to
|
||||
# modify graph after DTensor expansion.
|
||||
with gm.graph.inserting_before(node):
|
||||
new_node = gm.graph.call_function(
|
||||
torch.add, args=node.args
|
||||
)
|
||||
node.replace_all_uses_with(new_node)
|
||||
|
||||
gm.graph.lint()
|
||||
gm.graph.eliminate_dead_code()
|
||||
|
||||
return gm
|
||||
|
||||
@compile(module_override={DataDependentModule: DDMOverride()})
|
||||
def train_step(mod, opt, inp):
|
||||
mod(inp).sum().backward()
|
||||
opt.step()
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
mod = DummyModel(self.world_size).cuda(rank)
|
||||
opt = torch.optim.SGD(mod.parameters(), lr=0.01, foreach=False)
|
||||
# FIXME: symbolic tracing treats bs=1 as constant, have to use bs > 1.
|
||||
inp = torch.randn(4, 10).cuda(rank)
|
||||
train_step(mod, opt, inp)
|
||||
|
||||
# checking transforms are indeed invoked.
|
||||
self.assertEqual(transform_targets, [torch.ops.dummy.ddm.default, torch.ops.dummy.ddm_backward.default])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1,10 +1,39 @@
|
|||
from typing import Dict, Optional, Sequence, Tuple
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from copy import copy
|
||||
from functools import wraps, partial
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed._spmd.distribute import distribute, Schema
|
||||
import torch.utils._pytree as pytree
|
||||
from torch import fx
|
||||
from torch.distributed._spmd.distribute import (
|
||||
_convert_to_distributed,
|
||||
distribute,
|
||||
Schema,
|
||||
)
|
||||
from torch.distributed._spmd.distributed_graph import DistributedGraph
|
||||
from torch.distributed._tensor import Placement, Replicate
|
||||
from torch.distributed._tensor import (
|
||||
DeviceMesh,
|
||||
Placement,
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.nn.utils import stateless
|
||||
from functorch import make_fx
|
||||
from torch.nn.utils._named_member_accessor import NamedMemberAccessor
|
||||
|
||||
|
||||
class SPMD(nn.Module):
|
||||
|
|
@ -53,3 +82,253 @@ class SPMD(nn.Module):
|
|||
|
||||
assert self._compiled_m is not None
|
||||
return self._compiled_m(*args, **kwargs)
|
||||
|
||||
|
||||
class Override(ABC):
|
||||
r"""
|
||||
Override the tracing and transformation behavior of :meth:`~torch.distributed._spmd.compile`.
|
||||
This is useful when any part of the model is not traceable or if you prefer
|
||||
to not trace it due to any reason. More specifically, users can implement
|
||||
:meth:`torch.distributed._spmd.Override.replacement` to replace an original
|
||||
submodule with the return new submodule. The new submodule contrains
|
||||
operations that users preferred to be traced, which simply be a dummy
|
||||
placeholder operator. After tracing, users can implement
|
||||
:meth:`torch.distributed._spmd.Override.transform` to transform the traced
|
||||
graph, where the dummy placeholder operator serves as an anchor to insert
|
||||
new sub-graphs.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def replacement(self, orig_submodule: torch.nn.Module) -> torch.nn.Module:
|
||||
r"""
|
||||
Implement this method to return a new :class:`nn.Module` instance to
|
||||
replace the ``orig_submodule`` argument in the model. This helps if
|
||||
``orig_submodule`` is not traceable or should not be traced.
|
||||
|
||||
Args:
|
||||
orig_submodule (class:`nn.Module`): original submodule instance to replace.
|
||||
|
||||
Returns:
|
||||
A new :class:`nn.Module` instance to replace the original one.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform(
|
||||
self, gm: fx.GraphModule, schema_map: Dict[str, Schema]
|
||||
) -> fx.Graph:
|
||||
r"""
|
||||
Given a DTensor-expanded graph and shardig schema for every node,
|
||||
conduct additional transformation for the sub-graph from the :class:`nn.Module`
|
||||
returned by :meth:`torch.distributed._spmd.Override.replacement` if
|
||||
necessary.
|
||||
|
||||
Args:
|
||||
gm (:class:`fx.Graph`): a DTensor-expanded graph.
|
||||
schema_map (Dict[str, :class:`Schema`]): a dictionary maps from node
|
||||
name to DTensor schema.
|
||||
|
||||
Returns:
|
||||
The :class:`fx.Graph` after transformation.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def _dtensor_expand(
|
||||
gm: fx.GraphModule,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
named_states: Dict[str, Any],
|
||||
params_and_buffers: Dict[str, Any],
|
||||
) -> Tuple[fx.GraphModule, Dict[str, Schema]]:
|
||||
flat_args, _ = pytree.tree_flatten(list(args) + list(kwargs.values()))
|
||||
|
||||
mesh = DeviceMesh("cuda", torch.arange(dist.get_world_size()).cuda())
|
||||
shard_schema: Schema = Schema(mesh=mesh, placements=[Shard(0)])
|
||||
# FIXME: allow other sharding schemas
|
||||
replicate_schema: Schema = Schema(mesh=mesh, placements=[Replicate()])
|
||||
|
||||
inps, schemas = [], []
|
||||
for a in flat_args:
|
||||
if isinstance(a, torch.Tensor):
|
||||
inps.append(a)
|
||||
schemas.append(shard_schema)
|
||||
elif isinstance(a, nn.Module) or isinstance(a, torch.optim.Optimizer):
|
||||
# nn.Module or optimizer placeholder is captured by make_fx but
|
||||
# never used in the graph
|
||||
inps.append(torch.empty(0))
|
||||
schemas.append(shard_schema)
|
||||
|
||||
for o in pytree.tree_flatten(named_states)[0]:
|
||||
if isinstance(o, torch.Tensor):
|
||||
inps.append(o)
|
||||
schemas.append(replicate_schema)
|
||||
else:
|
||||
inps.append(torch.empty(0))
|
||||
schemas.append(replicate_schema)
|
||||
|
||||
for p in pytree.tree_flatten(params_and_buffers)[0]:
|
||||
assert isinstance(
|
||||
p, torch.Tensor
|
||||
), f"expecting Tensor but got {type(p)}"
|
||||
inps.append(p)
|
||||
schemas.append(replicate_schema)
|
||||
|
||||
return _convert_to_distributed(gm, inps, schemas, _allow_partial=False)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _rematerialize_optimizer(
|
||||
opt: torch.optim.Optimizer,
|
||||
named_states: Dict[str, Any],
|
||||
params: Dict[str, nn.Parameter],
|
||||
):
|
||||
assert opt is not None
|
||||
|
||||
# update opt.state with proxy tensors
|
||||
orig_states: Dict[str, Any] = copy(opt.state)
|
||||
for n in named_states:
|
||||
# opt.state's key type is string, but optimizer uses Parameter as keys
|
||||
opt.state[params[n]] = named_states[n] # type: ignore[index]
|
||||
|
||||
# FIXME: support multiple parameter groups
|
||||
param_group = opt.param_groups[0]
|
||||
orig_params = param_group["params"]
|
||||
# FIXME(@mrshenli): exclude buffers
|
||||
param_group["params"] = params.values()
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
param_group["params"] = orig_params
|
||||
opt.state.update(orig_states)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _enable_compile():
|
||||
# The return value of torch._utils.is_compiling changes optimizer behavior.
|
||||
# We need that function to return True to include optimizer in the graph.
|
||||
# See: https://github.com/pytorch/pytorch/blob/a524123c91ab399c9dd6882c1189596dd77e7734/torch/optim/optimizer.py#L41
|
||||
def f_true():
|
||||
return True
|
||||
|
||||
orig_is_compiling_code = torch._utils.is_compiling.__code__
|
||||
torch._utils.is_compiling.__code__ = f_true.__code__
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch._utils.is_compiling.__code__ = orig_is_compiling_code
|
||||
|
||||
|
||||
def compile(module_override: Optional[Dict[Type[Any], Override]] = None):
|
||||
r"""
|
||||
Compile and optimize a callable, which can be a train step within a training
|
||||
loop. This method will extract :class:`nn.Module` and :class:`torch.optim.Optimizer`
|
||||
instances from the input arguments and trace operations applied to their
|
||||
parameters and states.
|
||||
|
||||
Args:
|
||||
module_override (Optional[Dict[Type[Any], Override]]): a dictionary maps
|
||||
from target :class:`nn.Module` types to :class:`Override` objects.
|
||||
The :class:`Override` objects provide :class:`nn.Module` replacements
|
||||
during tracing and a graph transformation function after tracing.
|
||||
(Default: ``None``)
|
||||
"""
|
||||
|
||||
def inner(func: Callable):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# 1. Extract nn.Module and Optimizer from args and kwargs
|
||||
# FIXME(@mrshenli): support multiple nn.Module instances
|
||||
# FIXME(@mrshenli): support multiple Optiimzer instances
|
||||
# FIXME(@mrshenli): need to broadcast model to sync parameters
|
||||
mod, opt = None, None
|
||||
for arg in pytree.tree_flatten(list(args) + list(kwargs.values()))[
|
||||
0
|
||||
]:
|
||||
if isinstance(arg, nn.Module):
|
||||
assert mod is None, "Only support single nn.Module for now"
|
||||
mod = arg
|
||||
if isinstance(arg, torch.optim.Optimizer):
|
||||
assert opt is None, "Only support single Optimizer for now"
|
||||
opt = arg
|
||||
|
||||
assert (
|
||||
mod is not None
|
||||
), "Couldn't find nn.Module instances from the arguments."
|
||||
|
||||
# 2. Override target submodules (e.g., MoE) with dummy replacements
|
||||
if module_override:
|
||||
accessor = NamedMemberAccessor(mod)
|
||||
|
||||
for typ, override in module_override.items():
|
||||
for name, submodule in mod.named_modules():
|
||||
if isinstance(submodule, typ):
|
||||
accessor.swap_submodule(
|
||||
name, override.replacement(submodule)
|
||||
)
|
||||
|
||||
# 3. Trace statelss version of the train_step
|
||||
params_and_buffers: Dict[str, Union[torch.Tensor, nn.Parameter]] = {
|
||||
**dict(mod.named_parameters(remove_duplicate=False)),
|
||||
**dict(mod.named_buffers(remove_duplicate=False)),
|
||||
}
|
||||
|
||||
named_states = {}
|
||||
if opt is not None:
|
||||
opt_states, spec = pytree.tree_flatten(dict(opt.state))
|
||||
|
||||
# Pass named_states instead of opt.state to stateless_func, because
|
||||
# the later uses nn.Parameter as key. During tracing, we need to
|
||||
# make sure optimizers can find the states using proxy tensors.
|
||||
for n, p in params_and_buffers.items():
|
||||
if p in opt.state:
|
||||
# opt.state's key type is string, but optimizer uses
|
||||
# Parameter as keys
|
||||
named_states[n] = opt.state[p] # type: ignore[index]
|
||||
|
||||
# Lift states and parameters as function arguments so that make_fx
|
||||
# can trace operations applied to them.
|
||||
def stateless_func(
|
||||
func, args, kwargs, named_states, params_and_buffers
|
||||
):
|
||||
with stateless._reparametrize_module(
|
||||
cast(nn.Module, mod), params_and_buffers
|
||||
), _rematerialize_optimizer(
|
||||
opt, named_states, params_and_buffers
|
||||
) if opt else nullcontext():
|
||||
ret = func(*args, **kwargs)
|
||||
# make sure updated parameters are returned
|
||||
return ret, list(mod.parameters()) # type: ignore[union-attr]
|
||||
|
||||
# FIXME: Using symbolic tracing to work around. Otherwise it hits
|
||||
# shape mismatch error, as we use local inputs to trace local graph
|
||||
# and use DTensor to expand operators, where DTensor's shape is the
|
||||
# global shape.
|
||||
with _enable_compile():
|
||||
# FIXME: functionalize crashes with
|
||||
# "UnsupportedFakeTensorException: meta converter nyi"
|
||||
gm = make_fx(
|
||||
partial(stateless_func, func),
|
||||
tracing_mode="symbolic",
|
||||
_allow_non_fake_inputs=True,
|
||||
)(args, kwargs, named_states, params_and_buffers)
|
||||
|
||||
# 4. Use DTensor to insert collectives
|
||||
gm, name_to_spec = _dtensor_expand(
|
||||
gm, args, kwargs, named_states, params_and_buffers
|
||||
)
|
||||
|
||||
# 5. Replace previously inserted dummy ones with real graphs.
|
||||
if module_override:
|
||||
for _, override in module_override.items():
|
||||
gm = override.transform(gm, name_to_spec)
|
||||
|
||||
with torch.no_grad():
|
||||
# N.B.: we don't need autograd as backward has already been
|
||||
# captured in the graph.
|
||||
return gm(args, kwargs, named_states, params_and_buffers)[0]
|
||||
|
||||
return wrapper
|
||||
|
||||
return inner
|
||||
|
|
|
|||
|
|
@ -445,9 +445,16 @@ def _convert_to_distributed(
|
|||
)
|
||||
|
||||
elif isinstance(node.target, torch._ops.OpOverload):
|
||||
node_replacements[node] = _get_dtensor_dispatch_graph(
|
||||
node, node_to_obj
|
||||
)
|
||||
if not node.target._schema.name[-1] == "_":
|
||||
node_replacements[node] = _get_dtensor_dispatch_graph(
|
||||
node, node_to_obj
|
||||
)
|
||||
else:
|
||||
# FIXME(@mrshenli, @wanchaol): this prevents DTensor to insert
|
||||
# allreduce for partial DTensor objects.
|
||||
# FIXME: assuming it's inplace on the first arugment
|
||||
node_to_obj[node] = node_to_obj[node.args[0]]
|
||||
logger.info(f"Skipping expanding inplace operator {node.target.name()}")
|
||||
elif node.op == OP.OUTPUT:
|
||||
if not _allow_partial:
|
||||
# Returns an expanded dummy add node that ensures
|
||||
|
|
|
|||
|
|
@ -82,6 +82,29 @@ def swap_tensor(
|
|||
return orig_tensor
|
||||
|
||||
|
||||
def swap_submodule(
|
||||
module: "torch.nn.Module",
|
||||
name: str,
|
||||
submodule: "torch.nn.Module",
|
||||
) -> "torch.nn.Module":
|
||||
if not isinstance(module, torch.nn.Module):
|
||||
raise TypeError(f"{module} is not an instance of torch.nn.Module")
|
||||
if not isinstance(submodule, torch.nn.Module):
|
||||
raise TypeError(f"{submodule} is not an instance of torch.nn.Module")
|
||||
if "." in name:
|
||||
raise KeyError('submodule name can\'t contain "."')
|
||||
if name == "":
|
||||
raise KeyError('submodule name can\'t be empty string ""')
|
||||
if name not in module._modules:
|
||||
raise KeyError(f"submodule {name} does not exist")
|
||||
|
||||
orig_submodule = module._modules[name]
|
||||
if not isinstance(orig_submodule, torch.nn.Module):
|
||||
raise TypeError(f"{name} attribute is not an instance of torch.nn.Module")
|
||||
module._modules[name] = submodule
|
||||
return orig_submodule
|
||||
|
||||
|
||||
class NamedMemberAccessor:
|
||||
"""
|
||||
A class that provides a way to access the submodules and parameters/buffers
|
||||
|
|
@ -128,6 +151,15 @@ class NamedMemberAccessor:
|
|||
self.memo[name] = submodule
|
||||
return submodule
|
||||
|
||||
def swap_submodule(self, path: str, value: "torch.nn.Module") -> "torch.nn.Module":
|
||||
"""
|
||||
Swap the submodule specified by the given ``path`` to ``value``.
|
||||
For example, to swap the attribute mod.layer1.conv1 use
|
||||
``accessor.swap_submodule("layer1.conv1", conv2)``.
|
||||
"""
|
||||
prefix, _, attr = path.rpartition(".")
|
||||
return swap_submodule(self.get_submodule(prefix), attr, value)
|
||||
|
||||
def get_tensor(self, name: str) -> torch.Tensor:
|
||||
"""
|
||||
Get the tensor specified by the given path to value.
|
||||
|
|
|
|||
Loading…
Reference in a new issue