mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[FSDP2] Move to public torch.distributed.fsdp (#141868)
**Overview** This PR moves `torch/distributed/_composable/fsdp` to `torch/distributed/fsdp/_fully_shard` and makes public APIs available from `torch.distributed.fsdp`, e.g.: ``` from torch.distributed.fsdp import fully_shard ``` This is targeting 2.6 release. I rewrote some of the documentation with (hopefully) improved phrasing. **Changes for Reland** - Preserved the public objects from `torch/distributed/_composable/fsdp/fully_shard.py` so that the import path still works internally - Added a unit test that we can do `from torch.distributed._composable.fsdp.fully_shard import FSDPModule` Differential Revision: [D66890387](https://our.internmc.facebook.com/intern/diff/D66890387) Pull Request resolved: https://github.com/pytorch/pytorch/pull/141868 Approved by: https://github.com/kwen2501, https://github.com/wconstab, https://github.com/weifengpy, https://github.com/fegin, https://github.com/XilunWu Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
This commit is contained in:
parent
868d62552d
commit
78425bff30
45 changed files with 792 additions and 590 deletions
85
docs/source/distributed.fsdp.fully_shard.rst
Normal file
85
docs/source/distributed.fsdp.fully_shard.rst
Normal file
|
|
@ -0,0 +1,85 @@
|
||||||
|
torch.distributed.fsdp.fully_shard
|
||||||
|
==================================
|
||||||
|
|
||||||
|
PyTorch FSDP2 (``fully_shard``)
|
||||||
|
-------------------------------
|
||||||
|
|
||||||
|
PyTorch FSDP2 provides a fully sharded data parallelism (FSDP) implementation
|
||||||
|
targeting performant eager-mode while using per-parameter sharding for improved
|
||||||
|
usability.
|
||||||
|
|
||||||
|
- If you are new to FSDP, we recommend that you start with FSDP2 due to improved
|
||||||
|
usability.
|
||||||
|
- If you are currently using FSDP1, consider evaluating the following
|
||||||
|
differences to see if you should switch to FSDP2:
|
||||||
|
|
||||||
|
Compared to PyTorch FSDP1 (``FullyShardedDataParallel``):
|
||||||
|
|
||||||
|
- FSDP2 uses ``DTensor``-based dim-0 per-parameter sharding for a simpler
|
||||||
|
sharding representation compared to FSDP1's flat-parameter sharding, while
|
||||||
|
preserving similar throughput performance. More specifically, FSDP2 chunks
|
||||||
|
each parameter on dim-0 across the data parallel workers (using
|
||||||
|
``torch.chunk(dim=0)``), whereas FSDP1 flattens, concatenates, and chunks a
|
||||||
|
group of tensors together, making reasoning about what data is present on
|
||||||
|
each worker and resharding to different parallelisms complex. Per-parameter
|
||||||
|
sharding provides a more intuitive user experience, relaxes constraints
|
||||||
|
around frozen parameters, and allows for communication-free (sharded) state
|
||||||
|
dicts, which otherwise require all-gathers in FSDP1.
|
||||||
|
- FSDP2 implements a different memory management approach to handle the
|
||||||
|
multi-stream usages that avoids ``torch.Tensor.record_stream``. This ensures
|
||||||
|
deterministic and expected memory usage and does not require blocking the CPU
|
||||||
|
like in FSDP1's ``limit_all_gathers=True``.
|
||||||
|
- FSDP2 exposes APIs for manual control over prefetching and collective
|
||||||
|
scheduling, allowing power users more customization. See the methods on
|
||||||
|
``FSDPModule`` below for details.
|
||||||
|
- FSDP2 simplifies some of the API surface: e.g. FSDP2 does not directly
|
||||||
|
support full state dicts. Instead, users can reshard the sharded state dicts
|
||||||
|
containing ``DTensor`` s to full state dicts themselves using ``DTensor``
|
||||||
|
APIs like ``DTensor.full_tensor()`` or by using higher-level APIs like
|
||||||
|
`PyTorch Distributed Checkpoint <https://pytorch.org/docs/stable/distributed.checkpoint.html>`_ 's
|
||||||
|
distributed state dict APIs. Also, some other args have been removed; see
|
||||||
|
`here <https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md>`_ for
|
||||||
|
details.
|
||||||
|
|
||||||
|
If you are onboarding FSDP for the first time or if any of the above appeals to
|
||||||
|
your use case, we recommend that you consider using FSDP2.
|
||||||
|
|
||||||
|
See `this RFC <https://github.com/pytorch/pytorch/issues/114299>`_ for details
|
||||||
|
on system design and implementation.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
``torch.distributed.fsdp.fully_shard`` is currently in prototype state and
|
||||||
|
under development. The core API will likely not change, but we may make some
|
||||||
|
API changes if necessary.
|
||||||
|
|
||||||
|
.. currentmodule:: torch.distributed.fsdp
|
||||||
|
|
||||||
|
The frontend API is ``fully_shard`` that can be called on a ``module``:
|
||||||
|
|
||||||
|
.. autofunction:: fully_shard
|
||||||
|
|
||||||
|
Calling ``fully_shard(module)`` dynamically constructs a new class that
|
||||||
|
subclasses ``type(module)`` and an FSDP class ``FSDPModule``. For example, if
|
||||||
|
we call ``fully_shard(linear)`` on a module ``linear: nn.Linear``, then FSDP
|
||||||
|
constructs a new class ``FSDPLinear`` and changes ``linear`` 's type to this.
|
||||||
|
Otherwise, ``fully_shard`` does not change the module structure and parameter
|
||||||
|
fully-qualified names. The class ``FSDPModule`` allows providing some
|
||||||
|
FSDP-specific methods on the module.
|
||||||
|
|
||||||
|
.. autoclass:: FSDPModule
|
||||||
|
:members:
|
||||||
|
:member-order: bysource
|
||||||
|
|
||||||
|
.. autoclass:: UnshardHandle
|
||||||
|
:members:
|
||||||
|
|
||||||
|
.. autofunction:: register_fsdp_forward_method
|
||||||
|
|
||||||
|
.. autoclass:: MixedPrecisionPolicy
|
||||||
|
:members:
|
||||||
|
|
||||||
|
.. autoclass:: OffloadPolicy
|
||||||
|
:members:
|
||||||
|
|
||||||
|
.. autoclass:: CPUOffloadPolicy
|
||||||
|
:members:
|
||||||
|
|
@ -79,6 +79,7 @@ Features described in this documentation are classified by release status:
|
||||||
torch.distributed.algorithms.join <distributed.algorithms.join>
|
torch.distributed.algorithms.join <distributed.algorithms.join>
|
||||||
torch.distributed.elastic <distributed.elastic>
|
torch.distributed.elastic <distributed.elastic>
|
||||||
torch.distributed.fsdp <fsdp>
|
torch.distributed.fsdp <fsdp>
|
||||||
|
torch.distributed.fsdp.fully_shard <distributed.fsdp.fully_shard>
|
||||||
torch.distributed.tensor.parallel <distributed.tensor.parallel>
|
torch.distributed.tensor.parallel <distributed.tensor.parallel>
|
||||||
torch.distributed.optim <distributed.optim>
|
torch.distributed.optim <distributed.optim>
|
||||||
torch.distributed.pipelining <distributed.pipelining>
|
torch.distributed.pipelining <distributed.pipelining>
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from typing import Any, List, Optional, Type, Union
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed._composable.fsdp import fully_shard
|
from torch.distributed.fsdp import fully_shard
|
||||||
from torch.nn.parallel.scatter_gather import _is_namedtuple
|
from torch.nn.parallel.scatter_gather import _is_namedtuple
|
||||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||||
|
|
|
||||||
|
|
@ -7,8 +7,8 @@ from typing import Optional, Union
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed._composable import replicate
|
from torch.distributed._composable import replicate
|
||||||
from torch.distributed._composable.fsdp import fully_shard
|
|
||||||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
||||||
|
from torch.distributed.fsdp import fully_shard
|
||||||
from torch.distributed.tensor.debug import CommDebugMode
|
from torch.distributed.tensor.debug import CommDebugMode
|
||||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||||
from torch.testing._internal.common_fsdp import FSDPTest, MLPStack
|
from torch.testing._internal.common_fsdp import FSDPTest, MLPStack
|
||||||
|
|
|
||||||
|
|
@ -11,30 +11,30 @@ import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.distributed._composable import checkpoint, replicate
|
from torch.distributed._composable import checkpoint, replicate
|
||||||
from torch.distributed._composable.fsdp import (
|
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
||||||
|
from torch.distributed.fsdp import (
|
||||||
FSDPModule,
|
FSDPModule,
|
||||||
fully_shard,
|
fully_shard,
|
||||||
MixedPrecisionPolicy,
|
MixedPrecisionPolicy,
|
||||||
OffloadPolicy,
|
OffloadPolicy,
|
||||||
)
|
)
|
||||||
from torch.distributed._composable.fsdp._fsdp_collectives import (
|
from torch.distributed.fsdp._fully_shard._fsdp_collectives import (
|
||||||
_div_if_needed,
|
_div_if_needed,
|
||||||
_get_gradient_divide_factors,
|
_get_gradient_divide_factors,
|
||||||
foreach_all_gather,
|
foreach_all_gather,
|
||||||
foreach_all_gather_copy_out,
|
foreach_all_gather_copy_out,
|
||||||
foreach_reduce,
|
foreach_reduce,
|
||||||
)
|
)
|
||||||
from torch.distributed._composable.fsdp._fsdp_common import FSDPMeshInfo, TrainingState
|
from torch.distributed.fsdp._fully_shard._fsdp_common import FSDPMeshInfo, TrainingState
|
||||||
from torch.distributed._composable.fsdp._fsdp_init import (
|
from torch.distributed.fsdp._fully_shard._fsdp_init import (
|
||||||
_get_post_forward_mesh_info,
|
_get_post_forward_mesh_info,
|
||||||
_init_default_fully_shard_mesh,
|
_init_default_fully_shard_mesh,
|
||||||
)
|
)
|
||||||
from torch.distributed._composable.fsdp._fsdp_param import ShardedState
|
from torch.distributed.fsdp._fully_shard._fsdp_param import ShardedState
|
||||||
from torch.distributed._composable.fsdp._fsdp_param_group import FSDPParamGroup
|
from torch.distributed.fsdp._fully_shard._fsdp_param_group import FSDPParamGroup
|
||||||
from torch.distributed._tensor import DTensor
|
from torch.distributed.tensor import DTensor
|
||||||
from torch.distributed._tensor.experimental import implicit_replication
|
|
||||||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
|
||||||
from torch.distributed.tensor.debug import CommDebugMode
|
from torch.distributed.tensor.debug import CommDebugMode
|
||||||
|
from torch.distributed.tensor.experimental import implicit_replication
|
||||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||||
from torch.testing._internal.common_fsdp import (
|
from torch.testing._internal.common_fsdp import (
|
||||||
|
|
|
||||||
|
|
@ -12,17 +12,19 @@ from unittest import mock
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._dynamo.testing
|
import torch._dynamo.testing
|
||||||
import torch.distributed._composable.fsdp._fsdp_param
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch._dynamo.utils import counters
|
from torch._dynamo.utils import counters
|
||||||
from torch._inductor import comms
|
from torch._inductor import comms
|
||||||
from torch._inductor.utils import is_fallback_op, run_and_get_code
|
from torch._inductor.utils import is_fallback_op, run_and_get_code
|
||||||
from torch.distributed._composable.fsdp import fully_shard
|
|
||||||
from torch.distributed._composable.fsdp._fsdp_common import TrainingState
|
|
||||||
from torch.distributed._composable.fsdp._fsdp_param_group import FSDPParamGroup
|
|
||||||
from torch.distributed._tensor import init_device_mesh
|
from torch.distributed._tensor import init_device_mesh
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
|
from torch.distributed.fsdp import (
|
||||||
|
fully_shard,
|
||||||
|
FullyShardedDataParallel as FSDP,
|
||||||
|
ShardingStrategy,
|
||||||
|
)
|
||||||
|
from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState
|
||||||
|
from torch.distributed.fsdp._fully_shard._fsdp_param_group import FSDPParamGroup
|
||||||
from torch.testing import FileCheck
|
from torch.testing import FileCheck
|
||||||
from torch.testing._internal.common_distributed import (
|
from torch.testing._internal.common_distributed import (
|
||||||
at_least_x_gpu,
|
at_least_x_gpu,
|
||||||
|
|
@ -83,7 +85,7 @@ class TestFullyShardCompileCompute(FSDPTest):
|
||||||
):
|
):
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
trace_rules_check_count = 0
|
trace_rules_check_count = 0
|
||||||
HOOKS_FILE_NAME = "torch/distributed/_composable/fsdp/_fsdp_state.py"
|
HOOKS_FILE_NAME = "torch/distributed/fsdp/_fully_shard/_fsdp_state.py"
|
||||||
HOOK_WRAPPER_NAME = "fsdp_hook_wrapper"
|
HOOK_WRAPPER_NAME = "fsdp_hook_wrapper"
|
||||||
|
|
||||||
def patched_trace_rules_check(*args, **kwargs):
|
def patched_trace_rules_check(*args, **kwargs):
|
||||||
|
|
|
||||||
|
|
@ -13,8 +13,8 @@ import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
from torch.autograd.grad_mode import _unsafe_preserve_version_counter
|
from torch.autograd.grad_mode import _unsafe_preserve_version_counter
|
||||||
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
|
|
||||||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
||||||
|
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
|
||||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||||
from torch.testing._internal.common_fsdp import (
|
from torch.testing._internal.common_fsdp import (
|
||||||
|
|
|
||||||
|
|
@ -10,8 +10,8 @@ import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.distributed._composable import checkpoint, replicate
|
from torch.distributed._composable import checkpoint, replicate
|
||||||
from torch.distributed._composable.fsdp import fully_shard
|
from torch.distributed.fsdp import fully_shard
|
||||||
from torch.distributed._composable.fsdp._fsdp_param_group import (
|
from torch.distributed.fsdp._fully_shard._fsdp_param_group import (
|
||||||
RegisterPostBackwardFunction,
|
RegisterPostBackwardFunction,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,8 @@ import copy
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.amp.grad_scaler import GradScaler, OptState
|
from torch.amp.grad_scaler import GradScaler, OptState
|
||||||
from torch.distributed._composable.fsdp import fully_shard
|
|
||||||
from torch.distributed._tensor import init_device_mesh
|
from torch.distributed._tensor import init_device_mesh
|
||||||
|
from torch.distributed.fsdp import fully_shard
|
||||||
from torch.distributed.tensor.parallel import (
|
from torch.distributed.tensor.parallel import (
|
||||||
ColwiseParallel,
|
ColwiseParallel,
|
||||||
parallelize_module,
|
parallelize_module,
|
||||||
|
|
|
||||||
|
|
@ -9,13 +9,6 @@ import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed._composable import replicate
|
from torch.distributed._composable import replicate
|
||||||
from torch.distributed._composable.fsdp import fully_shard
|
|
||||||
from torch.distributed._composable.fsdp._fsdp_init import (
|
|
||||||
_get_managed_modules,
|
|
||||||
_get_managed_states,
|
|
||||||
)
|
|
||||||
from torch.distributed._composable.fsdp._fsdp_param import ParamModuleInfo
|
|
||||||
from torch.distributed._composable.fsdp._fsdp_param_group import _get_param_module_infos
|
|
||||||
from torch.distributed._tensor import (
|
from torch.distributed._tensor import (
|
||||||
DeviceMesh,
|
DeviceMesh,
|
||||||
distribute_tensor,
|
distribute_tensor,
|
||||||
|
|
@ -24,6 +17,15 @@ from torch.distributed._tensor import (
|
||||||
Shard,
|
Shard,
|
||||||
)
|
)
|
||||||
from torch.distributed.device_mesh import init_device_mesh
|
from torch.distributed.device_mesh import init_device_mesh
|
||||||
|
from torch.distributed.fsdp import fully_shard
|
||||||
|
from torch.distributed.fsdp._fully_shard._fsdp_init import (
|
||||||
|
_get_managed_modules,
|
||||||
|
_get_managed_states,
|
||||||
|
)
|
||||||
|
from torch.distributed.fsdp._fully_shard._fsdp_param import ParamModuleInfo
|
||||||
|
from torch.distributed.fsdp._fully_shard._fsdp_param_group import (
|
||||||
|
_get_param_module_infos,
|
||||||
|
)
|
||||||
from torch.distributed.fsdp._init_utils import (
|
from torch.distributed.fsdp._init_utils import (
|
||||||
_init_inter_node_process_group,
|
_init_inter_node_process_group,
|
||||||
_init_intra_node_process_group,
|
_init_intra_node_process_group,
|
||||||
|
|
@ -1156,5 +1158,31 @@ class TestFullyShardShardPlacementFn(FSDPTestMultiThread):
|
||||||
fully_shard(model, shard_placement_fn=shard_placement_fn)
|
fully_shard(model, shard_placement_fn=shard_placement_fn)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Remove this test class once we remove the old import path:
|
||||||
|
# torch/distributed/_composable/fsdp
|
||||||
|
class TestFullyShardOldImport(FSDPTestMultiThread):
|
||||||
|
@property
|
||||||
|
def world_size(self) -> int:
|
||||||
|
return 2
|
||||||
|
|
||||||
|
@unittest.skipIf(not TEST_CUDA, "no cuda")
|
||||||
|
def test_old_import_training(self):
|
||||||
|
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
|
||||||
|
from torch.distributed._composable.fsdp.fully_shard import FSDPModule
|
||||||
|
|
||||||
|
model = nn.Sequential(nn.Linear(16, 16), nn.Linear(16, 16))
|
||||||
|
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16)
|
||||||
|
fully_shard(model[0], mp_policy=mp_policy)
|
||||||
|
fully_shard(model[1], mp_policy=mp_policy)
|
||||||
|
fully_shard(model, mp_policy=mp_policy)
|
||||||
|
|
||||||
|
self.assertIsInstance(model[0], FSDPModule)
|
||||||
|
self.assertIsInstance(model[1], FSDPModule)
|
||||||
|
self.assertIsInstance(model, FSDPModule)
|
||||||
|
|
||||||
|
inp = torch.randn((8, 16), device="cuda")
|
||||||
|
model(inp).sum().backward()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ import logging
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed._composable.fsdp import fully_shard
|
from torch.distributed.fsdp import fully_shard
|
||||||
logger = logging.getLogger("torch.distributed._composable.fsdp")
|
logger = logging.getLogger("torch.distributed._composable.fsdp")
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
|
|
|
||||||
|
|
@ -4,11 +4,7 @@ import functools
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed._composable.fsdp import (
|
from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, OffloadPolicy
|
||||||
CPUOffloadPolicy,
|
|
||||||
fully_shard,
|
|
||||||
OffloadPolicy,
|
|
||||||
)
|
|
||||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||||
from torch.testing._internal.common_fsdp import FSDPTest
|
from torch.testing._internal.common_fsdp import FSDPTest
|
||||||
from torch.testing._internal.common_utils import run_tests
|
from torch.testing._internal.common_utils import run_tests
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,8 @@ import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.distributed._functional_collectives as funcol
|
import torch.distributed._functional_collectives as funcol
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
|
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
|
||||||
from torch.distributed._composable.fsdp._fsdp_collectives import (
|
from torch.distributed.fsdp._fully_shard._fsdp_collectives import (
|
||||||
_get_gradient_divide_factors,
|
_get_gradient_divide_factors,
|
||||||
)
|
)
|
||||||
from torch.distributed.tensor import Shard
|
from torch.distributed.tensor import Shard
|
||||||
|
|
|
||||||
|
|
@ -7,8 +7,8 @@ from typing import Callable
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed._composable.fsdp import fully_shard
|
|
||||||
from torch.distributed._tensor.experimental import implicit_replication
|
from torch.distributed._tensor.experimental import implicit_replication
|
||||||
|
from torch.distributed.fsdp import fully_shard
|
||||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||||
from torch.testing._internal.common_fsdp import (
|
from torch.testing._internal.common_fsdp import (
|
||||||
FSDPTest,
|
FSDPTest,
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import copy
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed._composable.fsdp import FSDPModule, fully_shard
|
from torch.distributed.fsdp import FSDPModule, fully_shard
|
||||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||||
from torch.testing._internal.common_fsdp import FSDPTestMultiThread, MLP
|
from torch.testing._internal.common_fsdp import FSDPTestMultiThread, MLP
|
||||||
from torch.testing._internal.common_utils import run_tests
|
from torch.testing._internal.common_utils import run_tests
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,8 @@ from typing import Dict, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed._composable.fsdp import CPUOffloadPolicy, fully_shard
|
|
||||||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
||||||
|
from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard
|
||||||
from torch.distributed.tensor import distribute_tensor, DTensor, Shard
|
from torch.distributed.tensor import distribute_tensor, DTensor, Shard
|
||||||
from torch.distributed.tensor.parallel import (
|
from torch.distributed.tensor.parallel import (
|
||||||
ColwiseParallel,
|
ColwiseParallel,
|
||||||
|
|
|
||||||
|
|
@ -12,18 +12,18 @@ import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed._composable import checkpoint, replicate
|
from torch.distributed._composable import checkpoint, replicate
|
||||||
from torch.distributed._composable.fsdp import (
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||||
|
_CHECKPOINT_PREFIX,
|
||||||
|
apply_activation_checkpointing,
|
||||||
|
)
|
||||||
|
from torch.distributed.device_mesh import DeviceMesh
|
||||||
|
from torch.distributed.fsdp import (
|
||||||
CPUOffloadPolicy,
|
CPUOffloadPolicy,
|
||||||
FSDPModule,
|
FSDPModule,
|
||||||
fully_shard,
|
fully_shard,
|
||||||
OffloadPolicy,
|
OffloadPolicy,
|
||||||
register_fsdp_forward_method,
|
register_fsdp_forward_method,
|
||||||
)
|
)
|
||||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
|
||||||
_CHECKPOINT_PREFIX,
|
|
||||||
apply_activation_checkpointing,
|
|
||||||
)
|
|
||||||
from torch.distributed.device_mesh import DeviceMesh
|
|
||||||
from torch.distributed.tensor import DTensor, init_device_mesh, Shard
|
from torch.distributed.tensor import DTensor, init_device_mesh, Shard
|
||||||
from torch.distributed.tensor.debug import CommDebugMode
|
from torch.distributed.tensor.debug import CommDebugMode
|
||||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||||
|
|
@ -671,7 +671,7 @@ class TestFullyShard1DTrainingCompose(FSDPTest):
|
||||||
module_grouping: str,
|
module_grouping: str,
|
||||||
):
|
):
|
||||||
assert checkpoint_impl in ("composable", "utils", "wrapper")
|
assert checkpoint_impl in ("composable", "utils", "wrapper")
|
||||||
testing_compile = fully_shard != torch.distributed._composable.fsdp.fully_shard
|
testing_compile = fully_shard != torch.distributed.fsdp.fully_shard
|
||||||
if testing_compile and checkpoint_impl == "composable":
|
if testing_compile and checkpoint_impl == "composable":
|
||||||
return
|
return
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,6 @@ import torch.distributed.checkpoint as dcp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.distributed._composable import replicate
|
from torch.distributed._composable import replicate
|
||||||
from torch.distributed._composable.fsdp import CPUOffloadPolicy, fully_shard
|
|
||||||
from torch.distributed._tensor import DTensor, init_device_mesh, Replicate, Shard
|
from torch.distributed._tensor import DTensor, init_device_mesh, Replicate, Shard
|
||||||
from torch.distributed.checkpoint.state_dict import (
|
from torch.distributed.checkpoint.state_dict import (
|
||||||
get_model_state_dict,
|
get_model_state_dict,
|
||||||
|
|
@ -22,7 +21,11 @@ from torch.distributed.checkpoint.state_dict import (
|
||||||
StateDictOptions,
|
StateDictOptions,
|
||||||
)
|
)
|
||||||
from torch.distributed.device_mesh import DeviceMesh
|
from torch.distributed.device_mesh import DeviceMesh
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
from torch.distributed.fsdp import (
|
||||||
|
CPUOffloadPolicy,
|
||||||
|
fully_shard,
|
||||||
|
FullyShardedDataParallel as FSDP,
|
||||||
|
)
|
||||||
from torch.distributed.fsdp._common_utils import (
|
from torch.distributed.fsdp._common_utils import (
|
||||||
_get_module_fsdp_state,
|
_get_module_fsdp_state,
|
||||||
clean_tensor_name,
|
clean_tensor_name,
|
||||||
|
|
|
||||||
|
|
@ -6,10 +6,6 @@ from typing import TYPE_CHECKING
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed.checkpoint as dcp
|
import torch.distributed.checkpoint as dcp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed._composable.fsdp.fully_shard import (
|
|
||||||
fully_shard,
|
|
||||||
MixedPrecisionPolicy,
|
|
||||||
)
|
|
||||||
from torch.distributed._tensor import DTensor
|
from torch.distributed._tensor import DTensor
|
||||||
from torch.distributed.checkpoint import FileSystemReader
|
from torch.distributed.checkpoint import FileSystemReader
|
||||||
from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner
|
from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner
|
||||||
|
|
@ -17,6 +13,7 @@ from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_di
|
||||||
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
|
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
|
||||||
from torch.distributed.checkpoint.stateful import Stateful
|
from torch.distributed.checkpoint.stateful import Stateful
|
||||||
from torch.distributed.device_mesh import init_device_mesh
|
from torch.distributed.device_mesh import init_device_mesh
|
||||||
|
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
|
||||||
from torch.distributed.pipelining import PipelineStage
|
from torch.distributed.pipelining import PipelineStage
|
||||||
from torch.distributed.pipelining.schedules import (
|
from torch.distributed.pipelining.schedules import (
|
||||||
PipelineScheduleSingle,
|
PipelineScheduleSingle,
|
||||||
|
|
|
||||||
|
|
@ -7,9 +7,9 @@ import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.distributed._composable.fsdp import fully_shard
|
|
||||||
from torch.distributed._composable.replicate import replicate
|
from torch.distributed._composable.replicate import replicate
|
||||||
from torch.distributed._tensor import DTensor
|
from torch.distributed._tensor import DTensor
|
||||||
|
from torch.distributed.fsdp import fully_shard
|
||||||
from torch.testing._internal.common_distributed import (
|
from torch.testing._internal.common_distributed import (
|
||||||
MultiProcessTestCase,
|
MultiProcessTestCase,
|
||||||
skip_if_lt_x_gpu,
|
skip_if_lt_x_gpu,
|
||||||
|
|
|
||||||
|
|
@ -6,12 +6,12 @@ import itertools
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed._functional_collectives as funcol
|
import torch.distributed._functional_collectives as funcol
|
||||||
import torch.distributed.tensor._random as random
|
import torch.distributed.tensor._random as random
|
||||||
from torch.distributed._composable.fsdp import fully_shard
|
|
||||||
from torch.distributed._tensor import DeviceMesh, DTensor, init_device_mesh
|
from torch.distributed._tensor import DeviceMesh, DTensor, init_device_mesh
|
||||||
from torch.distributed._tensor._utils import compute_local_shape_and_global_offset
|
from torch.distributed._tensor._utils import compute_local_shape_and_global_offset
|
||||||
from torch.distributed._tensor.api import distribute_tensor
|
from torch.distributed._tensor.api import distribute_tensor
|
||||||
from torch.distributed._tensor.placement_types import Replicate, Shard
|
from torch.distributed._tensor.placement_types import Replicate, Shard
|
||||||
from torch.distributed.distributed_c10d import broadcast_object_list
|
from torch.distributed.distributed_c10d import broadcast_object_list
|
||||||
|
from torch.distributed.fsdp import fully_shard
|
||||||
from torch.distributed.tensor._random import (
|
from torch.distributed.tensor._random import (
|
||||||
is_rng_supported_mesh,
|
is_rng_supported_mesh,
|
||||||
manual_seed,
|
manual_seed,
|
||||||
|
|
|
||||||
|
|
@ -6,18 +6,18 @@ from typing import Union
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed._composable import checkpoint
|
from torch.distributed._composable import checkpoint
|
||||||
from torch.distributed._composable.fsdp import (
|
|
||||||
CPUOffloadPolicy,
|
|
||||||
fully_shard,
|
|
||||||
MixedPrecisionPolicy,
|
|
||||||
OffloadPolicy,
|
|
||||||
)
|
|
||||||
from torch.distributed._tensor import init_device_mesh
|
from torch.distributed._tensor import init_device_mesh
|
||||||
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
|
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
|
||||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||||
apply_activation_checkpointing,
|
apply_activation_checkpointing,
|
||||||
CheckpointWrapper,
|
CheckpointWrapper,
|
||||||
)
|
)
|
||||||
|
from torch.distributed.fsdp import (
|
||||||
|
CPUOffloadPolicy,
|
||||||
|
fully_shard,
|
||||||
|
MixedPrecisionPolicy,
|
||||||
|
OffloadPolicy,
|
||||||
|
)
|
||||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||||
from torch.testing._internal.common_fsdp import FSDPTest, MLP
|
from torch.testing._internal.common_fsdp import FSDPTest, MLP
|
||||||
from torch.testing._internal.common_utils import run_tests
|
from torch.testing._internal.common_utils import run_tests
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@ import copy
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed.checkpoint as dcp
|
import torch.distributed.checkpoint as dcp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed._composable.fsdp import fully_shard
|
|
||||||
from torch.distributed._tensor import DTensor, init_device_mesh
|
from torch.distributed._tensor import DTensor, init_device_mesh
|
||||||
from torch.distributed._tensor.experimental import implicit_replication
|
from torch.distributed._tensor.experimental import implicit_replication
|
||||||
from torch.distributed.checkpoint.state_dict import (
|
from torch.distributed.checkpoint.state_dict import (
|
||||||
|
|
@ -14,7 +13,11 @@ from torch.distributed.checkpoint.state_dict import (
|
||||||
get_optimizer_state_dict,
|
get_optimizer_state_dict,
|
||||||
StateDictOptions,
|
StateDictOptions,
|
||||||
)
|
)
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
|
from torch.distributed.fsdp import (
|
||||||
|
fully_shard,
|
||||||
|
FullyShardedDataParallel as FSDP,
|
||||||
|
StateDictType,
|
||||||
|
)
|
||||||
from torch.distributed.fsdp.wrap import always_wrap_policy
|
from torch.distributed.fsdp.wrap import always_wrap_policy
|
||||||
from torch.distributed.tensor.parallel import (
|
from torch.distributed.tensor.parallel import (
|
||||||
ColwiseParallel,
|
ColwiseParallel,
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,6 @@ import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed._composable import replicate
|
from torch.distributed._composable import replicate
|
||||||
from torch.distributed._composable.fsdp import fully_shard
|
|
||||||
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
||||||
from torch.distributed._tensor import DTensor, init_device_mesh
|
from torch.distributed._tensor import DTensor, init_device_mesh
|
||||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||||
|
|
@ -28,6 +27,7 @@ from torch.distributed.checkpoint.state_dict import (
|
||||||
StateDictOptions,
|
StateDictOptions,
|
||||||
)
|
)
|
||||||
from torch.distributed.fsdp import (
|
from torch.distributed.fsdp import (
|
||||||
|
fully_shard,
|
||||||
FullyShardedDataParallel as FSDP,
|
FullyShardedDataParallel as FSDP,
|
||||||
ShardingStrategy,
|
ShardingStrategy,
|
||||||
StateDictType,
|
StateDictType,
|
||||||
|
|
|
||||||
|
|
@ -3263,7 +3263,7 @@ if torch.distributed.is_available():
|
||||||
"torch.distributed._composable.replicate",
|
"torch.distributed._composable.replicate",
|
||||||
}
|
}
|
||||||
if not torch._dynamo.config.skip_fsdp_hooks:
|
if not torch._dynamo.config.skip_fsdp_hooks:
|
||||||
LEGACY_MOD_INLINELIST.add("torch.distributed._composable.fsdp")
|
LEGACY_MOD_INLINELIST.add("torch.distributed.fsdp._fully_shard")
|
||||||
|
|
||||||
|
|
||||||
# Force inline functions under these modules, even they are in *_SKIPLIST.
|
# Force inline functions under these modules, even they are in *_SKIPLIST.
|
||||||
|
|
@ -3323,7 +3323,7 @@ MOD_INLINELIST = set(MOD_INLINELIST)
|
||||||
if torch.distributed.is_available():
|
if torch.distributed.is_available():
|
||||||
MOD_INLINELIST.add("torch.distributed")
|
MOD_INLINELIST.add("torch.distributed")
|
||||||
if not torch._dynamo.config.skip_fsdp_hooks:
|
if not torch._dynamo.config.skip_fsdp_hooks:
|
||||||
MOD_INLINELIST.add("torch.distributed._composable.fsdp")
|
MOD_INLINELIST.add("torch.distributed.fsdp._fully_shard")
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(None)
|
@functools.lru_cache(None)
|
||||||
|
|
|
||||||
|
|
@ -994,7 +994,7 @@ class FSDPParamGroupUseTrainingStateVariable(ContextWrappingVariable):
|
||||||
self.param_group_var.value._training_state = value
|
self.param_group_var.value._training_state = value
|
||||||
|
|
||||||
def module_name(self):
|
def module_name(self):
|
||||||
return "torch.distributed._composable.fsdp._fsdp_param_group.FSDPParamGroup"
|
return "torch.distributed.fsdp._fully_shard._fsdp_param_group.FSDPParamGroup"
|
||||||
|
|
||||||
def fn_name(self):
|
def fn_name(self):
|
||||||
return "use_training_state"
|
return "use_training_state"
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ from .constant import ConstantVariable
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torch.distributed._composable.fsdp import _fsdp_param_group
|
from torch.distributed.fsdp._fully_shard import _fsdp_param_group
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
_fsdp_param_group = None
|
_fsdp_param_group = None
|
||||||
|
|
||||||
|
|
@ -305,7 +305,7 @@ class UserFunctionVariable(BaseUserFunctionVariable):
|
||||||
and not tx.output.current_tracer.allow_side_effects_under_checkpoint
|
and not tx.output.current_tracer.allow_side_effects_under_checkpoint
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
from torch.distributed._composable.fsdp._fsdp_state import FSDPState
|
from torch.distributed.fsdp._fully_shard._fsdp_state import FSDPState
|
||||||
except Exception:
|
except Exception:
|
||||||
FSDPState = None
|
FSDPState = None
|
||||||
if FSDPState is not None and self.fn in [
|
if FSDPState is not None and self.fn in [
|
||||||
|
|
|
||||||
|
|
@ -56,7 +56,7 @@ except ModuleNotFoundError:
|
||||||
np = None # type: ignore[assignment]
|
np = None # type: ignore[assignment]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torch.distributed._composable.fsdp import _fsdp_param_group
|
from torch.distributed.fsdp._fully_shard import _fsdp_param_group
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
_fsdp_param_group = None # type: ignore[assignment]
|
_fsdp_param_group = None # type: ignore[assignment]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -536,7 +536,7 @@ Graph: {graph}
|
||||||
|
|
||||||
def reinplace_fsdp_all_gather(graph: torch.fx.Graph) -> None:
|
def reinplace_fsdp_all_gather(graph: torch.fx.Graph) -> None:
|
||||||
try:
|
try:
|
||||||
import torch.distributed._composable.fsdp._fsdp_collectives
|
import torch.distributed.fsdp._fully_shard._fsdp_collectives
|
||||||
|
|
||||||
assert torch.distributed.is_available()
|
assert torch.distributed.is_available()
|
||||||
# Assert existence of these ops
|
# Assert existence of these ops
|
||||||
|
|
|
||||||
|
|
@ -1,2 +1,8 @@
|
||||||
from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy
|
from torch.distributed.fsdp import (
|
||||||
from .fully_shard import FSDPModule, fully_shard, register_fsdp_forward_method
|
CPUOffloadPolicy,
|
||||||
|
FSDPModule,
|
||||||
|
fully_shard,
|
||||||
|
MixedPrecisionPolicy,
|
||||||
|
OffloadPolicy,
|
||||||
|
register_fsdp_forward_method,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,501 +1,8 @@
|
||||||
# mypy: allow-untyped-decorators
|
# TODO: For backward compatibility, we are importing the public objects
|
||||||
# mypy: allow-untyped-defs
|
# originally from this file.
|
||||||
import functools
|
from torch.distributed.fsdp import ( # noqa: F401
|
||||||
from typing import (
|
FSDPModule,
|
||||||
Any,
|
fully_shard,
|
||||||
Callable,
|
register_fsdp_forward_method,
|
||||||
cast,
|
UnshardHandle,
|
||||||
Dict,
|
|
||||||
Iterable,
|
|
||||||
List,
|
|
||||||
NoReturn,
|
|
||||||
Optional,
|
|
||||||
Type,
|
|
||||||
Union,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.distributed._composable import contract
|
|
||||||
from torch.distributed.tensor import DeviceMesh, Shard
|
|
||||||
from torch.distributed.utils import _get_root_modules
|
|
||||||
|
|
||||||
from ._fsdp_api import MixedPrecisionPolicy, OffloadPolicy
|
|
||||||
from ._fsdp_common import FSDPMeshInfo, HSDPMeshInfo
|
|
||||||
from ._fsdp_init import (
|
|
||||||
_get_device_from_mesh,
|
|
||||||
_get_managed_modules,
|
|
||||||
_get_managed_states,
|
|
||||||
_get_post_forward_mesh_info,
|
|
||||||
_init_default_fully_shard_mesh,
|
|
||||||
_move_states_to_device,
|
|
||||||
)
|
|
||||||
from ._fsdp_param_group import FSDPParamGroup
|
|
||||||
from ._fsdp_state import _get_module_fsdp_state, FSDPState
|
|
||||||
|
|
||||||
|
|
||||||
cls_to_fsdp_cls: Dict[Type, Type] = {}
|
|
||||||
|
|
||||||
|
|
||||||
# The decorator adds a state object to `module` that can be accessed via
|
|
||||||
# `fully_shard.state(module)`. The state object and module are 1:1.
|
|
||||||
@contract(state_cls=FSDPState) # type: ignore[operator]
|
|
||||||
def fully_shard(
|
|
||||||
module: Union[nn.Module, List[nn.Module]],
|
|
||||||
*,
|
|
||||||
mesh: Optional[DeviceMesh] = None,
|
|
||||||
reshard_after_forward: Union[bool, int] = True,
|
|
||||||
shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]] = None,
|
|
||||||
mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(),
|
|
||||||
offload_policy: OffloadPolicy = OffloadPolicy(),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Shard module parameters across data parallel workers.
|
|
||||||
|
|
||||||
This function applies fully sharded data parallelism (FSDP) or a variant to
|
|
||||||
``module``, a technique for memory savings at the cost of communication.
|
|
||||||
Parameters are sharded across ``mesh``, and in turn, so are their gradients
|
|
||||||
and optimizer states.
|
|
||||||
|
|
||||||
The sharded parameters are all-gathered to construct the unsharded
|
|
||||||
parameters for forward or backward computation. The unsharded parameters
|
|
||||||
are freed after computation to save memory. The gradients are reduced
|
|
||||||
across the mesh and divided by the mesh size for data parallelism. The
|
|
||||||
optimizer step runs on the sharded parameters.
|
|
||||||
|
|
||||||
Each call to ``fully_shard`` constructs one communication group that
|
|
||||||
includes the parameters in ``module.parameters()`` except those already
|
|
||||||
assigned to a group from a nested call. Each group's parameters and its
|
|
||||||
gradients are communicated together in one collective, respectively.
|
|
||||||
Constructing multiple groups across the model (e.g. "layer by layer")
|
|
||||||
allows for peak memory savings and communication/computation overlap.
|
|
||||||
|
|
||||||
Implementation-wise, the sharded parameters are represented as
|
|
||||||
:class:`DTensor` s, sharded on dim-0, and the unsharded parameters are
|
|
||||||
represented as :class:`Tensor` s. A module forward pre-hook all-gathers the
|
|
||||||
parameters, and a module forward hook frees them. Similar backward hooks
|
|
||||||
gather parameters and later free parameters/reduce gradients.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
module (Union[nn.Module, List[nn.Module]): The module or modules to
|
|
||||||
shard with FSDP and group together for communication.
|
|
||||||
mesh (Optional[DeviceMesh]): This data parallel mesh defines the
|
|
||||||
sharding and device. If 1D, then parameters are fully sharded
|
|
||||||
across the 1D mesh (FSDP). If 2D, then parameters are sharded
|
|
||||||
across the 0th dim and replicated across the 1st dim (HSDP). The
|
|
||||||
mesh's device type gives the device type used for communication;
|
|
||||||
if a CUDA or CUDA-like device type, then we use the current device.
|
|
||||||
reshard_after_forward (Union[bool, int]): This controls the parameter
|
|
||||||
behavior after forward and can trade off memory and communication:
|
|
||||||
- If ``True``, then this reshards parameters after forward and
|
|
||||||
all-gathers in backward.
|
|
||||||
- If ``False``, then this keeps the unsharded parameters in memory
|
|
||||||
after forward and avoids the all-gather in backward.
|
|
||||||
- If an ``int``, then this represents the world size to reshard to
|
|
||||||
after forward. It should be a non-trivial divisor of the ``mesh``
|
|
||||||
shard dim size (i.e. excluding 1 and the dim size itself). A choice
|
|
||||||
may be the intra-node size (e.g. ``torch.cuda.device_count()``).
|
|
||||||
This allows the all-gather in backward to be over a smaller world
|
|
||||||
size at the cost of higher memory usage than setting to ``True``.
|
|
||||||
- The root FSDP state has its value specially set to ``False`` as a
|
|
||||||
heuristic since its parameters would typically be immediately
|
|
||||||
all-gathered for backward.
|
|
||||||
- After forward, the parameters registered to the module depend on
|
|
||||||
to this: The registered parameters are the sharded parameters if
|
|
||||||
``True``; unsharded parameters if ``False``; and the paramters
|
|
||||||
resharded to the smaller mesh otherwise. To modify the parameters
|
|
||||||
between forward and backward, the registered parameters must be the
|
|
||||||
sharded parameters. For ``False`` or an ``int``, this can be done
|
|
||||||
by manually resharding via :meth:`reshard`.
|
|
||||||
shard_placement_fn (Optional[Callable[[nn.Parameter], Optional[Shard]]]):
|
|
||||||
This callable can be used to override the sharding placement for a
|
|
||||||
parameter to shard a parameter on a dimension other than dim-0. If
|
|
||||||
this callable returns a ``Shard`` placement (not ``None``), then
|
|
||||||
FSDP will shard according to that placement (e.g. ``Shard(1)``).
|
|
||||||
If sharding on a nonzero dim, we currently require even sharding,
|
|
||||||
i.e. the tensor dim size on that dim must be divisible by the FSDP
|
|
||||||
shard mesh size.
|
|
||||||
mp_policy (MixedPrecisionPolicy): This controls the mixed precision
|
|
||||||
policy, which offers parameter/reduction mixed precision for this
|
|
||||||
module. See :class:`MixedPrecisionPolicy` for details.
|
|
||||||
offload_policy (OffloadPolicy): This controls the offloading policy,
|
|
||||||
which offers parameter/gradient/optimizer state offloading. See
|
|
||||||
:class:`OffloadPolicy` and its subclasses for details.
|
|
||||||
"""
|
|
||||||
if isinstance(module, (nn.ModuleList, nn.ModuleDict)):
|
|
||||||
raise ValueError(
|
|
||||||
f"fully_shard does not support containers that do not implement forward: {module}"
|
|
||||||
)
|
|
||||||
mesh = mesh or _init_default_fully_shard_mesh()
|
|
||||||
if mesh.ndim not in (1, 2):
|
|
||||||
raise ValueError(f"fully_shard expects a 1D or 2D DeviceMesh but got {mesh}")
|
|
||||||
elif mesh.ndim == 1:
|
|
||||||
mesh_info = FSDPMeshInfo(mesh, shard_mesh_dim=0)
|
|
||||||
else:
|
|
||||||
if mesh.mesh_dim_names is None:
|
|
||||||
raise AssertionError(
|
|
||||||
"Please init the 2D mesh for HSDP with mesh_dim_names specified"
|
|
||||||
)
|
|
||||||
mesh_info = HSDPMeshInfo(mesh, shard_mesh_dim=1, replicate_mesh_dim=0)
|
|
||||||
device = _get_device_from_mesh(mesh)
|
|
||||||
post_forward_mesh_info = _get_post_forward_mesh_info(
|
|
||||||
reshard_after_forward, mesh_info
|
|
||||||
)
|
|
||||||
|
|
||||||
arg_module = module
|
|
||||||
modules = (
|
|
||||||
(module,) if isinstance(module, nn.Module) else tuple(_get_root_modules(module))
|
|
||||||
)
|
|
||||||
state = fully_shard.state(modules[0])
|
|
||||||
state.init(modules, device, mp_policy)
|
|
||||||
|
|
||||||
managed_modules = _get_managed_modules(modules)
|
|
||||||
params, buffers = _get_managed_states(managed_modules)
|
|
||||||
_move_states_to_device(params, buffers, device)
|
|
||||||
if params:
|
|
||||||
state._fsdp_param_group = FSDPParamGroup(
|
|
||||||
params,
|
|
||||||
modules,
|
|
||||||
mesh_info,
|
|
||||||
post_forward_mesh_info,
|
|
||||||
device,
|
|
||||||
shard_placement_fn,
|
|
||||||
mp_policy,
|
|
||||||
offload_policy,
|
|
||||||
)
|
|
||||||
|
|
||||||
# For Dynamo
|
|
||||||
for managed_module in managed_modules:
|
|
||||||
managed_module._is_fsdp_managed_module = True # type: ignore[assignment]
|
|
||||||
managed_module._fsdp_use_orig_params = True # type: ignore[assignment]
|
|
||||||
|
|
||||||
# Place FSDP leftmost for highest priority in the method resolution order
|
|
||||||
for module in modules:
|
|
||||||
cls = module.__class__
|
|
||||||
new_cls = cls_to_fsdp_cls.get(cls, None)
|
|
||||||
if not new_cls:
|
|
||||||
dct = {"__deepcopy__": unimplemented_deepcopy}
|
|
||||||
new_cls = type(f"FSDP{cls.__name__}", (FSDPModule, cls), dct)
|
|
||||||
cls_to_fsdp_cls[cls] = new_cls
|
|
||||||
module.__class__ = new_cls
|
|
||||||
return arg_module
|
|
||||||
|
|
||||||
|
|
||||||
def unimplemented_deepcopy(*args: Any, **kwargs: Any) -> NoReturn:
|
|
||||||
raise AssertionError(
|
|
||||||
"FSDP does not support deepcopy. Please use state dict for serialization."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FSDPModule:
|
|
||||||
def __new__(cls, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Override ``__new__`` to remove the FSDP class and directly construct
|
|
||||||
the original class for cases like indexing into a container module.
|
|
||||||
"""
|
|
||||||
# Use index 2 since 0 is the dynamically constructed `FSDP<...>` class
|
|
||||||
# and index 1 is the `FSDPModule` class itself
|
|
||||||
orig_cls = cls.__mro__[2]
|
|
||||||
self = orig_cls.__new__(orig_cls, *args, **kwargs)
|
|
||||||
self.__init__(*args, **kwargs)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def reshard(self) -> None:
|
|
||||||
"""
|
|
||||||
Reshards the module's parameters, registering the sharded parameters
|
|
||||||
to the module and freeing the unsharded parameters if needed. This
|
|
||||||
method is *not* recursive.
|
|
||||||
"""
|
|
||||||
state = self._get_fsdp_state()
|
|
||||||
if fsdp_param_group := state._fsdp_param_group:
|
|
||||||
fsdp_param_group.reshard()
|
|
||||||
|
|
||||||
def unshard(self, async_op: bool = False) -> Optional["UnshardHandle"]:
|
|
||||||
"""
|
|
||||||
Unshards the module's parameters by allocating memory and all-gathering
|
|
||||||
the parameters. This method is *not* recursive.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
async_op (bool): If ``True``, then returns a :class:`UnshardHandle`
|
|
||||||
that has a :meth:`wait` method to wait on the unshard op. If
|
|
||||||
``False``, then returns ``None`` and waits on the handle inside
|
|
||||||
this function.
|
|
||||||
|
|
||||||
.. warning:: This method is experimental and subject to change.
|
|
||||||
|
|
||||||
.. note:: If ``async_op=True``, then the user does not have to call
|
|
||||||
:meth:`wait` on the returned handle if waiting on the unshard op
|
|
||||||
in the module's pre-forward is tolerable. FSDP will wait on the
|
|
||||||
pending unshard op in the pre-forward automatically.
|
|
||||||
"""
|
|
||||||
state = self._get_fsdp_state()
|
|
||||||
fsdp_param_group = state._fsdp_param_group
|
|
||||||
if fsdp_param_group is not None:
|
|
||||||
fsdp_param_group.lazy_init()
|
|
||||||
fsdp_param_group.unshard(async_op=async_op)
|
|
||||||
handle = UnshardHandle(fsdp_param_group)
|
|
||||||
if async_op:
|
|
||||||
return handle
|
|
||||||
handle.wait()
|
|
||||||
return None
|
|
||||||
|
|
||||||
def set_is_last_backward(self, is_last_backward: bool) -> None:
|
|
||||||
"""
|
|
||||||
Sets whether the next backward is the last one, meaning that FSDP
|
|
||||||
should wait for gradient reduction to finish and clear internal data
|
|
||||||
structures used for explicit prefetching.
|
|
||||||
"""
|
|
||||||
state = self._get_fsdp_state()
|
|
||||||
state._state_ctx.is_last_backward = is_last_backward
|
|
||||||
|
|
||||||
def set_requires_gradient_sync(
|
|
||||||
self, requires_gradient_sync: bool, *, recurse: bool = True
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Sets if the module should sync gradients. This can be used to implement
|
|
||||||
gradient accumulation without communication. For HSDP, this controls
|
|
||||||
both reduce-scatter and all-reduce together.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
requires_gradient_sync (bool): Whether to reduce gradients for the
|
|
||||||
module's parameters.
|
|
||||||
recurse (bool): Whether to set for all submodules or just the
|
|
||||||
passed-in module.
|
|
||||||
"""
|
|
||||||
self_module = cast(nn.Module, self)
|
|
||||||
modules = list(self_module.modules()) if recurse else [self_module]
|
|
||||||
for module in modules:
|
|
||||||
if isinstance(module, FSDPModule):
|
|
||||||
state = module._get_fsdp_state()
|
|
||||||
if fsdp_param_group := state._fsdp_param_group:
|
|
||||||
fsdp_param_group.reduce_grads = requires_gradient_sync
|
|
||||||
fsdp_param_group.all_reduce_grads = requires_gradient_sync
|
|
||||||
|
|
||||||
def set_requires_all_reduce(
|
|
||||||
self, requires_all_reduce: bool, *, recurse: bool = True
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Sets if the module should all-reduce gradients. This can be used to
|
|
||||||
implement gradient accumulation with only reduce-scatter but not
|
|
||||||
all-reduce for HSDP.
|
|
||||||
"""
|
|
||||||
self_module = cast(nn.Module, self)
|
|
||||||
modules = list(self_module.modules()) if recurse else [self_module]
|
|
||||||
for module in modules:
|
|
||||||
if isinstance(module, FSDPModule):
|
|
||||||
state = module._get_fsdp_state()
|
|
||||||
if fsdp_param_group := state._fsdp_param_group:
|
|
||||||
fsdp_param_group.all_reduce_grads = requires_all_reduce
|
|
||||||
|
|
||||||
def set_reshard_after_backward(
|
|
||||||
self, reshard_after_backward: bool, *, recurse: bool = True
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Sets if the module should reshard parameters after backward. This can
|
|
||||||
be used during gradient accumulation to trade off higher memory for
|
|
||||||
reduced communication.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
reshard_after_backward (bool): Whether to reshard parameters after
|
|
||||||
backward.
|
|
||||||
recurse (bool): Whether to set for all submodules or just the
|
|
||||||
passed-in module.
|
|
||||||
"""
|
|
||||||
self_module = cast(nn.Module, self)
|
|
||||||
modules = list(self_module.modules()) if recurse else [self_module]
|
|
||||||
for module in modules:
|
|
||||||
if isinstance(module, FSDPModule):
|
|
||||||
state = module._get_fsdp_state()
|
|
||||||
if fsdp_param_group := state._fsdp_param_group:
|
|
||||||
fsdp_param_group.reshard_after_backward = reshard_after_backward
|
|
||||||
|
|
||||||
def set_modules_to_forward_prefetch(self, modules: List["FSDPModule"]) -> None:
|
|
||||||
"""
|
|
||||||
Sets the FSDP modules for which this FSDP module should explicitly
|
|
||||||
prefetch all-gathers in forward. The prefetching runs after this
|
|
||||||
module's all-gather copy-out.
|
|
||||||
|
|
||||||
Passing a singleton list containing the next FSDP module gives the same
|
|
||||||
all-gather overlap behavior as the default overlap behavior, except the
|
|
||||||
prefetched all-gather is issued earlier from the CPU. Passing a list
|
|
||||||
with at least length two is required for more aggressive overlap and
|
|
||||||
will use more reserved memory.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
modules (List[FSDPModule]): FSDP modules to prefetch.
|
|
||||||
"""
|
|
||||||
_assert_all_fsdp_modules(modules)
|
|
||||||
self._get_fsdp_state()._states_to_forward_prefetch = [
|
|
||||||
module._get_fsdp_state() for module in modules
|
|
||||||
]
|
|
||||||
|
|
||||||
def set_modules_to_backward_prefetch(self, modules: List["FSDPModule"]) -> None:
|
|
||||||
"""
|
|
||||||
Sets the FSDP modules for which this FSDP module should explicitly
|
|
||||||
prefetch all-gathers in backward. This overrides the default backward
|
|
||||||
pretching implementation that prefetches the next FSDP module based on
|
|
||||||
the reverse post-forward order.
|
|
||||||
|
|
||||||
Passing a singleton list containing the previous FSDP module gives the
|
|
||||||
same all-gather overlap behavior as the default overlap behavior.
|
|
||||||
Passing a list with at least length two is required for more aggressive
|
|
||||||
overlap and will use more reserved memory.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
modules (List[FSDPModule]): FSDP modules to prefetch.
|
|
||||||
"""
|
|
||||||
_assert_all_fsdp_modules(modules)
|
|
||||||
self._get_fsdp_state()._states_to_backward_prefetch = [
|
|
||||||
module._get_fsdp_state() for module in modules
|
|
||||||
]
|
|
||||||
|
|
||||||
def set_post_optim_event(self, event: torch.Event) -> None:
|
|
||||||
"""
|
|
||||||
Sets a post-optimizer-step event for the root FSDP module to wait the
|
|
||||||
all-gather streams on.
|
|
||||||
|
|
||||||
By default, the root FSDP module waits the all-gather streams on the
|
|
||||||
current stream to ensure that the optimizer step has finished before
|
|
||||||
all-gathering. However, this may introduce false dependencies if
|
|
||||||
there is unrelated computation after the optimizer step. This API
|
|
||||||
allows the user to provide their own event to wait on. After the root
|
|
||||||
waits on the event, the event is discarded, so this API should be
|
|
||||||
called with a new event each iteration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event (torch.Event): Event recorded after the optimizer step
|
|
||||||
to wait all-gather streams on.
|
|
||||||
"""
|
|
||||||
self._get_fsdp_state()._state_ctx.post_optim_event = event
|
|
||||||
|
|
||||||
def set_reduce_scatter_divide_factor(self, factor: float) -> None:
|
|
||||||
"""
|
|
||||||
Sets a custom divide factor for the reduce-scatter. This becomes a
|
|
||||||
custom reduce op using NCCL's PreMulSum, which allows multiplying by
|
|
||||||
the factor before reduction.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
factor (float): Custom divide factor.
|
|
||||||
"""
|
|
||||||
state = self._get_fsdp_state()
|
|
||||||
if (fsdp_param_group := state._fsdp_param_group) is not None:
|
|
||||||
mul_factor = 1.0 / float(factor)
|
|
||||||
reduce_op = torch.distributed._make_nccl_premul_sum(mul_factor)
|
|
||||||
fsdp_param_group.reduce_scatter_reduce_op = reduce_op
|
|
||||||
|
|
||||||
def set_unshard_in_backward(self, unshard_in_backward: bool) -> None:
|
|
||||||
"""
|
|
||||||
Sets whether the FSDP module's parameters need to be unsharded in
|
|
||||||
backward. This can be used in expert cases when the user knows that all
|
|
||||||
parameters in this FSDP module's parameter group are not needed for
|
|
||||||
backward computation (e.g. embedding).
|
|
||||||
"""
|
|
||||||
state = self._get_fsdp_state()
|
|
||||||
if (fsdp_param_group := state._fsdp_param_group) is not None:
|
|
||||||
fsdp_param_group.unshard_in_backward = unshard_in_backward
|
|
||||||
|
|
||||||
def _set_unshard_async_op(self, async_op: bool):
|
|
||||||
"""
|
|
||||||
Sets whether to use ``async_op=True`` or ``False`` for the pre-forward
|
|
||||||
and pre-backward unshard op. This defaults to ``False`` but can be set
|
|
||||||
to ``True`` with this method.
|
|
||||||
|
|
||||||
Setting this to ``True`` allows the all-gather allocations to happen in
|
|
||||||
the default stream, avoiding inter-stream memory fragmentation.
|
|
||||||
However, you must use explicit prefetching (e.g. via :meth:`unshard`)
|
|
||||||
in forward to still get overlap, and the pre-all-gather ops like dtype
|
|
||||||
casting and copy-in will not overlap with compute.
|
|
||||||
"""
|
|
||||||
self_module = cast(nn.Module, self)
|
|
||||||
for module in self_module.modules():
|
|
||||||
if isinstance(module, FSDPModule):
|
|
||||||
state = module._get_fsdp_state()
|
|
||||||
if fsdp_param_group := state._fsdp_param_group:
|
|
||||||
fsdp_param_group.unshard_async_op = async_op
|
|
||||||
|
|
||||||
def _get_fsdp_state(self) -> FSDPState:
|
|
||||||
if (state := _get_module_fsdp_state(cast(nn.Module, self))) is None:
|
|
||||||
raise AssertionError(f"No FSDP state found on {self}")
|
|
||||||
return state
|
|
||||||
|
|
||||||
def _apply(self, *args: Any, **kwargs: Any) -> Any:
|
|
||||||
# Reshard to ensure that sharded parameters are registered
|
|
||||||
self.reshard()
|
|
||||||
ret = super()._apply(*args, **kwargs) # type: ignore[misc]
|
|
||||||
state = self._get_fsdp_state()
|
|
||||||
if not (fsdp_param_group := state._fsdp_param_group):
|
|
||||||
return ret
|
|
||||||
# TODO: Remove this padding logic once DTensor pads the local tensor:
|
|
||||||
# https://github.com/pytorch/pytorch/issues/113045
|
|
||||||
with torch.no_grad():
|
|
||||||
for fsdp_param in fsdp_param_group.fsdp_params:
|
|
||||||
fsdp_param.reset_sharded_param()
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
class UnshardHandle:
|
|
||||||
"""
|
|
||||||
A handle to wait on the unshard op.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
fsdp_param_group (FSDPParamGroup, optional): FSDP parameter group to
|
|
||||||
unshard. This should be ``None`` iff the FSDP module does not
|
|
||||||
manage any parameters, meaning the unshard is a no-op.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, fsdp_param_group: Optional[FSDPParamGroup]):
|
|
||||||
self._fsdp_param_group = fsdp_param_group
|
|
||||||
|
|
||||||
def wait(self):
|
|
||||||
"""
|
|
||||||
Waits on the unshard op.
|
|
||||||
|
|
||||||
This ensures that the current stream can use the unsharded parameters,
|
|
||||||
which are now registered to the module.
|
|
||||||
"""
|
|
||||||
if self._fsdp_param_group is not None:
|
|
||||||
self._fsdp_param_group.wait_for_unshard()
|
|
||||||
# Avoid keeping a reference
|
|
||||||
self._fsdp_param_group = None
|
|
||||||
|
|
||||||
|
|
||||||
def register_fsdp_forward_method(module: nn.Module, method_name: str) -> None:
|
|
||||||
"""
|
|
||||||
Registers a method on ``module`` to be a forward method for FSDP.
|
|
||||||
|
|
||||||
FSDP only knows to run its pre-forward and post-forward hooks on the
|
|
||||||
default :meth:`nn.Module.forward` method. This function patches a user
|
|
||||||
specified method to run the pre/post-forward hooks before/after the method,
|
|
||||||
respectively. If ``module`` is not an :class:`FSDPModule`, then this is a
|
|
||||||
no-op.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
module (nn.Module): Module to register the forward method on.
|
|
||||||
method_name (str): Name of the forward method.
|
|
||||||
"""
|
|
||||||
if not isinstance(module, FSDPModule):
|
|
||||||
# Make no-op to allow including both when using/not using FSDP
|
|
||||||
return
|
|
||||||
if not hasattr(module, method_name):
|
|
||||||
raise ValueError(f"{type(module)} does not have a method {method_name}")
|
|
||||||
orig_method = getattr(module, method_name)
|
|
||||||
|
|
||||||
@functools.wraps(orig_method)
|
|
||||||
def wrapped_method(self, *args, **kwargs):
|
|
||||||
fsdp_state = self._get_fsdp_state()
|
|
||||||
args, kwargs = fsdp_state._pre_forward(self, args, kwargs)
|
|
||||||
out = orig_method(*args, **kwargs)
|
|
||||||
return fsdp_state._post_forward(self, args, out)
|
|
||||||
|
|
||||||
# Use `__get__` to make `wrapped_method` an instance method
|
|
||||||
setattr(
|
|
||||||
module,
|
|
||||||
method_name,
|
|
||||||
wrapped_method.__get__(module, type(module)), # type:ignore[attr-defined]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _assert_all_fsdp_modules(modules: Iterable[Any]) -> None:
|
|
||||||
for module in modules:
|
|
||||||
if not isinstance(module, FSDPModule):
|
|
||||||
raise ValueError(f"Expects FSDPModule but got {type(module)}: {module}")
|
|
||||||
|
|
|
||||||
|
|
@ -7,8 +7,6 @@ import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch import nn, optim
|
from torch import nn, optim
|
||||||
from torch._guards import active_fake_mode
|
from torch._guards import active_fake_mode
|
||||||
from torch.distributed._composable.fsdp import FSDPModule
|
|
||||||
from torch.distributed._composable.fsdp._fsdp_param_group import FSDPParamGroup
|
|
||||||
from torch.distributed._tools.mem_tracker import _RefType, _State, MemTracker
|
from torch.distributed._tools.mem_tracker import _RefType, _State, MemTracker
|
||||||
from torch.distributed.distributed_c10d import (
|
from torch.distributed.distributed_c10d import (
|
||||||
_IllegalWork,
|
_IllegalWork,
|
||||||
|
|
@ -16,6 +14,8 @@ from torch.distributed.distributed_c10d import (
|
||||||
ReduceOp,
|
ReduceOp,
|
||||||
Work,
|
Work,
|
||||||
)
|
)
|
||||||
|
from torch.distributed.fsdp import FSDPModule
|
||||||
|
from torch.distributed.fsdp._fully_shard._fsdp_param_group import FSDPParamGroup
|
||||||
from torch.futures import Future
|
from torch.futures import Future
|
||||||
from torch.utils._python_dispatch import TorchDispatchMode
|
from torch.utils._python_dispatch import TorchDispatchMode
|
||||||
from torch.utils._pytree import tree_map_only
|
from torch.utils._pytree import tree_map_only
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,13 @@
|
||||||
from ._flat_param import FlatParameter as FlatParameter
|
from ._flat_param import FlatParameter as FlatParameter
|
||||||
|
from ._fully_shard import (
|
||||||
|
CPUOffloadPolicy,
|
||||||
|
FSDPModule,
|
||||||
|
fully_shard,
|
||||||
|
MixedPrecisionPolicy,
|
||||||
|
OffloadPolicy,
|
||||||
|
register_fsdp_forward_method,
|
||||||
|
UnshardHandle,
|
||||||
|
)
|
||||||
from .fully_sharded_data_parallel import (
|
from .fully_sharded_data_parallel import (
|
||||||
BackwardPrefetch,
|
BackwardPrefetch,
|
||||||
CPUOffload,
|
CPUOffload,
|
||||||
|
|
@ -20,6 +29,7 @@ from .fully_sharded_data_parallel import (
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
# FSDP1
|
||||||
"BackwardPrefetch",
|
"BackwardPrefetch",
|
||||||
"CPUOffload",
|
"CPUOffload",
|
||||||
"FullOptimStateDictConfig",
|
"FullOptimStateDictConfig",
|
||||||
|
|
@ -36,4 +46,21 @@ __all__ = [
|
||||||
"StateDictConfig",
|
"StateDictConfig",
|
||||||
"StateDictSettings",
|
"StateDictSettings",
|
||||||
"StateDictType",
|
"StateDictType",
|
||||||
|
# FSDP2
|
||||||
|
"CPUOffloadPolicy",
|
||||||
|
"FSDPModule",
|
||||||
|
"fully_shard",
|
||||||
|
"MixedPrecisionPolicy",
|
||||||
|
"OffloadPolicy",
|
||||||
|
"register_fsdp_forward_method",
|
||||||
|
"UnshardHandle",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Set namespace for exposed private names
|
||||||
|
CPUOffloadPolicy.__module__ = "torch.distributed.fsdp"
|
||||||
|
FSDPModule.__module__ = "torch.distributed.fsdp"
|
||||||
|
fully_shard.__module__ = "torch.distributed.fsdp"
|
||||||
|
MixedPrecisionPolicy.__module__ = "torch.distributed.fsdp"
|
||||||
|
OffloadPolicy.__module__ = "torch.distributed.fsdp"
|
||||||
|
register_fsdp_forward_method.__module__ = "torch.distributed.fsdp"
|
||||||
|
UnshardHandle.__module__ = "torch.distributed.fsdp"
|
||||||
|
|
|
||||||
18
torch/distributed/fsdp/_fully_shard/__init__.py
Normal file
18
torch/distributed/fsdp/_fully_shard/__init__.py
Normal file
|
|
@ -0,0 +1,18 @@
|
||||||
|
from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy
|
||||||
|
from ._fully_shard import (
|
||||||
|
FSDPModule,
|
||||||
|
fully_shard,
|
||||||
|
register_fsdp_forward_method,
|
||||||
|
UnshardHandle,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CPUOffloadPolicy",
|
||||||
|
"FSDPModule",
|
||||||
|
"fully_shard",
|
||||||
|
"MixedPrecisionPolicy",
|
||||||
|
"OffloadPolicy",
|
||||||
|
"register_fsdp_forward_method",
|
||||||
|
"UnshardHandle",
|
||||||
|
]
|
||||||
|
|
@ -57,7 +57,10 @@ class MixedPrecisionPolicy:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OffloadPolicy:
|
class OffloadPolicy:
|
||||||
"""This base class represents the policy of no offloading."""
|
"""
|
||||||
|
This base class represents the policy of no offloading and is only used as
|
||||||
|
the default value for the ``offload_policy`` arg.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -71,10 +74,10 @@ class CPUOffloadPolicy(OffloadPolicy):
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
pin_memory (bool): Whether to pin sharded parameter and gradient
|
pin_memory (bool): Whether to pin sharded parameter and gradient
|
||||||
memory. Pinning memory allows H2D/D2H copying without blocking the
|
memory. Pinning memory allows both more efficient H2D/D2H copies
|
||||||
CPU and in turn, overlap with compute, but pinned memory cannot be
|
and for the copies to overlap with compute. However, the pinned
|
||||||
used by other processes. Set this to ``False`` if you have
|
memory cannot be used by other processes. Set this to ``False`` if
|
||||||
insufficient CPU memory. (Default: ``True``)
|
you have insufficient CPU memory. (Default: ``True``)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pin_memory: bool = True
|
pin_memory: bool = True
|
||||||
|
|
@ -29,7 +29,7 @@ from ._fsdp_common import (
|
||||||
from ._fsdp_param import FSDPParam, ParamModuleInfo, ShardedState
|
from ._fsdp_param import FSDPParam, ParamModuleInfo, ShardedState
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger("torch.distributed._composable.fsdp")
|
logger = logging.getLogger("torch.distributed.fsdp.fully_shard")
|
||||||
|
|
||||||
_ModuleToHandleDict = Dict[nn.Module, RemovableHandle] # for state dict
|
_ModuleToHandleDict = Dict[nn.Module, RemovableHandle] # for state dict
|
||||||
|
|
||||||
|
|
@ -42,7 +42,7 @@ if TYPE_CHECKING:
|
||||||
from ._fsdp_param import FSDPParam
|
from ._fsdp_param import FSDPParam
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger("torch.distributed._composable.fsdp")
|
logger = logging.getLogger("torch.distributed.fsdp.fully_shard")
|
||||||
|
|
||||||
|
|
||||||
class FSDPStateContext:
|
class FSDPStateContext:
|
||||||
523
torch/distributed/fsdp/_fully_shard/_fully_shard.py
Normal file
523
torch/distributed/fsdp/_fully_shard/_fully_shard.py
Normal file
|
|
@ -0,0 +1,523 @@
|
||||||
|
# mypy: allow-untyped-decorators
|
||||||
|
# mypy: allow-untyped-defs
|
||||||
|
import functools
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
cast,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
NoReturn,
|
||||||
|
Optional,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.distributed._composable import contract
|
||||||
|
from torch.distributed.tensor import DeviceMesh, Shard
|
||||||
|
from torch.distributed.utils import _get_root_modules
|
||||||
|
|
||||||
|
from ._fsdp_api import MixedPrecisionPolicy, OffloadPolicy
|
||||||
|
from ._fsdp_common import FSDPMeshInfo, HSDPMeshInfo
|
||||||
|
from ._fsdp_init import (
|
||||||
|
_get_device_from_mesh,
|
||||||
|
_get_managed_modules,
|
||||||
|
_get_managed_states,
|
||||||
|
_get_post_forward_mesh_info,
|
||||||
|
_init_default_fully_shard_mesh,
|
||||||
|
_move_states_to_device,
|
||||||
|
)
|
||||||
|
from ._fsdp_param_group import FSDPParamGroup
|
||||||
|
from ._fsdp_state import _get_module_fsdp_state, FSDPState
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"fully_shard",
|
||||||
|
"FSDPModule",
|
||||||
|
"UnshardHandle",
|
||||||
|
"register_fsdp_forward_method",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
cls_to_fsdp_cls: Dict[Type, Type] = {}
|
||||||
|
|
||||||
|
|
||||||
|
# The decorator adds a state object to `module` that can be accessed via
|
||||||
|
# `fully_shard.state(module)`. The state object and module are 1:1.
|
||||||
|
@contract(state_cls=FSDPState) # type: ignore[operator]
|
||||||
|
def fully_shard(
|
||||||
|
module: Union[nn.Module, List[nn.Module]],
|
||||||
|
*,
|
||||||
|
mesh: Optional[DeviceMesh] = None,
|
||||||
|
reshard_after_forward: Union[bool, int] = True,
|
||||||
|
shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]] = None,
|
||||||
|
mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(),
|
||||||
|
offload_policy: OffloadPolicy = OffloadPolicy(),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Apply fully sharded data parallelism (FSDP) to ``module``, where FSDP
|
||||||
|
shards module parameters, gradients, and optimizer states across data
|
||||||
|
parallel workers to save memory at the cost of communication.
|
||||||
|
|
||||||
|
At initialization, FSDP shards the module's parameters across the data
|
||||||
|
parallel workers given by ``mesh``. Before forward, FSDP all-gathers the
|
||||||
|
sharded parameters across the data-parallel workers to get the unsharded
|
||||||
|
parameters for forward computation. If ``reshard_after_forward`` is
|
||||||
|
``True``, then FSDP frees the unsharded parameters after forward and
|
||||||
|
re-all-gathers them in backward before gradient computation. After gradient
|
||||||
|
computation, FSDP frees the unsharded parameters and reduce-scatters the
|
||||||
|
unsharded gradients across data-parallel workers.
|
||||||
|
|
||||||
|
This implementation represents the sharded parameters as :class:`DTensor` s
|
||||||
|
sharded on dim-0, while the unsharded parameters will be like the original
|
||||||
|
parameters on ``module`` (e.g. :class:`torch.Tensor` if originally
|
||||||
|
:class:`torch.Tensor`). A module
|
||||||
|
`forward pre-hook <https://pytorch.org/docs/main/generated/torch.nn.Module.html#torch.nn.Module.register_forward_pre_hook>`_
|
||||||
|
on ``module`` all-gathers the parameters, and a module
|
||||||
|
`forward hook <https://pytorch.org/docs/main/generated/torch.nn.Module.html#torch.nn.Module.register_forward_hook>`_
|
||||||
|
on ``module`` frees them (if needed). Similar backward hooks all-gather
|
||||||
|
parameters and later free parameters and reduce-scatter gradients.
|
||||||
|
|
||||||
|
Since grouping multiple tensors together for one collective is critical for
|
||||||
|
communication efficiency, this implementation makes this grouping first
|
||||||
|
class. Calling :meth:`fully_shard` on ``module`` constructs one group that
|
||||||
|
includes the parameters in ``module.parameters()`` except those already
|
||||||
|
assigned to a group from an earlier call on a submodule. This means that
|
||||||
|
:meth:`fully_shard` should be called bottom-up on your model. Each group's
|
||||||
|
parameters are all-gathered in one collective, and its gradients are
|
||||||
|
reduce-scattered in one collective. Partitioning the model into multiple
|
||||||
|
groups ("layer by layer") allows for peak memory savings and communication/computation
|
||||||
|
overlap. Users generally should *not* call :meth:`fully_shard` only on the
|
||||||
|
topmost root module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (Union[nn.Module, List[nn.Module]): The module or modules to
|
||||||
|
shard with FSDP and group together for communication.
|
||||||
|
mesh (Optional[DeviceMesh]): This data parallel mesh defines the
|
||||||
|
sharding and device. If 1D, then parameters are fully sharded
|
||||||
|
across the 1D mesh (FSDP) with ``(Shard(0),)`` placement. If 2D,
|
||||||
|
then parameters are sharded across the 1st dim and replicated
|
||||||
|
across the 0th dim (HSDP) with ``(Replicate(), Shard(0))``
|
||||||
|
placement. The mesh's device type gives the device type used for
|
||||||
|
communication; if a CUDA or CUDA-like device type, then we use the
|
||||||
|
current device.
|
||||||
|
reshard_after_forward (Union[bool, int]): This controls the parameter
|
||||||
|
behavior after forward and can trade off memory and communication:
|
||||||
|
|
||||||
|
- If ``True``, then this reshards parameters after forward and
|
||||||
|
re-all-gathers in backward.
|
||||||
|
- If ``False``, then this keeps the unsharded parameters in memory
|
||||||
|
after forward and avoids the all-gather in backward.
|
||||||
|
- If an ``int``, then this represents the world size to reshard to
|
||||||
|
after forward. It should be a non-trivial divisor of the ``mesh``
|
||||||
|
shard dim size (i.e. excluding 1 and the dim size itself). A
|
||||||
|
choice may be the intra-node size (e.g. ``torch.cuda.device_count()``).
|
||||||
|
This allows the all-gather in backward to be over a smaller world
|
||||||
|
size at the cost of higher memory usage than setting to ``True``.
|
||||||
|
- The root FSDP state has its value specially set to ``False`` as a
|
||||||
|
heuristic since its parameters would typically be immediately
|
||||||
|
all-gathered for backward.
|
||||||
|
- After forward, the parameters registered to the module depend on
|
||||||
|
to this: The registered parameters are the sharded parameters if
|
||||||
|
``True``; unsharded parameters if ``False``; and the paramters
|
||||||
|
resharded to the smaller mesh otherwise. To modify the parameters
|
||||||
|
between forward and backward, the registered parameters must be
|
||||||
|
the sharded parameters. For ``False`` or an ``int``, this can be
|
||||||
|
done by manually resharding via :meth:`reshard`.
|
||||||
|
shard_placement_fn (Optional[Callable[[nn.Parameter], Optional[Shard]]]):
|
||||||
|
This callable can be used to override the sharding placement for a
|
||||||
|
parameter to shard a parameter on a dimension other than dim-0. If
|
||||||
|
this callable returns a :class:`Shard` placement (not ``None``),
|
||||||
|
then FSDP will shard according to that placement (e.g. ``Shard(1)``).
|
||||||
|
If sharding on a nonzero dim, we currently require even sharding,
|
||||||
|
i.e. the tensor dim size on that dim must be divisible by the FSDP
|
||||||
|
shard mesh size.
|
||||||
|
mp_policy (MixedPrecisionPolicy): This controls the mixed precision
|
||||||
|
policy, which offers parameter/reduction mixed precision for this
|
||||||
|
module. See :class:`MixedPrecisionPolicy` for details.
|
||||||
|
offload_policy (OffloadPolicy): This controls the offloading policy,
|
||||||
|
which offers parameter/gradient/optimizer state offloading. See
|
||||||
|
:class:`OffloadPolicy` and its subclasses for details.
|
||||||
|
"""
|
||||||
|
if isinstance(module, (nn.ModuleList, nn.ModuleDict)):
|
||||||
|
raise ValueError(
|
||||||
|
f"fully_shard does not support containers that do not implement forward: {module}"
|
||||||
|
)
|
||||||
|
mesh = mesh or _init_default_fully_shard_mesh()
|
||||||
|
if mesh.ndim not in (1, 2):
|
||||||
|
raise ValueError(f"fully_shard expects a 1D or 2D DeviceMesh but got {mesh}")
|
||||||
|
elif mesh.ndim == 1:
|
||||||
|
mesh_info = FSDPMeshInfo(mesh, shard_mesh_dim=0)
|
||||||
|
else:
|
||||||
|
if mesh.mesh_dim_names is None:
|
||||||
|
raise AssertionError(
|
||||||
|
"Please init the 2D mesh for HSDP with mesh_dim_names specified"
|
||||||
|
)
|
||||||
|
mesh_info = HSDPMeshInfo(mesh, shard_mesh_dim=1, replicate_mesh_dim=0)
|
||||||
|
device = _get_device_from_mesh(mesh)
|
||||||
|
post_forward_mesh_info = _get_post_forward_mesh_info(
|
||||||
|
reshard_after_forward, mesh_info
|
||||||
|
)
|
||||||
|
|
||||||
|
arg_module = module
|
||||||
|
modules = (
|
||||||
|
(module,) if isinstance(module, nn.Module) else tuple(_get_root_modules(module))
|
||||||
|
)
|
||||||
|
state = fully_shard.state(modules[0])
|
||||||
|
state.init(modules, device, mp_policy)
|
||||||
|
|
||||||
|
managed_modules = _get_managed_modules(modules)
|
||||||
|
params, buffers = _get_managed_states(managed_modules)
|
||||||
|
_move_states_to_device(params, buffers, device)
|
||||||
|
if params:
|
||||||
|
state._fsdp_param_group = FSDPParamGroup(
|
||||||
|
params,
|
||||||
|
modules,
|
||||||
|
mesh_info,
|
||||||
|
post_forward_mesh_info,
|
||||||
|
device,
|
||||||
|
shard_placement_fn,
|
||||||
|
mp_policy,
|
||||||
|
offload_policy,
|
||||||
|
)
|
||||||
|
|
||||||
|
# For Dynamo
|
||||||
|
for managed_module in managed_modules:
|
||||||
|
managed_module._is_fsdp_managed_module = True # type: ignore[assignment]
|
||||||
|
managed_module._fsdp_use_orig_params = True # type: ignore[assignment]
|
||||||
|
|
||||||
|
# Place FSDP leftmost for highest priority in the method resolution order
|
||||||
|
for module in modules:
|
||||||
|
cls = module.__class__
|
||||||
|
new_cls = cls_to_fsdp_cls.get(cls, None)
|
||||||
|
if not new_cls:
|
||||||
|
dct = {"__deepcopy__": _unimplemented_deepcopy}
|
||||||
|
new_cls = type(f"FSDP{cls.__name__}", (FSDPModule, cls), dct)
|
||||||
|
cls_to_fsdp_cls[cls] = new_cls
|
||||||
|
module.__class__ = new_cls
|
||||||
|
return arg_module
|
||||||
|
|
||||||
|
|
||||||
|
def _unimplemented_deepcopy(*args: Any, **kwargs: Any) -> NoReturn:
|
||||||
|
raise AssertionError(
|
||||||
|
"FSDP does not support deepcopy. Please use state dict for serialization."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FSDPModule:
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Override ``__new__`` to remove the FSDP class and directly construct
|
||||||
|
the original class for cases like indexing into a container module.
|
||||||
|
"""
|
||||||
|
# Use index 2 since 0 is the dynamically constructed `FSDP<...>` class
|
||||||
|
# and index 1 is the `FSDPModule` class itself
|
||||||
|
orig_cls = cls.__mro__[2]
|
||||||
|
self = orig_cls.__new__(orig_cls, *args, **kwargs)
|
||||||
|
self.__init__(*args, **kwargs)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def reshard(self) -> None:
|
||||||
|
"""
|
||||||
|
Reshards the module's parameters, freeing the unsharded parameters if
|
||||||
|
they are allocated and registering the sharded parameters to the
|
||||||
|
module. This method is *not* recursive.
|
||||||
|
"""
|
||||||
|
state = self._get_fsdp_state()
|
||||||
|
if fsdp_param_group := state._fsdp_param_group:
|
||||||
|
fsdp_param_group.reshard()
|
||||||
|
|
||||||
|
def unshard(self, async_op: bool = False) -> Optional["UnshardHandle"]:
|
||||||
|
"""
|
||||||
|
Unshards the module's parameters by allocating memory and all-gathering
|
||||||
|
the parameters. This method is *not* recursive. The unshard follows the
|
||||||
|
:class:`MixedPrecisionPolicy`, so it will all-gather following
|
||||||
|
``param_dtype`` if set.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
async_op (bool): If ``True``, then returns a :class:`UnshardHandle`
|
||||||
|
that has a :meth:`wait` method to wait on the unshard op. If
|
||||||
|
``False``, then returns ``None`` and waits on the handle inside
|
||||||
|
this function.
|
||||||
|
|
||||||
|
.. note:: If ``async_op=True``, then FSDP will wait on the pending
|
||||||
|
unshard in the module's pre-forward for the user. The user only
|
||||||
|
needs to call :meth:`wait` explicitly if the wait should happen
|
||||||
|
before pre-forward.
|
||||||
|
"""
|
||||||
|
state = self._get_fsdp_state()
|
||||||
|
fsdp_param_group = state._fsdp_param_group
|
||||||
|
if fsdp_param_group is not None:
|
||||||
|
fsdp_param_group.lazy_init()
|
||||||
|
fsdp_param_group.unshard(async_op=async_op)
|
||||||
|
handle = _UnshardHandleImpl(fsdp_param_group)
|
||||||
|
if async_op:
|
||||||
|
return handle
|
||||||
|
handle.wait()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set_is_last_backward(self, is_last_backward: bool) -> None:
|
||||||
|
"""
|
||||||
|
Sets whether the next backward is the last one. On the last backward,
|
||||||
|
FSDP waits on pending gradient reduction and clears internal data
|
||||||
|
data structures for backward prefetching. This can be useful for
|
||||||
|
microbatching.
|
||||||
|
"""
|
||||||
|
state = self._get_fsdp_state()
|
||||||
|
state._state_ctx.is_last_backward = is_last_backward
|
||||||
|
|
||||||
|
def set_requires_gradient_sync(
|
||||||
|
self, requires_gradient_sync: bool, *, recurse: bool = True
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Sets if the module should sync gradients. This can be used to implement
|
||||||
|
gradient accumulation *without communication*. For HSDP, this controls
|
||||||
|
both reduce-scatter and all-reduce together.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
requires_gradient_sync (bool): Whether to reduce gradients for the
|
||||||
|
module's parameters.
|
||||||
|
recurse (bool): Whether to set for all FSDP submodules or just the
|
||||||
|
passed-in module.
|
||||||
|
"""
|
||||||
|
self_module = cast(nn.Module, self)
|
||||||
|
modules = list(self_module.modules()) if recurse else [self_module]
|
||||||
|
for module in modules:
|
||||||
|
if isinstance(module, FSDPModule):
|
||||||
|
state = module._get_fsdp_state()
|
||||||
|
if fsdp_param_group := state._fsdp_param_group:
|
||||||
|
fsdp_param_group.reduce_grads = requires_gradient_sync
|
||||||
|
fsdp_param_group.all_reduce_grads = requires_gradient_sync
|
||||||
|
|
||||||
|
def set_requires_all_reduce(
|
||||||
|
self, requires_all_reduce: bool, *, recurse: bool = True
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Sets if the module should all-reduce gradients. This can be used to
|
||||||
|
implement gradient accumulation with only reduce-scatter but not
|
||||||
|
all-reduce for HSDP.
|
||||||
|
"""
|
||||||
|
self_module = cast(nn.Module, self)
|
||||||
|
modules = list(self_module.modules()) if recurse else [self_module]
|
||||||
|
for module in modules:
|
||||||
|
if isinstance(module, FSDPModule):
|
||||||
|
state = module._get_fsdp_state()
|
||||||
|
if fsdp_param_group := state._fsdp_param_group:
|
||||||
|
fsdp_param_group.all_reduce_grads = requires_all_reduce
|
||||||
|
|
||||||
|
def set_reshard_after_backward(
|
||||||
|
self, reshard_after_backward: bool, *, recurse: bool = True
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Sets if the module should reshard parameters after backward. This can
|
||||||
|
be used during gradient accumulation to trade off higher memory for
|
||||||
|
reduced communication since the unsharded parameters do not need to be
|
||||||
|
re-all-gathered before the next forward.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reshard_after_backward (bool): Whether to reshard parameters after
|
||||||
|
backward.
|
||||||
|
recurse (bool): Whether to set for all FSDP submodules or just the
|
||||||
|
passed-in module.
|
||||||
|
"""
|
||||||
|
self_module = cast(nn.Module, self)
|
||||||
|
modules = list(self_module.modules()) if recurse else [self_module]
|
||||||
|
for module in modules:
|
||||||
|
if isinstance(module, FSDPModule):
|
||||||
|
state = module._get_fsdp_state()
|
||||||
|
if fsdp_param_group := state._fsdp_param_group:
|
||||||
|
fsdp_param_group.reshard_after_backward = reshard_after_backward
|
||||||
|
|
||||||
|
def set_modules_to_forward_prefetch(self, modules: List["FSDPModule"]) -> None:
|
||||||
|
"""
|
||||||
|
Sets the FSDP modules for which this FSDP module should explicitly
|
||||||
|
prefetch all-gathers in forward. The prefetching runs after this
|
||||||
|
module's all-gather copy-out.
|
||||||
|
|
||||||
|
Passing a singleton list containing the next FSDP module gives the same
|
||||||
|
all-gather overlap behavior as the default overlap behavior, except the
|
||||||
|
prefetched all-gather is issued earlier from the CPU. Passing a list
|
||||||
|
with at least length two is required for more aggressive overlap and
|
||||||
|
will use more reserved memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
modules (List[FSDPModule]): FSDP modules to prefetch.
|
||||||
|
"""
|
||||||
|
_assert_all_fsdp_modules(modules)
|
||||||
|
self._get_fsdp_state()._states_to_forward_prefetch = [
|
||||||
|
module._get_fsdp_state() for module in modules
|
||||||
|
]
|
||||||
|
|
||||||
|
def set_modules_to_backward_prefetch(self, modules: List["FSDPModule"]) -> None:
|
||||||
|
"""
|
||||||
|
Sets the FSDP modules for which this FSDP module should explicitly
|
||||||
|
prefetch all-gathers in backward. This overrides the default backward
|
||||||
|
pretching implementation that prefetches the next FSDP module based on
|
||||||
|
the reverse post-forward order.
|
||||||
|
|
||||||
|
Passing a singleton list containing the previous FSDP module gives the
|
||||||
|
same all-gather overlap behavior as the default overlap behavior.
|
||||||
|
Passing a list with at least length two is required for more aggressive
|
||||||
|
overlap and will use more reserved memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
modules (List[FSDPModule]): FSDP modules to prefetch.
|
||||||
|
"""
|
||||||
|
_assert_all_fsdp_modules(modules)
|
||||||
|
self._get_fsdp_state()._states_to_backward_prefetch = [
|
||||||
|
module._get_fsdp_state() for module in modules
|
||||||
|
]
|
||||||
|
|
||||||
|
def set_post_optim_event(self, event: torch.Event) -> None:
|
||||||
|
"""
|
||||||
|
Sets a post-optimizer-step event for the root FSDP module to wait the
|
||||||
|
all-gather streams on.
|
||||||
|
|
||||||
|
By default, the root FSDP module waits the all-gather streams on the
|
||||||
|
current stream to ensure that the optimizer step has finished before
|
||||||
|
all-gathering. However, this may introduce false dependencies if
|
||||||
|
there is unrelated computation after the optimizer step. This API
|
||||||
|
allows the user to provide their own event to wait on. After the root
|
||||||
|
waits on the event, the event is discarded, so this API should be
|
||||||
|
called with a new event each iteration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event (torch.Event): Event recorded after the optimizer step
|
||||||
|
to wait all-gather streams on.
|
||||||
|
"""
|
||||||
|
self._get_fsdp_state()._state_ctx.post_optim_event = event
|
||||||
|
|
||||||
|
def set_reduce_scatter_divide_factor(self, factor: float) -> None:
|
||||||
|
"""
|
||||||
|
Sets a custom divide factor for the reduce-scatter. This becomes a
|
||||||
|
custom reduce op using NCCL's PreMulSum, which allows multiplying by
|
||||||
|
the factor before reduction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
factor (float): Custom divide factor.
|
||||||
|
"""
|
||||||
|
state = self._get_fsdp_state()
|
||||||
|
if (fsdp_param_group := state._fsdp_param_group) is not None:
|
||||||
|
mul_factor = 1.0 / float(factor)
|
||||||
|
reduce_op = torch.distributed._make_nccl_premul_sum(mul_factor)
|
||||||
|
fsdp_param_group.reduce_scatter_reduce_op = reduce_op
|
||||||
|
|
||||||
|
def set_unshard_in_backward(self, unshard_in_backward: bool) -> None:
|
||||||
|
"""
|
||||||
|
Sets whether the FSDP module's parameters need to be unsharded in
|
||||||
|
backward. This can be used in expert cases when the user knows that all
|
||||||
|
parameters in this FSDP module's parameter group are not needed for
|
||||||
|
backward computation (e.g. embedding).
|
||||||
|
"""
|
||||||
|
state = self._get_fsdp_state()
|
||||||
|
if (fsdp_param_group := state._fsdp_param_group) is not None:
|
||||||
|
fsdp_param_group.unshard_in_backward = unshard_in_backward
|
||||||
|
|
||||||
|
def _set_unshard_async_op(self, async_op: bool):
|
||||||
|
"""
|
||||||
|
Sets whether to use ``async_op=True`` or ``False`` for the pre-forward
|
||||||
|
and pre-backward unshard op. This defaults to ``False`` but can be set
|
||||||
|
to ``True`` with this method.
|
||||||
|
|
||||||
|
Setting this to ``True`` allows the all-gather allocations to happen in
|
||||||
|
the default stream, avoiding inter-stream memory fragmentation.
|
||||||
|
However, you must use explicit prefetching (e.g. via :meth:`unshard`)
|
||||||
|
in forward to still get overlap, and the pre-all-gather ops like dtype
|
||||||
|
casting and copy-in will not overlap with compute.
|
||||||
|
"""
|
||||||
|
self_module = cast(nn.Module, self)
|
||||||
|
for module in self_module.modules():
|
||||||
|
if isinstance(module, FSDPModule):
|
||||||
|
state = module._get_fsdp_state()
|
||||||
|
if fsdp_param_group := state._fsdp_param_group:
|
||||||
|
fsdp_param_group.unshard_async_op = async_op
|
||||||
|
|
||||||
|
def _get_fsdp_state(self) -> FSDPState:
|
||||||
|
if (state := _get_module_fsdp_state(cast(nn.Module, self))) is None:
|
||||||
|
raise AssertionError(f"No FSDP state found on {self}")
|
||||||
|
return state
|
||||||
|
|
||||||
|
def _apply(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
# Reshard to ensure that sharded parameters are registered
|
||||||
|
self.reshard()
|
||||||
|
ret = super()._apply(*args, **kwargs) # type: ignore[misc]
|
||||||
|
state = self._get_fsdp_state()
|
||||||
|
if not (fsdp_param_group := state._fsdp_param_group):
|
||||||
|
return ret
|
||||||
|
# TODO: Remove this padding logic once DTensor pads the local tensor:
|
||||||
|
# https://github.com/pytorch/pytorch/issues/113045
|
||||||
|
with torch.no_grad():
|
||||||
|
for fsdp_param in fsdp_param_group.fsdp_params:
|
||||||
|
fsdp_param.reset_sharded_param()
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class UnshardHandle:
|
||||||
|
"""
|
||||||
|
A handle to wait on a :meth:`FSDPModule.unshard` op.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wait(self) -> None:
|
||||||
|
"""
|
||||||
|
Waits on the unshard op. This ensures that the current stream can use
|
||||||
|
the unsharded parameters, which are now registered to the module.
|
||||||
|
"""
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class _UnshardHandleImpl(UnshardHandle):
|
||||||
|
def __init__(self, fsdp_param_group: Optional[FSDPParamGroup]):
|
||||||
|
self._fsdp_param_group = fsdp_param_group
|
||||||
|
|
||||||
|
def wait(self):
|
||||||
|
if self._fsdp_param_group is not None:
|
||||||
|
self._fsdp_param_group.wait_for_unshard()
|
||||||
|
# Avoid keeping a reference
|
||||||
|
self._fsdp_param_group = None
|
||||||
|
|
||||||
|
|
||||||
|
def register_fsdp_forward_method(module: nn.Module, method_name: str) -> None:
|
||||||
|
"""
|
||||||
|
Registers a method on ``module`` to be considered a forward method for
|
||||||
|
FSDP.
|
||||||
|
|
||||||
|
FSDP all-gathers parameters pre-forward and optionally frees parameters
|
||||||
|
post-forward (depending on ``reshard_after_forward``). FSDP only knows to
|
||||||
|
do this for :meth:`nn.Module.forward` by default. This function patches a
|
||||||
|
user-specified method to run the pre/post-forward hooks before/after the
|
||||||
|
method, respectively. If ``module`` is not an :class:`FSDPModule`, then
|
||||||
|
this is a no-op.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (nn.Module): Module to register the forward method on.
|
||||||
|
method_name (str): Name of the forward method.
|
||||||
|
"""
|
||||||
|
if not isinstance(module, FSDPModule):
|
||||||
|
# Make no-op to allow including both when using/not using FSDP
|
||||||
|
return
|
||||||
|
if not hasattr(module, method_name):
|
||||||
|
raise ValueError(f"{type(module)} does not have a method {method_name}")
|
||||||
|
orig_method = getattr(module, method_name)
|
||||||
|
|
||||||
|
@functools.wraps(orig_method)
|
||||||
|
def wrapped_method(self, *args, **kwargs):
|
||||||
|
fsdp_state = self._get_fsdp_state()
|
||||||
|
args, kwargs = fsdp_state._pre_forward(self, args, kwargs)
|
||||||
|
out = orig_method(*args, **kwargs)
|
||||||
|
return fsdp_state._post_forward(self, args, out)
|
||||||
|
|
||||||
|
# Use `__get__` to make `wrapped_method` an instance method
|
||||||
|
setattr(
|
||||||
|
module,
|
||||||
|
method_name,
|
||||||
|
wrapped_method.__get__(module, type(module)), # type:ignore[attr-defined]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_all_fsdp_modules(modules: Iterable[Any]) -> None:
|
||||||
|
for module in modules:
|
||||||
|
if not isinstance(module, FSDPModule):
|
||||||
|
raise ValueError(f"Expects FSDPModule but got {type(module)}: {module}")
|
||||||
|
|
@ -24,7 +24,7 @@ from typing import (
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed._composable.fsdp.fully_shard import FSDPModule, UnshardHandle
|
from torch.distributed.fsdp import FSDPModule, UnshardHandle
|
||||||
from torch.profiler import record_function
|
from torch.profiler import record_function
|
||||||
|
|
||||||
from .microbatch import merge_chunks, split_args_kwargs_into_chunks, TensorChunkSpec
|
from .microbatch import merge_chunks, split_args_kwargs_into_chunks, TensorChunkSpec
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ import torch.distributed as dist
|
||||||
import torch.fx as fx
|
import torch.fx as fx
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch._subclasses.fake_tensor import FakeTensor
|
from torch._subclasses.fake_tensor import FakeTensor
|
||||||
from torch.distributed._composable.fsdp.fully_shard import FSDPModule, fully_shard
|
from torch.distributed.fsdp import FSDPModule, fully_shard
|
||||||
from torch.fx.node import map_aggregate
|
from torch.fx.node import map_aggregate
|
||||||
from torch.nn.parallel import DistributedDataParallel
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
from torch.utils._pytree import tree_map_only
|
from torch.utils._pytree import tree_map_only
|
||||||
|
|
|
||||||
|
|
@ -31,14 +31,17 @@ import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.distributed._composable import checkpoint
|
from torch.distributed._composable import checkpoint
|
||||||
from torch.distributed._composable.fsdp import fully_shard
|
from torch.distributed.device_mesh import DeviceMesh
|
||||||
from torch.distributed._composable.fsdp._fsdp_param_group import (
|
from torch.distributed.fsdp import (
|
||||||
|
CPUOffload,
|
||||||
|
fully_shard,
|
||||||
|
FullyShardedDataParallel as FSDP,
|
||||||
|
)
|
||||||
|
from torch.distributed.fsdp._common_utils import TrainingState
|
||||||
|
from torch.distributed.fsdp._fully_shard._fsdp_param_group import (
|
||||||
FSDPParamGroup,
|
FSDPParamGroup,
|
||||||
RegisterPostBackwardFunction,
|
RegisterPostBackwardFunction,
|
||||||
)
|
)
|
||||||
from torch.distributed.device_mesh import DeviceMesh
|
|
||||||
from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel as FSDP
|
|
||||||
from torch.distributed.fsdp._common_utils import TrainingState
|
|
||||||
from torch.distributed.fsdp._init_utils import NO_RESHARD_AFTER_FORWARD_STRATEGIES
|
from torch.distributed.fsdp._init_utils import NO_RESHARD_AFTER_FORWARD_STRATEGIES
|
||||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||||
BackwardPrefetch,
|
BackwardPrefetch,
|
||||||
|
|
@ -1484,7 +1487,7 @@ class FSDPTest(MultiProcessTestCase):
|
||||||
|
|
||||||
def test_compiled_fsdp(compile_compute_on_module: Optional[type] = None):
|
def test_compiled_fsdp(compile_compute_on_module: Optional[type] = None):
|
||||||
def fully_shard_with_compiled_compute(*args, **kwargs):
|
def fully_shard_with_compiled_compute(*args, **kwargs):
|
||||||
torch.distributed._composable.fsdp.fully_shard(*args, **kwargs) # type: ignore[operator]
|
torch.distributed.fsdp.fully_shard(*args, **kwargs) # type: ignore[operator]
|
||||||
if compile_compute_on_module is None or isinstance(
|
if compile_compute_on_module is None or isinstance(
|
||||||
args[0], compile_compute_on_module
|
args[0], compile_compute_on_module
|
||||||
):
|
):
|
||||||
|
|
@ -1497,7 +1500,7 @@ def test_compiled_fsdp(compile_compute_on_module: Optional[type] = None):
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
original_fully_shard = torch.distributed._composable.fsdp.fully_shard
|
original_fully_shard = torch.distributed.fsdp.fully_shard
|
||||||
for mode in FullyShardMode:
|
for mode in FullyShardMode:
|
||||||
if mode != FullyShardMode.EAGER and not has_triton():
|
if mode != FullyShardMode.EAGER and not has_triton():
|
||||||
warnings.warn("Inductor on GPU needs Triton and recent GPU arch")
|
warnings.warn("Inductor on GPU needs Triton and recent GPU arch")
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue