mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[reland] Create torch.distributed._shard package. (#72141)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72141 We have many sharding components currently: torch.distributed._sharded_tensor, torch.distributed._sharding_spec, torch.distributed._sharded_optimizer and more coming. As a result, organizing all of this under the `torch.distributed._shard` package. For BC reasons, I'm still keeping the old packages and have them just reference the new package. ghstack-source-id: 148150861 ghstack-source-id: 148150861 Test Plan: waitforbuildbot Reviewed By: fduwjj Differential Revision: D33904585 fbshipit-source-id: 057e847eb7521b536a3ee4e0f94871aacc752062 (cherry picked from commit 29a70dd7afde6083bab942081020a13278f38e52)
This commit is contained in:
parent
7b014cc645
commit
64670e414e
33 changed files with 927 additions and 882 deletions
|
|
@ -2,13 +2,13 @@
|
|||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import torch.distributed._sharded_tensor as sharded_tensor
|
||||
import torch.distributed._shard.sharded_tensor
|
||||
|
||||
from copy import deepcopy
|
||||
from torch.distributed._sharding_spec import (
|
||||
from torch.distributed._shard.sharding_spec import (
|
||||
ChunkShardingSpec,
|
||||
)
|
||||
from torch.distributed._sharded_optim import (
|
||||
from torch.distributed._shard.sharded_optim import (
|
||||
ShardedOptimizer,
|
||||
named_params_with_sharded_tensor
|
||||
)
|
||||
|
|
@ -20,7 +20,7 @@ from torch.testing._internal.common_utils import (
|
|||
run_tests,
|
||||
)
|
||||
|
||||
from torch.testing._internal.distributed._sharded_tensor import (
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor import (
|
||||
ShardedTensorTestBase,
|
||||
with_comms,
|
||||
)
|
||||
|
|
@ -4,17 +4,17 @@ import sys
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from torch.distributed import _sharded_tensor
|
||||
from torch.distributed._shard import sharded_tensor
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
|
||||
from torch.testing._internal.common_distributed import (
|
||||
requires_nccl,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.distributed._sharding_spec import (
|
||||
from torch.distributed._shard.sharding_spec import (
|
||||
ChunkShardingSpec,
|
||||
)
|
||||
from torch.testing._internal.distributed._sharded_tensor import (
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor import (
|
||||
ShardedTensorTestBase,
|
||||
with_comms,
|
||||
)
|
||||
|
|
@ -35,9 +35,9 @@ class TestShardedTensorBinaryOps(ShardedTensorTestBase):
|
|||
pg1 = _get_default_group() if pg1 is None else pg1
|
||||
pg2 = _get_default_group() if pg2 is None else pg2
|
||||
torch.manual_seed(TestShardedTensorBinaryOps.seed)
|
||||
st1 = _sharded_tensor.rand(spec1, sizes, process_group=pg1)
|
||||
st1 = sharded_tensor.rand(spec1, sizes, process_group=pg1)
|
||||
torch.manual_seed(TestShardedTensorBinaryOps.seed + seed_offset)
|
||||
st2 = _sharded_tensor.rand(spec2, sizes, process_group=pg2)
|
||||
st2 = sharded_tensor.rand(spec2, sizes, process_group=pg2)
|
||||
|
||||
TestShardedTensorBinaryOps.seed += 1
|
||||
return st1, st2
|
||||
|
|
@ -72,23 +72,23 @@ class TestShardedTensorBinaryOps(ShardedTensorTestBase):
|
|||
torch.nn.init.uniform_(st1.local_shards()[0].tensor)
|
||||
self.assertFalse(cmp_op(st1, st2))
|
||||
|
||||
st1 = _sharded_tensor.ones(spec, 10, 10)
|
||||
st2 = _sharded_tensor.ones(spec, 10, 5)
|
||||
st1 = sharded_tensor.ones(spec, 10, 10)
|
||||
st2 = sharded_tensor.ones(spec, 10, 5)
|
||||
self.assertFalse(cmp_op(st1, st2))
|
||||
|
||||
st1, st2 = self.get_random_tensors(spec, alt_spec, 10, 10)
|
||||
self.assertFalse(cmp_op(st1, st2))
|
||||
|
||||
st1 = _sharded_tensor.ones(spec, 10, 10)
|
||||
st2 = _sharded_tensor.zeros(spec, 10, 10)
|
||||
st1 = sharded_tensor.ones(spec, 10, 10)
|
||||
st2 = sharded_tensor.zeros(spec, 10, 10)
|
||||
self.assertFalse(cmp_op(st1, st2))
|
||||
|
||||
st1 = _sharded_tensor.ones(spec, 10, 10)
|
||||
st2 = _sharded_tensor.ones(spec, 10, 10, dtype=torch.double)
|
||||
st1 = sharded_tensor.ones(spec, 10, 10)
|
||||
st2 = sharded_tensor.ones(spec, 10, 10, dtype=torch.double)
|
||||
self.assertFalse(cmp_op(st1, st2))
|
||||
|
||||
st1 = _sharded_tensor.ones(spec, 10, 10)
|
||||
st2 = _sharded_tensor.ones(spec, 10, 10, requires_grad=True)
|
||||
st1 = sharded_tensor.ones(spec, 10, 10)
|
||||
st2 = sharded_tensor.ones(spec, 10, 10, requires_grad=True)
|
||||
self.assertFalse(cmp_op(st1, st2))
|
||||
|
||||
cpu_spec = ChunkShardingSpec(
|
||||
|
|
@ -100,8 +100,8 @@ class TestShardedTensorBinaryOps(ShardedTensorTestBase):
|
|||
"rank:3/cpu",
|
||||
],
|
||||
)
|
||||
st1 = _sharded_tensor.ones(cpu_spec, 10, 10)
|
||||
st2 = _sharded_tensor.ones(cpu_spec, 10, 10, pin_memory=True)
|
||||
st1 = sharded_tensor.ones(cpu_spec, 10, 10)
|
||||
st2 = sharded_tensor.ones(cpu_spec, 10, 10, pin_memory=True)
|
||||
self.assertFalse(cmp_op(st1, st2))
|
||||
|
||||
pg = dist.new_group([1, 0, 3, 2])
|
||||
|
|
@ -149,7 +149,7 @@ class TestShardedTensorBinaryOps(ShardedTensorTestBase):
|
|||
# compare different arrays
|
||||
st1, st2 = self.get_random_tensors(spec, spec, 10, 10, seed_offset=1)
|
||||
self.assertFalse(torch.allclose(st1, st2))
|
||||
# _sharded_tensor.rand produces uniform values in the [0,1] range.
|
||||
# sharded_tensor.rand produces uniform values in the [0,1] range.
|
||||
self.assertTrue(torch.allclose(st1, st2, atol=1))
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
@ -4,7 +4,7 @@ import sys
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed._sharded_tensor import (
|
||||
from torch.distributed._shard import (
|
||||
shard_parameter,
|
||||
)
|
||||
from torch.testing._internal.common_distributed import (
|
||||
|
|
@ -15,12 +15,12 @@ from torch.testing._internal.common_utils import (
|
|||
TEST_WITH_DEV_DBG_ASAN,
|
||||
run_tests,
|
||||
)
|
||||
from torch.testing._internal.distributed._sharded_tensor import (
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor import (
|
||||
TEST_GPU_NUM,
|
||||
ShardedTensorTestBase,
|
||||
with_comms,
|
||||
)
|
||||
from torch.testing._internal.distributed._sharded_tensor._test_ops_common import (
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor._test_ops_common import (
|
||||
generate_chunk_sharding_specs_for_test,
|
||||
generate_local_weight_sharding_params_for_test,
|
||||
)
|
||||
|
|
@ -4,7 +4,7 @@ import sys
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed._sharded_tensor import (
|
||||
from torch.distributed._shard import (
|
||||
shard_parameter,
|
||||
)
|
||||
from torch.testing._internal.common_distributed import (
|
||||
|
|
@ -15,12 +15,12 @@ from torch.testing._internal.common_utils import (
|
|||
TEST_WITH_DEV_DBG_ASAN,
|
||||
run_tests,
|
||||
)
|
||||
from torch.testing._internal.distributed._sharded_tensor import (
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor import (
|
||||
TEST_GPU_NUM,
|
||||
ShardedTensorTestBase,
|
||||
with_comms,
|
||||
)
|
||||
from torch.testing._internal.distributed._sharded_tensor._test_ops_common import (
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor._test_ops_common import (
|
||||
generate_chunk_sharding_specs_for_test,
|
||||
generate_local_weight_sharding_params_for_test,
|
||||
)
|
||||
|
|
@ -3,15 +3,15 @@
|
|||
import sys
|
||||
import torch
|
||||
|
||||
from torch.distributed import _sharded_tensor
|
||||
from torch.distributed._sharding_spec import (
|
||||
from torch.distributed._shard import sharded_tensor
|
||||
from torch.distributed._shard.sharding_spec import (
|
||||
ChunkShardingSpec,
|
||||
)
|
||||
from torch.testing._internal.common_distributed import (
|
||||
requires_nccl,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.distributed._sharded_tensor import (
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor import (
|
||||
ShardedTensorTestBase,
|
||||
with_comms,
|
||||
)
|
||||
|
|
@ -50,17 +50,17 @@ class TestShardedTensorNNInit(ShardedTensorTestBase):
|
|||
seed = 1234
|
||||
dtype = torch.double
|
||||
|
||||
sharded_tensor = _sharded_tensor.empty(spec, h, w, dtype=dtype)
|
||||
self.assertEqual(1, len(sharded_tensor.local_shards()))
|
||||
st = sharded_tensor.empty(spec, h, w, dtype=dtype)
|
||||
self.assertEqual(1, len(st.local_shards()))
|
||||
|
||||
# Clone local tensor to ensure torch.nn.init starts from the same input
|
||||
local_tensor_clone = torch.clone(sharded_tensor.local_shards()[0].tensor)
|
||||
local_tensor_clone = torch.clone(st.local_shards()[0].tensor)
|
||||
torch.manual_seed(seed)
|
||||
torch.nn.init.uniform_(sharded_tensor, a=a, b=b)
|
||||
torch.nn.init.uniform_(st, a=a, b=b)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
torch.nn.init.uniform_(local_tensor_clone, a=a, b=b)
|
||||
self.assertEqual(local_tensor_clone, sharded_tensor.local_shards()[0].tensor)
|
||||
self.assertEqual(local_tensor_clone, st.local_shards()[0].tensor)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
|
|
@ -85,17 +85,17 @@ class TestShardedTensorNNInit(ShardedTensorTestBase):
|
|||
seed = 1234
|
||||
dtype = torch.double
|
||||
|
||||
sharded_tensor = _sharded_tensor.empty(spec, h, w, dtype=dtype)
|
||||
self.assertEqual(1, len(sharded_tensor.local_shards()))
|
||||
st = sharded_tensor.empty(spec, h, w, dtype=dtype)
|
||||
self.assertEqual(1, len(st.local_shards()))
|
||||
|
||||
# Clone local tensor to ensure torch.nn.init starts from the same input
|
||||
local_tensor_clone = torch.clone(sharded_tensor.local_shards()[0].tensor)
|
||||
local_tensor_clone = torch.clone(st.local_shards()[0].tensor)
|
||||
torch.manual_seed(seed)
|
||||
torch.nn.init.normal_(sharded_tensor, mean=mean, std=std)
|
||||
torch.nn.init.normal_(st, mean=mean, std=std)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
torch.nn.init.normal_(local_tensor_clone, mean=mean, std=std)
|
||||
self.assertEqual(local_tensor_clone, sharded_tensor.local_shards()[0].tensor)
|
||||
self.assertEqual(local_tensor_clone, st.local_shards()[0].tensor)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
|
|
@ -120,17 +120,17 @@ class TestShardedTensorNNInit(ShardedTensorTestBase):
|
|||
seed = 1234
|
||||
dtype = torch.double
|
||||
|
||||
sharded_tensor = _sharded_tensor.empty(spec, h, w, dtype=dtype)
|
||||
self.assertEqual(1, len(sharded_tensor.local_shards()))
|
||||
st = sharded_tensor.empty(spec, h, w, dtype=dtype)
|
||||
self.assertEqual(1, len(st.local_shards()))
|
||||
|
||||
# Clone local tensor to ensure torch.nn.init starts from the same input
|
||||
local_tensor_clone = torch.clone(sharded_tensor.local_shards()[0].tensor)
|
||||
local_tensor_clone = torch.clone(st.local_shards()[0].tensor)
|
||||
torch.manual_seed(seed)
|
||||
torch.nn.init.kaiming_uniform_(sharded_tensor, a=a, mode=mode, nonlinearity=nonlinearity)
|
||||
torch.nn.init.kaiming_uniform_(st, a=a, mode=mode, nonlinearity=nonlinearity)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
torch.nn.init.kaiming_uniform_(local_tensor_clone, a=a, mode=mode, nonlinearity=nonlinearity)
|
||||
self.assertEqual(local_tensor_clone, sharded_tensor.local_shards()[0].tensor)
|
||||
self.assertEqual(local_tensor_clone, st.local_shards()[0].tensor)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
@ -4,16 +4,16 @@ import sys
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed._sharded_tensor import (
|
||||
from torch.distributed._shard import shard_parameter
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
empty,
|
||||
shard_parameter,
|
||||
)
|
||||
from torch.distributed._sharding_spec import (
|
||||
from torch.distributed._shard.sharding_spec import (
|
||||
ChunkShardingSpec,
|
||||
EnumerableShardingSpec,
|
||||
ShardMetadata
|
||||
)
|
||||
from torch.distributed._sharded_optim import (
|
||||
from torch.distributed._shard.sharded_optim import (
|
||||
ShardedOptimizer,
|
||||
named_params_with_sharded_tensor,
|
||||
)
|
||||
|
|
@ -25,12 +25,12 @@ from torch.testing._internal.common_utils import (
|
|||
TEST_WITH_DEV_DBG_ASAN,
|
||||
run_tests,
|
||||
)
|
||||
from torch.testing._internal.distributed._sharded_tensor import (
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor import (
|
||||
TEST_GPU_NUM,
|
||||
ShardedTensorTestBase,
|
||||
with_comms,
|
||||
)
|
||||
from torch.testing._internal.distributed._sharded_tensor._test_ops_common import (
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor._test_ops_common import (
|
||||
generate_chunk_sharding_specs_for_test,
|
||||
generate_local_weight_sharding_params_for_test,
|
||||
)
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -2,13 +2,13 @@
|
|||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
from torch.distributed._sharding_spec import (
|
||||
from torch.distributed._shard.sharding_spec import (
|
||||
ChunkShardingSpec,
|
||||
DevicePlacementSpec,
|
||||
EnumerableShardingSpec,
|
||||
ShardMetadata,
|
||||
)
|
||||
from torch.distributed._sharding_spec._internals import (
|
||||
from torch.distributed._shard.sharding_spec._internals import (
|
||||
check_tensor,
|
||||
get_split_size,
|
||||
get_chunked_dim_size,
|
||||
|
|
@ -202,13 +202,14 @@ WINDOWS_BLOCKLIST = [
|
|||
"distributed/pipeline/sync/test_worker",
|
||||
"distributed/elastic/agent/server/test/api_test",
|
||||
"distributed/elastic/multiprocessing/api_test",
|
||||
"distributed/_sharded_tensor/test_sharded_tensor",
|
||||
"distributed/_sharded_tensor/ops/test_embedding",
|
||||
"distributed/_sharded_tensor/ops/test_embedding_bag",
|
||||
"distributed/_sharded_tensor/ops/test_binary_cmp",
|
||||
"distributed/_sharded_tensor/ops/test_init",
|
||||
"distributed/_sharded_tensor/ops/test_linear",
|
||||
"distributed/_sharded_optim/test_sharded_optim",
|
||||
"distributed/_shard/sharding_spec/test_sharding_spec",
|
||||
"distributed/_shard/sharded_tensor/test_sharded_tensor",
|
||||
"distributed/_shard/sharded_tensor/ops/test_embedding",
|
||||
"distributed/_shard/sharded_tensor/ops/test_embedding_bag",
|
||||
"distributed/_shard/sharded_tensor/ops/test_binary_cmp",
|
||||
"distributed/_shard/sharded_tensor/ops/test_init",
|
||||
"distributed/_shard/sharded_tensor/ops/test_linear",
|
||||
"distributed/_shard/sharded_optim/test_sharded_optim",
|
||||
] + FSDP_TEST + FX2TRT_TESTS
|
||||
|
||||
ROCM_BLOCKLIST = [
|
||||
|
|
@ -216,13 +217,13 @@ ROCM_BLOCKLIST = [
|
|||
"distributed/rpc/test_faulty_agent",
|
||||
"distributed/rpc/test_tensorpipe_agent",
|
||||
"distributed/rpc/cuda/test_tensorpipe_agent",
|
||||
"distributed/_sharded_tensor/test_sharded_tensor",
|
||||
"distributed/_sharded_tensor/ops/test_embedding",
|
||||
"distributed/_sharded_tensor/ops/test_embedding_bag",
|
||||
"distributed/_sharded_tensor/ops/test_binary_cmp",
|
||||
"distributed/_sharded_tensor/ops/test_init",
|
||||
"distributed/_sharded_tensor/ops/test_linear",
|
||||
"distributed/_sharded_optim/test_sharded_optim",
|
||||
"distributed/_shard/sharded_tensor/test_sharded_tensor",
|
||||
"distributed/_shard/sharded_tensor/ops/test_embedding",
|
||||
"distributed/_shard/sharded_tensor/ops/test_embedding_bag",
|
||||
"distributed/_shard/sharded_tensor/ops/test_binary_cmp",
|
||||
"distributed/_shard/sharded_tensor/ops/test_init",
|
||||
"distributed/_shard/sharded_tensor/ops/test_linear",
|
||||
"distributed/_shard/sharded_optim/test_sharded_optim",
|
||||
"test_determination",
|
||||
"test_multiprocessing",
|
||||
"test_jit_legacy",
|
||||
|
|
@ -354,14 +355,14 @@ DISTRIBUTED_TESTS = [
|
|||
"distributed/elastic/utils/util_test",
|
||||
"distributed/elastic/utils/distributed_test",
|
||||
"distributed/elastic/multiprocessing/api_test",
|
||||
"distributed/_sharding_spec/test_sharding_spec",
|
||||
"distributed/_sharded_tensor/test_sharded_tensor",
|
||||
"distributed/_sharded_tensor/ops/test_embedding",
|
||||
"distributed/_sharded_tensor/ops/test_embedding_bag",
|
||||
"distributed/_sharded_tensor/ops/test_binary_cmp",
|
||||
"distributed/_sharded_tensor/ops/test_init",
|
||||
"distributed/_sharded_tensor/ops/test_linear",
|
||||
"distributed/_sharded_optim/test_sharded_optim",
|
||||
"distributed/_shard/sharding_spec/test_sharding_spec",
|
||||
"distributed/_shard/sharded_tensor/test_sharded_tensor",
|
||||
"distributed/_shard/sharded_tensor/ops/test_embedding",
|
||||
"distributed/_shard/sharded_tensor/ops/test_embedding_bag",
|
||||
"distributed/_shard/sharded_tensor/ops/test_binary_cmp",
|
||||
"distributed/_shard/sharded_tensor/ops/test_init",
|
||||
"distributed/_shard/sharded_tensor/ops/test_linear",
|
||||
"distributed/_shard/sharded_optim/test_sharded_optim",
|
||||
] + [test for test in TESTS if test.startswith("distributed/fsdp")]
|
||||
|
||||
# Dictionary matching test modules (in TESTS) to lists of test cases (within that test_module) that would be run when
|
||||
|
|
|
|||
1
torch/distributed/_shard/__init__.py
Normal file
1
torch/distributed/_shard/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .api import shard_parameter
|
||||
145
torch/distributed/_shard/api.py
Normal file
145
torch/distributed/_shard/api.py
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
import copy
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import distributed_c10d
|
||||
from .sharding_spec import (
|
||||
ChunkShardingSpec,
|
||||
ShardingSpec,
|
||||
)
|
||||
from torch.distributed._shard.sharding_spec._internals import (
|
||||
get_chunked_dim_size,
|
||||
get_split_size,
|
||||
)
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
Shard,
|
||||
ShardMetadata,
|
||||
ShardedTensor,
|
||||
)
|
||||
|
||||
def shard_parameter(
|
||||
module: torch.nn.Module,
|
||||
param_name: str,
|
||||
sharding_spec: ShardingSpec,
|
||||
src_rank=0,
|
||||
process_group=None):
|
||||
"""
|
||||
Given a :class:`torch.nn.Module`, a ``param_name`` for a parameter in that
|
||||
module, it shards that parameter according to the provided
|
||||
``sharding_spec``. ``src_rank`` denotes the source rank which would be
|
||||
used as the ground truth of the data which would be scattered as shards
|
||||
across the rest of the ranks.
|
||||
|
||||
This method replaces ``module.param_name`` with a
|
||||
:class:`torch.distributed._shard.sharded_tensor.ShardedTensor`
|
||||
|
||||
Args:
|
||||
module (:class:`torch.nn.Module`): Module whose parameter needs to be sharded.
|
||||
param_name (str): Name of the parameter of ``module`` that needs to be sharded.
|
||||
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
|
||||
describing how to shard the Tensor.
|
||||
|
||||
Keyword args:
|
||||
src_rank (int, optional): The source rank which is used as the ground truth of
|
||||
the data for the parameter that would be sharded and scattered
|
||||
across the rest of the ranks.
|
||||
Default: 0.
|
||||
process_group (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used.
|
||||
|
||||
.. warning::
|
||||
Only :class:`torch.distributed._shard.sharding_spec.ShardingSpec` is
|
||||
currently supported as the ``sharding_spec``.
|
||||
"""
|
||||
# Perform some validation first.
|
||||
if not isinstance(sharding_spec, ChunkShardingSpec):
|
||||
raise ValueError('Only ChunkShardingspec is supported.')
|
||||
|
||||
if not hasattr(module, param_name):
|
||||
raise ValueError(f'module: {module} does not have parameter with name: {param_name}')
|
||||
|
||||
tensor = getattr(module, param_name)
|
||||
if not isinstance(tensor, torch.Tensor):
|
||||
raise ValueError(f'Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}')
|
||||
|
||||
if not tensor.is_contiguous():
|
||||
raise ValueError(f'param: {param_name} is not a contiguous Tensor')
|
||||
|
||||
pg = process_group if process_group is not None else distributed_c10d._get_default_group()
|
||||
world_size = dist.get_world_size(pg)
|
||||
rank = dist.get_rank(pg)
|
||||
|
||||
# Validate src_rank and sharding_spec are same across all ranks.
|
||||
gathered_list = [None] * world_size
|
||||
dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg)
|
||||
|
||||
for idx, entry in enumerate(gathered_list):
|
||||
if src_rank != entry[0]: # type: ignore[index]
|
||||
raise ValueError(
|
||||
f'src_rank={src_rank} on rank: {rank} does not ' # type: ignore[index]
|
||||
f'match with src_rank={entry[0]} on rank: {idx}')
|
||||
if sharding_spec != entry[1]: # type: ignore[index]
|
||||
raise ValueError(
|
||||
f'sharding_spec={sharding_spec} on rank: {rank} does not ' # type: ignore[index]
|
||||
f'match with sharding_spec={entry[1]} on rank: {idx}')
|
||||
|
||||
# Rearrange chunks according to placement.
|
||||
local_metadata = None
|
||||
current_offsets = [0] * len(tensor.size())
|
||||
shards_metadata = []
|
||||
sharding_dim_size = tensor.size(sharding_spec.dim) # type: ignore[arg-type]
|
||||
split_size = get_split_size(sharding_dim_size, world_size)
|
||||
tensor_sizes = list(tensor.size())
|
||||
for idx, placement in enumerate(sharding_spec.placements):
|
||||
chunked_dim_size = get_chunked_dim_size(sharding_dim_size, split_size, idx)
|
||||
shard_size = copy.deepcopy(tensor_sizes)
|
||||
shard_size[sharding_spec.dim] = chunked_dim_size # type: ignore[index]
|
||||
|
||||
shard_metadata = ShardMetadata(
|
||||
shard_offsets=copy.deepcopy(current_offsets),
|
||||
shard_sizes=shard_size,
|
||||
placement=placement,
|
||||
)
|
||||
shards_metadata.append(shard_metadata)
|
||||
|
||||
if rank == placement.rank(): # type: ignore[union-attr]
|
||||
local_metadata = shard_metadata
|
||||
|
||||
current_offsets[sharding_spec.dim] += chunked_dim_size # type: ignore[index]
|
||||
|
||||
# Scatter the shards (use broadcast since NCCL doesn't support scatter, this is very inefficient).
|
||||
dist.broadcast(tensor, src=src_rank, group=pg)
|
||||
|
||||
# Reshape to get shard for this rank and we don't want autograd
|
||||
# recording here for the narrow op and 'local_shard' should be a
|
||||
# leaf variable in the autograd graph.
|
||||
local_shard = tensor.narrow(
|
||||
sharding_spec.dim, # type: ignore[arg-type]
|
||||
local_metadata.shard_offsets[sharding_spec.dim], # type: ignore[union-attr, arg-type, index]
|
||||
local_metadata.shard_sizes[sharding_spec.dim], # type: ignore[union-attr, index]
|
||||
).clone().detach().contiguous()
|
||||
|
||||
# Sync requires_grad to local_shard.
|
||||
local_shard.requires_grad = tensor.requires_grad
|
||||
|
||||
# Create ShardedTensor based on local shards.
|
||||
local_shards = [
|
||||
Shard(
|
||||
tensor=local_shard,
|
||||
metadata=local_metadata, # type: ignore[arg-type]
|
||||
)
|
||||
]
|
||||
|
||||
st = ShardedTensor._init_from_local_shards(local_shards, tensor.size(), process_group=pg)
|
||||
|
||||
# Manually set sharding_spec
|
||||
st._sharding_spec = sharding_spec
|
||||
|
||||
# Replace param with ShardedTensor.
|
||||
|
||||
# Need to delete the attribute first since param_name might be
|
||||
# torch.nn.Parameter and can't be replaced with ShardedTensor which is
|
||||
# not torch.nn.Parameter.
|
||||
delattr(module, param_name)
|
||||
|
||||
# Now we can set the attribute appropriately.
|
||||
setattr(module, param_name, st)
|
||||
|
|
@ -3,7 +3,7 @@ from .api import ShardedOptimizer
|
|||
|
||||
import torch.nn as nn
|
||||
|
||||
from torch.distributed._sharded_tensor import (
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
ShardedTensor
|
||||
)
|
||||
|
||||
|
|
@ -16,7 +16,7 @@ def named_params_with_sharded_tensor(
|
|||
r"""Returns an iterator over module parameters (together with the
|
||||
ShardedTensor parameters), yielding both the name of the parameter
|
||||
as well as the parameter itself. This is typically passed to a
|
||||
:class:torch.distributed._sharded_optim.ShardedOptimizer
|
||||
:class:torch.distributed._shard.sharded_optim.ShardedOptimizer
|
||||
|
||||
Args:
|
||||
prefix (str): prefix to prepend to all parameter names.
|
||||
|
|
@ -2,7 +2,7 @@ from typing import List, Union, Mapping, Dict, Any
|
|||
|
||||
import torch.optim as optim
|
||||
from torch import Tensor
|
||||
from torch.distributed._sharded_tensor import ShardedTensor
|
||||
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
||||
|
||||
|
||||
class ShardedOptimizer(optim.Optimizer):
|
||||
412
torch/distributed/_shard/sharded_tensor/__init__.py
Normal file
412
torch/distributed/_shard/sharded_tensor/__init__.py
Normal file
|
|
@ -0,0 +1,412 @@
|
|||
# coding=utf-8
|
||||
|
||||
import copy
|
||||
import functools
|
||||
import torch
|
||||
from torch.distributed._shard.sharding_spec import (
|
||||
ChunkShardingSpec,
|
||||
ShardingSpec,
|
||||
)
|
||||
from torch.distributed._shard.sharding_spec._internals import (
|
||||
get_chunked_dim_size,
|
||||
get_split_size,
|
||||
)
|
||||
from typing import List
|
||||
|
||||
from .api import (
|
||||
_register_sharded_op,
|
||||
CreateOp,
|
||||
Shard,
|
||||
ShardMetadata,
|
||||
ShardedTensor,
|
||||
ShardedTensorMetadata,
|
||||
TensorInitParams,
|
||||
TensorProperties,
|
||||
)
|
||||
from .utils import load_with_process_group
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import distributed_c10d
|
||||
|
||||
|
||||
def empty(sharding_spec: ShardingSpec,
|
||||
*size,
|
||||
dtype=None,
|
||||
layout=torch.strided,
|
||||
requires_grad=False,
|
||||
pin_memory=False,
|
||||
memory_format=torch.contiguous_format,
|
||||
process_group=None,
|
||||
init_rrefs=False) -> ShardedTensor:
|
||||
"""
|
||||
Returns a :class:`ShardedTensor` filled with uninitialized data.
|
||||
Needs to be called on all ranks in an SPMD fashion.
|
||||
|
||||
Args:
|
||||
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
|
||||
describing how to shard the Tensor.
|
||||
size (int...): a sequence of integers defining the shape of the output
|
||||
tensor. Can be a variable number of arguments or a collection like a list or tuple.
|
||||
|
||||
Keyword args:
|
||||
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
|
||||
Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).
|
||||
layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
|
||||
Default: ``torch.strided``.
|
||||
requires_grad (bool, optional): If autograd should record operations on the
|
||||
returned tensor. Default: ``False``.
|
||||
pin_memory (bool, optional): If set, returned tensor would be allocated in
|
||||
the pinned memory. Works only for CPU tensors. Default: ``False``.
|
||||
memory_format (:class:`torch.memory_format`, optional): the desired memory format of
|
||||
returned Tensor. Default: ``torch.contiguous_format``.
|
||||
process_group (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used.
|
||||
init_rrefs (bool, optional): Whether or not to initialize
|
||||
:class:`torch.distributed.rpc.RRef`s pointing to remote shards.
|
||||
Need to initialize the RPC Framework if specified as ``True``.
|
||||
Default: ``False``.
|
||||
|
||||
Returns:
|
||||
A :class:`ShardedTensor` object on each rank
|
||||
"""
|
||||
tensor_properties = TensorProperties(dtype=dtype, layout=layout,
|
||||
requires_grad=requires_grad,
|
||||
pin_memory=pin_memory, memory_format=memory_format, )
|
||||
tensor_init_params = TensorInitParams(create_op=CreateOp.EMPTY, tensor_properties=tensor_properties, )
|
||||
return ShardedTensor(
|
||||
sharding_spec,
|
||||
*size,
|
||||
tensor_init_params=tensor_init_params,
|
||||
process_group=process_group,
|
||||
init_rrefs=init_rrefs,
|
||||
)
|
||||
|
||||
def ones(sharding_spec: ShardingSpec,
|
||||
*size,
|
||||
dtype=None,
|
||||
layout=torch.strided,
|
||||
requires_grad=False,
|
||||
pin_memory=False,
|
||||
memory_format=torch.contiguous_format,
|
||||
process_group=None,
|
||||
init_rrefs=False) -> ShardedTensor:
|
||||
"""
|
||||
Returns a :class:`ShardedTensor` with the scalar value 1.
|
||||
Needs to be called on all ranks in an SPMD fashion.
|
||||
|
||||
Args:
|
||||
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
|
||||
describing how to shard the Tensor.
|
||||
size (int...): a sequence of integers defining the shape of the output
|
||||
tensor. Can be a variable number of arguments or a collection like a list or tuple.
|
||||
|
||||
Keyword args:
|
||||
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
|
||||
Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).
|
||||
layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
|
||||
Default: ``torch.strided``.
|
||||
requires_grad (bool, optional): If autograd should record operations on the
|
||||
returned tensor. Default: ``False``.
|
||||
pin_memory (bool, optional): If set, returned tensor would be allocated in
|
||||
the pinned memory. Works only for CPU tensors. Default: ``False``.
|
||||
process_group (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used.
|
||||
init_rrefs (bool, optional): Whether or not to initialize
|
||||
:class:`torch.distributed.rpc.RRef`s pointing to remote shards.
|
||||
Need to initialize the RPC Framework if specified as ``True``.
|
||||
Default: ``False``.
|
||||
|
||||
Returns:
|
||||
A :class:`ShardedTensor` object on each rank
|
||||
"""
|
||||
tensor_properties = TensorProperties(dtype=dtype, layout=layout,
|
||||
requires_grad=requires_grad,
|
||||
pin_memory=pin_memory, memory_format=memory_format, )
|
||||
tensor_init_params = TensorInitParams(create_op=CreateOp.ONES, tensor_properties=tensor_properties)
|
||||
return ShardedTensor(
|
||||
sharding_spec,
|
||||
*size,
|
||||
tensor_init_params=tensor_init_params,
|
||||
process_group=process_group,
|
||||
init_rrefs=init_rrefs,
|
||||
)
|
||||
|
||||
|
||||
def rand(sharding_spec: ShardingSpec,
|
||||
*size,
|
||||
dtype=None,
|
||||
layout=torch.strided,
|
||||
requires_grad=False,
|
||||
pin_memory=False,
|
||||
memory_format=torch.contiguous_format,
|
||||
process_group=None,
|
||||
init_rrefs=False) -> ShardedTensor:
|
||||
"""
|
||||
Returns a :class:`ShardedTensor` filled with random numbers from a uniform distribution on the
|
||||
interval :math:`[0, 1)`. Needs to be called on all ranks in an SPMD fashion.
|
||||
|
||||
Args:
|
||||
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
|
||||
describing how to shard the Tensor.
|
||||
size (int...): a sequence of integers defining the shape of the output
|
||||
tensor. Can be a variable number of arguments or a collection like a list or tuple.
|
||||
|
||||
Keyword args:
|
||||
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
|
||||
Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).
|
||||
layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
|
||||
Default: ``torch.strided``.
|
||||
requires_grad (bool, optional): If autograd should record operations on the
|
||||
returned tensor. Default: ``False``.
|
||||
pin_memory (bool, optional): If set, returned tensor would be allocated in
|
||||
the pinned memory. Works only for CPU tensors. Default: ``False``.
|
||||
process_group (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used.
|
||||
init_rrefs (bool, optional): Whether or not to initialize
|
||||
:class:`torch.distributed.rpc.RRef`s pointing to remote shards.
|
||||
Need to initialize the RPC Framework if specified as ``True``.
|
||||
Default: ``False``.
|
||||
|
||||
Returns:
|
||||
A :class:`ShardedTensor` object on each rank
|
||||
"""
|
||||
tensor_properties = TensorProperties(
|
||||
dtype=dtype, layout=layout, requires_grad=requires_grad,
|
||||
pin_memory=pin_memory, memory_format=memory_format
|
||||
)
|
||||
tensor_init_params = TensorInitParams(create_op=CreateOp.RAND, tensor_properties=tensor_properties, )
|
||||
return ShardedTensor(
|
||||
sharding_spec,
|
||||
*size,
|
||||
tensor_init_params=tensor_init_params,
|
||||
process_group=process_group,
|
||||
init_rrefs=init_rrefs,
|
||||
)
|
||||
|
||||
|
||||
def zeros(sharding_spec: ShardingSpec,
|
||||
*size,
|
||||
dtype=None,
|
||||
layout=torch.strided,
|
||||
requires_grad=False,
|
||||
pin_memory=False,
|
||||
memory_format=torch.contiguous_format,
|
||||
process_group=None,
|
||||
init_rrefs=False) -> ShardedTensor:
|
||||
"""
|
||||
Returns a :class:`ShardedTensor` filled with the scalar value 0.
|
||||
Needs to be called on all ranks in an SPMD fashion.
|
||||
|
||||
Args:
|
||||
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
|
||||
describing how to shard the Tensor.
|
||||
size (int...): a sequence of integers defining the shape of the output
|
||||
tensor. Can be a variable number of arguments or a collection like a list or tuple.
|
||||
|
||||
Keyword args:
|
||||
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
|
||||
Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).
|
||||
layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
|
||||
Default: ``torch.strided``.
|
||||
requires_grad (bool, optional): If autograd should record operations on the
|
||||
returned tensor. Default: ``False``.
|
||||
pin_memory (bool, optional): If set, returned tensor would be allocated in
|
||||
the pinned memory. Works only for CPU tensors. Default: ``False``.
|
||||
process_group (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used.
|
||||
init_rrefs (bool, optional): Whether or not to initialize
|
||||
:class:`torch.distributed.rpc.RRef`s pointing to remote shards.
|
||||
Need to initialize the RPC Framework if specified as ``True``.
|
||||
Default: ``False``.
|
||||
|
||||
Returns:
|
||||
A :class:`ShardedTensor` object on each rank
|
||||
"""
|
||||
tensor_properties = TensorProperties(
|
||||
dtype=dtype, layout=layout, requires_grad=requires_grad,
|
||||
pin_memory=pin_memory, memory_format=memory_format,
|
||||
)
|
||||
tensor_init_params = TensorInitParams(create_op=CreateOp.ZEROS, tensor_properties=tensor_properties, )
|
||||
return ShardedTensor(
|
||||
sharding_spec,
|
||||
*size,
|
||||
tensor_init_params=tensor_init_params,
|
||||
process_group=process_group,
|
||||
init_rrefs=init_rrefs,
|
||||
)
|
||||
|
||||
|
||||
def full(sharding_spec: ShardingSpec,
|
||||
size,
|
||||
fill_value=torch.types.Number,
|
||||
dtype=None,
|
||||
layout=torch.strided,
|
||||
requires_grad=False,
|
||||
pin_memory=False,
|
||||
memory_format=torch.contiguous_format,
|
||||
process_group=None,
|
||||
init_rrefs=False) -> ShardedTensor:
|
||||
"""
|
||||
Creates a :class:`ShardedTensor` filled with fill_value. The tensor’s dtype
|
||||
is inferred from fill_value. If dtype is specified, it will override the
|
||||
inferred type from fill_value. Needs to be called on all ranks in an SPMD fashion.
|
||||
|
||||
Args:
|
||||
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
|
||||
describing how to shard the Tensor.
|
||||
size (int...): a list, tuple, or `torch.Size` of integers defining the shape of the
|
||||
output tensor.
|
||||
fill_value (Scalar) – the value to fill the output tensor with.
|
||||
|
||||
Keyword args:
|
||||
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
|
||||
Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).
|
||||
layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
|
||||
Default: ``torch.strided``.
|
||||
requires_grad (bool, optional): If autograd should record operations on the
|
||||
returned tensor. Default: ``False``.
|
||||
pin_memory (bool, optional): If set, returned tensor would be allocated in
|
||||
the pinned memory. Works only for CPU tensors. Default: ``False``.
|
||||
process_group (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used.
|
||||
init_rrefs (bool, optional): Whether or not to initialize
|
||||
:class:`torch.distributed.rpc.RRef`s pointing to remote shards.
|
||||
Need to initialize the RPC Framework if specified as ``True``.
|
||||
Default: ``False``.
|
||||
|
||||
Returns:
|
||||
A :class:`ShardedTensor` object on each rank
|
||||
"""
|
||||
tensor_properties = TensorProperties(
|
||||
dtype=dtype, layout=layout, requires_grad=requires_grad,
|
||||
pin_memory=pin_memory, memory_format=memory_format,
|
||||
)
|
||||
tensor_init_params = TensorInitParams(
|
||||
create_op=CreateOp.FULL, fill_value=fill_value, tensor_properties=tensor_properties)
|
||||
return ShardedTensor(
|
||||
sharding_spec,
|
||||
*size,
|
||||
tensor_init_params=tensor_init_params,
|
||||
process_group=process_group,
|
||||
init_rrefs=init_rrefs,
|
||||
)
|
||||
|
||||
|
||||
def init_from_local_shards(
|
||||
local_shards: List[Shard],
|
||||
*global_size,
|
||||
process_group=None,
|
||||
init_rrefs=False) -> ShardedTensor:
|
||||
"""
|
||||
Creates an :class:`ShardedTensor` from local shards and the global metadata.
|
||||
Needs to be called on all ranks in an SPMD fashion.
|
||||
|
||||
Args:
|
||||
local_shards (List[:class `torch.distributed._shard.sharded_tensor.Shard`]): A list
|
||||
of shards that represent the local shards on this rank.
|
||||
global_size (int...): a list, tuple, or `torch.Size` of integers defining the
|
||||
shape of the overall sharded tensor.
|
||||
|
||||
Keyword args:
|
||||
process_group (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used.
|
||||
init_rrefs (bool, optional): Whether or not to initialize
|
||||
:class:`torch.distributed.rpc.RRef`s pointing to remote shards.
|
||||
Need to initialize the RPC Framework if specified as ``True``.
|
||||
Default: ``False``.
|
||||
|
||||
Returns:
|
||||
A :class:`ShardedTensor` object handle on this rank
|
||||
|
||||
|
||||
Examples:
|
||||
Suppose we want construct a sharded tensor on two ranks, global size = (10, 5),
|
||||
each shard have a (5, 5) local tensor, we can do it like below:
|
||||
|
||||
on rank 0:
|
||||
>>> local_shard_metadata = ShardMetadata(
|
||||
>>> shard_offsets=[0, 0]
|
||||
>>> shard_lengths=[5, 5]
|
||||
>>> placement="rank:0/cuda:0"
|
||||
>>> )
|
||||
>>> local_shards = [Shard(torch.randn(5, 5), local_shard_metadata)]
|
||||
>>> sharded_tensor = init_from_local_shards(local_shards, [10, 5])
|
||||
|
||||
on rank 1:
|
||||
>>> local_shard_metadata = ShardMetadata(
|
||||
>>> shard_offsets=[5, 0]
|
||||
>>> shard_lengths=[5, 5]
|
||||
>>> placement="rank:1/cuda:1"
|
||||
>>> )
|
||||
>>> local_shards = [Shard(torch.randn(5, 5), local_shard_metadata)]
|
||||
>>> sharded_tensor = init_from_local_shards(local_shards, [10, 5])
|
||||
"""
|
||||
return ShardedTensor._init_from_local_shards(
|
||||
local_shards,
|
||||
*global_size,
|
||||
process_group=process_group,
|
||||
init_rrefs=init_rrefs
|
||||
)
|
||||
|
||||
def state_dict_hook(module, destination, prefix, local_metadata):
|
||||
"""
|
||||
Hook to add ShardedTensor to Module's ``state_dict``. Needs to be
|
||||
registered to the Module using
|
||||
:meth:`torch.nn.Module._register_state_dict_hook`.
|
||||
"""
|
||||
for submodule_name, submodule in module.named_modules():
|
||||
for attr_name, attr in submodule.__dict__.items():
|
||||
if isinstance(attr, ShardedTensor):
|
||||
destination[prefix + submodule_name + '.' + attr_name] = attr
|
||||
|
||||
def pre_load_state_dict_hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||
"""
|
||||
Pre-load state dict hook to add ShardedTensor to the module.
|
||||
"""
|
||||
for submodule_name, submodule in module.named_modules():
|
||||
for attr_name, attr in submodule.__dict__.items():
|
||||
key = prefix + submodule_name + '.' + attr_name
|
||||
if key in state_dict:
|
||||
if isinstance(state_dict[key], ShardedTensor):
|
||||
setattr(submodule, attr_name, state_dict[key])
|
||||
|
||||
def sharded_op_impl(func):
|
||||
"""
|
||||
Provides a way for users to write their own custom sharded operator. This
|
||||
can be used to override existing ShardedTensor operators or write a new
|
||||
one not supported by ShardedTensor. If the operator in question is covered
|
||||
by ``__torch_function__`` dispatch and has a ShardedTensor as any of its
|
||||
parameters, the function provided will be invoked for that operator.
|
||||
|
||||
Example::
|
||||
>>> @custom_sharded_op(torch.nn.functional.linear)
|
||||
>>> def my_custom_sharded_linear(types, args, kwargs, process_group):
|
||||
>>> ....
|
||||
>>>
|
||||
>>> input = torch.rand(10, 32)
|
||||
>>> weight = sharded_tensor.rand(32, 16)
|
||||
>>> bias = torch.rand(16)
|
||||
>>> # This will call 'my_custom_sharded_linear'
|
||||
>>> torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
The types, args and kwargs parameters are the same parameters that are
|
||||
passed to ``__torch_function__`` dispatch API
|
||||
(https://pytorch.org/docs/stable/notes/extending.html#extending-torch).
|
||||
There is an additional ``process_group`` parameter which is the
|
||||
process_group used for the ShardedTensor and can be used by
|
||||
implementations for communications within a sharded implementation.
|
||||
|
||||
Args:
|
||||
func(Callable): Torch function for which we want to provide a sharded
|
||||
implementation (ex: torch.nn.functional.linear)
|
||||
"""
|
||||
def decorator_sharded_func(wrapped_func):
|
||||
_register_sharded_op(func, wrapped_func)
|
||||
|
||||
@functools.wraps(wrapped_func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return wrapped_func(*args, **kwargs)
|
||||
return wrapper
|
||||
return decorator_sharded_func
|
||||
|
||||
# Import all builtin sharded ops
|
||||
from ._ops import * # noqa: F403
|
||||
|
|
@ -4,7 +4,7 @@ from typing import List
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed._sharding_spec._internals import (
|
||||
from torch.distributed._shard.sharding_spec._internals import (
|
||||
get_split_size,
|
||||
get_chunked_dim_size,
|
||||
)
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.distributed_c10d as distributed_c10d
|
||||
from torch.distributed._sharded_tensor import (
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
ShardedTensor,
|
||||
sharded_op_impl
|
||||
)
|
||||
|
|
@ -4,14 +4,14 @@ from typing import cast
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed._sharded_tensor.ops._common import (
|
||||
from ._common import (
|
||||
_communicate_size_to_each_rank,
|
||||
_handle_col_wise_sharding_base,
|
||||
_handle_row_wise_lookup_distribute,
|
||||
_handle_max_norm_col_wise,
|
||||
)
|
||||
from torch.distributed._sharding_spec import ChunkShardingSpec
|
||||
from torch.distributed._sharded_tensor import (
|
||||
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
sharded_op_impl,
|
||||
ShardedTensor
|
||||
)
|
||||
|
|
@ -7,15 +7,15 @@ import torch.distributed as dist
|
|||
from torch._C._distributed_c10d import (
|
||||
ReduceOp,
|
||||
)
|
||||
from torch.distributed._sharded_tensor.ops._common import (
|
||||
from ._common import (
|
||||
_communicate_list_to_each_rank,
|
||||
_communicate_size_to_each_rank,
|
||||
_handle_col_wise_sharding_base,
|
||||
_handle_row_wise_lookup_distribute,
|
||||
_handle_max_norm_col_wise,
|
||||
)
|
||||
from torch.distributed._sharding_spec import ChunkShardingSpec
|
||||
from torch.distributed._sharded_tensor import (
|
||||
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
sharded_op_impl,
|
||||
ShardedTensor
|
||||
)
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
import torch
|
||||
from torch.distributed._sharded_tensor import (
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
sharded_op_impl,
|
||||
)
|
||||
|
||||
|
|
@ -2,11 +2,11 @@ from typing import List, cast
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed._sharded_tensor.ops._common import (
|
||||
from ._common import (
|
||||
_handle_col_wise_sharding_base,
|
||||
)
|
||||
from torch.distributed._sharding_spec import ChunkShardingSpec
|
||||
from torch.distributed._sharding_spec._internals import (
|
||||
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
|
||||
from torch.distributed._shard.sharding_spec._internals import (
|
||||
get_split_size,
|
||||
get_chunked_dim_size,
|
||||
)
|
||||
|
|
@ -15,7 +15,7 @@ from torch.distributed.nn.functional import (
|
|||
reduce_scatter,
|
||||
)
|
||||
|
||||
from torch.distributed._sharded_tensor import (
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
sharded_op_impl,
|
||||
ShardedTensor
|
||||
)
|
||||
|
|
@ -14,13 +14,13 @@ import torch
|
|||
import torch.distributed as dist
|
||||
from torch.distributed import rpc
|
||||
from torch.distributed import distributed_c10d
|
||||
from torch.distributed._sharding_spec import (
|
||||
from torch.distributed._shard.sharding_spec import (
|
||||
ChunkShardingSpec,
|
||||
EnumerableShardingSpec,
|
||||
ShardMetadata,
|
||||
ShardingSpec,
|
||||
)
|
||||
from torch.distributed._sharding_spec._internals import (
|
||||
from torch.distributed._shard.sharding_spec._internals import (
|
||||
check_tensor,
|
||||
get_split_size,
|
||||
get_chunked_dim_size,
|
||||
|
|
@ -108,13 +108,13 @@ class ShardedTensor(object):
|
|||
|
||||
ShardedTensor doesn't provide any Tensor like operations but is a wrapper
|
||||
providing the Tensor representing the local shard and the global metadata.
|
||||
Using these, users can build their custom distributed sharded computations
|
||||
Using these, users can build their custom distributed._sharded computations
|
||||
on top of this primitive. The local shards are all initialized using the
|
||||
create_op specified by tensor_init_params.create_op, e.g., torch.ones, or
|
||||
torch.empty
|
||||
|
||||
Args:
|
||||
sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification
|
||||
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
|
||||
describing how to shard the Tensor.
|
||||
size (int...): a sequence of integers defining the shape of the output
|
||||
tensor. Can be a variable number of arguments or a collection like a list or tuple.
|
||||
|
|
@ -138,7 +138,7 @@ class ShardedTensor(object):
|
|||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# Use __new__ for logging purposes.
|
||||
torch._C._log_api_usage_once("torch.distributed.sharded_tensor")
|
||||
torch._C._log_api_usage_once("torch.distributed._shard.sharded_tensor")
|
||||
return super(ShardedTensor, cls).__new__(cls)
|
||||
|
||||
def __init__(
|
||||
|
|
@ -3,7 +3,7 @@ from enum import Enum
|
|||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch.distributed._sharding_spec import ShardMetadata
|
||||
from torch.distributed._shard.sharding_spec import ShardMetadata
|
||||
|
||||
|
||||
class MEM_FORMAT_ENCODING(Enum):
|
||||
|
|
@ -2,7 +2,7 @@ from dataclasses import dataclass
|
|||
from typing import List, cast
|
||||
|
||||
import torch
|
||||
from torch.distributed._sharding_spec import ShardMetadata
|
||||
from torch.distributed._shard.sharding_spec import ShardMetadata
|
||||
from torch.distributed.remote_device import _remote_device
|
||||
|
||||
|
||||
|
|
@ -14,7 +14,7 @@ class Shard(object):
|
|||
|
||||
Args:
|
||||
tensor(torch.Tensor): Local tensor for the shard.
|
||||
metadata(:class `torch.distributed._sharded_tensor.ShardMetadata`):
|
||||
metadata(:class `torch.distributed._shard.sharded_tensor.ShardMetadata`):
|
||||
The metadata for the shard, including offsets, lengths and device placement.
|
||||
"""
|
||||
__slots__ = ['tensor', 'metadata']
|
||||
|
|
@ -5,10 +5,10 @@ from typing import Optional, List, Sequence
|
|||
import torch
|
||||
from torch.distributed import distributed_c10d
|
||||
from torch.distributed import rpc
|
||||
from torch.distributed._sharding_spec import (
|
||||
from torch.distributed._shard.sharding_spec import (
|
||||
ShardMetadata,
|
||||
)
|
||||
from torch.distributed._sharding_spec._internals import (
|
||||
from torch.distributed._shard.sharding_spec._internals import (
|
||||
check_tensor,
|
||||
validate_non_overlapping_shards_metadata,
|
||||
)
|
||||
8
torch/distributed/_shard/sharding_spec/__init__.py
Normal file
8
torch/distributed/_shard/sharding_spec/__init__.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
from .api import (
|
||||
ChunkShardingSpec,
|
||||
DevicePlacementSpec,
|
||||
EnumerableShardingSpec,
|
||||
PlacementSpec,
|
||||
ShardMetadata,
|
||||
ShardingSpec,
|
||||
)
|
||||
|
|
@ -1,540 +1,12 @@
|
|||
# coding=utf-8
|
||||
|
||||
import copy
|
||||
import functools
|
||||
# Keep old package for BC purposes, this file should be removed once
|
||||
# everything moves to the `torch.distributed._shard` package.
|
||||
import sys
|
||||
import torch
|
||||
from torch.distributed._sharding_spec import (
|
||||
ChunkShardingSpec,
|
||||
ShardingSpec,
|
||||
import warnings
|
||||
|
||||
from torch.distributed._shard.sharded_tensor import * # noqa: F403
|
||||
warnings.warn(
|
||||
"torch.distributed._sharded_tensor will be deprecated, use torch.distributed._shard.sharded_tensor instead",
|
||||
DeprecationWarning
|
||||
)
|
||||
from torch.distributed._sharding_spec._internals import (
|
||||
get_chunked_dim_size,
|
||||
get_split_size,
|
||||
)
|
||||
from typing import List
|
||||
|
||||
from .api import (
|
||||
_register_sharded_op,
|
||||
CreateOp,
|
||||
Shard,
|
||||
ShardMetadata,
|
||||
ShardedTensor,
|
||||
ShardedTensorMetadata,
|
||||
TensorInitParams,
|
||||
TensorProperties,
|
||||
)
|
||||
from .utils import load_with_process_group
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import distributed_c10d
|
||||
|
||||
|
||||
def empty(sharding_spec: ShardingSpec,
|
||||
*size,
|
||||
dtype=None,
|
||||
layout=torch.strided,
|
||||
requires_grad=False,
|
||||
pin_memory=False,
|
||||
memory_format=torch.contiguous_format,
|
||||
process_group=None,
|
||||
init_rrefs=False) -> ShardedTensor:
|
||||
"""
|
||||
Returns a :class:`ShardedTensor` filled with uninitialized data.
|
||||
Needs to be called on all ranks in an SPMD fashion.
|
||||
|
||||
Args:
|
||||
sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification
|
||||
describing how to shard the Tensor.
|
||||
size (int...): a sequence of integers defining the shape of the output
|
||||
tensor. Can be a variable number of arguments or a collection like a list or tuple.
|
||||
|
||||
Keyword args:
|
||||
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
|
||||
Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).
|
||||
layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
|
||||
Default: ``torch.strided``.
|
||||
requires_grad (bool, optional): If autograd should record operations on the
|
||||
returned tensor. Default: ``False``.
|
||||
pin_memory (bool, optional): If set, returned tensor would be allocated in
|
||||
the pinned memory. Works only for CPU tensors. Default: ``False``.
|
||||
memory_format (:class:`torch.memory_format`, optional): the desired memory format of
|
||||
returned Tensor. Default: ``torch.contiguous_format``.
|
||||
process_group (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used.
|
||||
init_rrefs (bool, optional): Whether or not to initialize
|
||||
:class:`torch.distributed.rpc.RRef`s pointing to remote shards.
|
||||
Need to initialize the RPC Framework if specified as ``True``.
|
||||
Default: ``False``.
|
||||
|
||||
Returns:
|
||||
A :class:`ShardedTensor` object on each rank
|
||||
"""
|
||||
tensor_properties = TensorProperties(dtype=dtype, layout=layout,
|
||||
requires_grad=requires_grad,
|
||||
pin_memory=pin_memory, memory_format=memory_format, )
|
||||
tensor_init_params = TensorInitParams(create_op=CreateOp.EMPTY, tensor_properties=tensor_properties, )
|
||||
return ShardedTensor(
|
||||
sharding_spec,
|
||||
*size,
|
||||
tensor_init_params=tensor_init_params,
|
||||
process_group=process_group,
|
||||
init_rrefs=init_rrefs,
|
||||
)
|
||||
|
||||
def ones(sharding_spec: ShardingSpec,
|
||||
*size,
|
||||
dtype=None,
|
||||
layout=torch.strided,
|
||||
requires_grad=False,
|
||||
pin_memory=False,
|
||||
memory_format=torch.contiguous_format,
|
||||
process_group=None,
|
||||
init_rrefs=False) -> ShardedTensor:
|
||||
"""
|
||||
Returns a :class:`ShardedTensor` with the scalar value 1.
|
||||
Needs to be called on all ranks in an SPMD fashion.
|
||||
|
||||
Args:
|
||||
sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification
|
||||
describing how to shard the Tensor.
|
||||
size (int...): a sequence of integers defining the shape of the output
|
||||
tensor. Can be a variable number of arguments or a collection like a list or tuple.
|
||||
|
||||
Keyword args:
|
||||
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
|
||||
Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).
|
||||
layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
|
||||
Default: ``torch.strided``.
|
||||
requires_grad (bool, optional): If autograd should record operations on the
|
||||
returned tensor. Default: ``False``.
|
||||
pin_memory (bool, optional): If set, returned tensor would be allocated in
|
||||
the pinned memory. Works only for CPU tensors. Default: ``False``.
|
||||
process_group (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used.
|
||||
init_rrefs (bool, optional): Whether or not to initialize
|
||||
:class:`torch.distributed.rpc.RRef`s pointing to remote shards.
|
||||
Need to initialize the RPC Framework if specified as ``True``.
|
||||
Default: ``False``.
|
||||
|
||||
Returns:
|
||||
A :class:`ShardedTensor` object on each rank
|
||||
"""
|
||||
tensor_properties = TensorProperties(dtype=dtype, layout=layout,
|
||||
requires_grad=requires_grad,
|
||||
pin_memory=pin_memory, memory_format=memory_format, )
|
||||
tensor_init_params = TensorInitParams(create_op=CreateOp.ONES, tensor_properties=tensor_properties)
|
||||
return ShardedTensor(
|
||||
sharding_spec,
|
||||
*size,
|
||||
tensor_init_params=tensor_init_params,
|
||||
process_group=process_group,
|
||||
init_rrefs=init_rrefs,
|
||||
)
|
||||
|
||||
|
||||
def rand(sharding_spec: ShardingSpec,
|
||||
*size,
|
||||
dtype=None,
|
||||
layout=torch.strided,
|
||||
requires_grad=False,
|
||||
pin_memory=False,
|
||||
memory_format=torch.contiguous_format,
|
||||
process_group=None,
|
||||
init_rrefs=False) -> ShardedTensor:
|
||||
"""
|
||||
Returns a :class:`ShardedTensor` filled with random numbers from a uniform distribution on the
|
||||
interval :math:`[0, 1)`. Needs to be called on all ranks in an SPMD fashion.
|
||||
|
||||
Args:
|
||||
sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification
|
||||
describing how to shard the Tensor.
|
||||
size (int...): a sequence of integers defining the shape of the output
|
||||
tensor. Can be a variable number of arguments or a collection like a list or tuple.
|
||||
|
||||
Keyword args:
|
||||
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
|
||||
Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).
|
||||
layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
|
||||
Default: ``torch.strided``.
|
||||
requires_grad (bool, optional): If autograd should record operations on the
|
||||
returned tensor. Default: ``False``.
|
||||
pin_memory (bool, optional): If set, returned tensor would be allocated in
|
||||
the pinned memory. Works only for CPU tensors. Default: ``False``.
|
||||
process_group (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used.
|
||||
init_rrefs (bool, optional): Whether or not to initialize
|
||||
:class:`torch.distributed.rpc.RRef`s pointing to remote shards.
|
||||
Need to initialize the RPC Framework if specified as ``True``.
|
||||
Default: ``False``.
|
||||
|
||||
Returns:
|
||||
A :class:`ShardedTensor` object on each rank
|
||||
"""
|
||||
tensor_properties = TensorProperties(
|
||||
dtype=dtype, layout=layout, requires_grad=requires_grad,
|
||||
pin_memory=pin_memory, memory_format=memory_format
|
||||
)
|
||||
tensor_init_params = TensorInitParams(create_op=CreateOp.RAND, tensor_properties=tensor_properties, )
|
||||
return ShardedTensor(
|
||||
sharding_spec,
|
||||
*size,
|
||||
tensor_init_params=tensor_init_params,
|
||||
process_group=process_group,
|
||||
init_rrefs=init_rrefs,
|
||||
)
|
||||
|
||||
|
||||
def zeros(sharding_spec: ShardingSpec,
|
||||
*size,
|
||||
dtype=None,
|
||||
layout=torch.strided,
|
||||
requires_grad=False,
|
||||
pin_memory=False,
|
||||
memory_format=torch.contiguous_format,
|
||||
process_group=None,
|
||||
init_rrefs=False) -> ShardedTensor:
|
||||
"""
|
||||
Returns a :class:`ShardedTensor` filled with the scalar value 0.
|
||||
Needs to be called on all ranks in an SPMD fashion.
|
||||
|
||||
Args:
|
||||
sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification
|
||||
describing how to shard the Tensor.
|
||||
size (int...): a sequence of integers defining the shape of the output
|
||||
tensor. Can be a variable number of arguments or a collection like a list or tuple.
|
||||
|
||||
Keyword args:
|
||||
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
|
||||
Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).
|
||||
layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
|
||||
Default: ``torch.strided``.
|
||||
requires_grad (bool, optional): If autograd should record operations on the
|
||||
returned tensor. Default: ``False``.
|
||||
pin_memory (bool, optional): If set, returned tensor would be allocated in
|
||||
the pinned memory. Works only for CPU tensors. Default: ``False``.
|
||||
process_group (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used.
|
||||
init_rrefs (bool, optional): Whether or not to initialize
|
||||
:class:`torch.distributed.rpc.RRef`s pointing to remote shards.
|
||||
Need to initialize the RPC Framework if specified as ``True``.
|
||||
Default: ``False``.
|
||||
|
||||
Returns:
|
||||
A :class:`ShardedTensor` object on each rank
|
||||
"""
|
||||
tensor_properties = TensorProperties(
|
||||
dtype=dtype, layout=layout, requires_grad=requires_grad,
|
||||
pin_memory=pin_memory, memory_format=memory_format,
|
||||
)
|
||||
tensor_init_params = TensorInitParams(create_op=CreateOp.ZEROS, tensor_properties=tensor_properties, )
|
||||
return ShardedTensor(
|
||||
sharding_spec,
|
||||
*size,
|
||||
tensor_init_params=tensor_init_params,
|
||||
process_group=process_group,
|
||||
init_rrefs=init_rrefs,
|
||||
)
|
||||
|
||||
|
||||
def full(sharding_spec: ShardingSpec,
|
||||
size,
|
||||
fill_value=torch.types.Number,
|
||||
dtype=None,
|
||||
layout=torch.strided,
|
||||
requires_grad=False,
|
||||
pin_memory=False,
|
||||
memory_format=torch.contiguous_format,
|
||||
process_group=None,
|
||||
init_rrefs=False) -> ShardedTensor:
|
||||
"""
|
||||
Creates a :class:`ShardedTensor` filled with fill_value. The tensor’s dtype
|
||||
is inferred from fill_value. If dtype is specified, it will override the
|
||||
inferred type from fill_value. Needs to be called on all ranks in an SPMD fashion.
|
||||
|
||||
Args:
|
||||
sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification
|
||||
describing how to shard the Tensor.
|
||||
size (int...): a list, tuple, or `torch.Size` of integers defining the shape of the
|
||||
output tensor.
|
||||
fill_value (Scalar) – the value to fill the output tensor with.
|
||||
|
||||
Keyword args:
|
||||
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
|
||||
Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).
|
||||
layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
|
||||
Default: ``torch.strided``.
|
||||
requires_grad (bool, optional): If autograd should record operations on the
|
||||
returned tensor. Default: ``False``.
|
||||
pin_memory (bool, optional): If set, returned tensor would be allocated in
|
||||
the pinned memory. Works only for CPU tensors. Default: ``False``.
|
||||
process_group (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used.
|
||||
init_rrefs (bool, optional): Whether or not to initialize
|
||||
:class:`torch.distributed.rpc.RRef`s pointing to remote shards.
|
||||
Need to initialize the RPC Framework if specified as ``True``.
|
||||
Default: ``False``.
|
||||
|
||||
Returns:
|
||||
A :class:`ShardedTensor` object on each rank
|
||||
"""
|
||||
tensor_properties = TensorProperties(
|
||||
dtype=dtype, layout=layout, requires_grad=requires_grad,
|
||||
pin_memory=pin_memory, memory_format=memory_format,
|
||||
)
|
||||
tensor_init_params = TensorInitParams(
|
||||
create_op=CreateOp.FULL, fill_value=fill_value, tensor_properties=tensor_properties)
|
||||
return ShardedTensor(
|
||||
sharding_spec,
|
||||
*size,
|
||||
tensor_init_params=tensor_init_params,
|
||||
process_group=process_group,
|
||||
init_rrefs=init_rrefs,
|
||||
)
|
||||
|
||||
|
||||
def init_from_local_shards(
|
||||
local_shards: List[Shard],
|
||||
*global_size,
|
||||
process_group=None,
|
||||
init_rrefs=False) -> ShardedTensor:
|
||||
"""
|
||||
Creates an :class:`ShardedTensor` from local shards and the global metadata.
|
||||
Needs to be called on all ranks in an SPMD fashion.
|
||||
|
||||
Args:
|
||||
local_shards (List[:class `torch.distributed._sharded_tensor.Shard`]): A list
|
||||
of shards that represent the local shards on this rank.
|
||||
global_size (int...): a list, tuple, or `torch.Size` of integers defining the
|
||||
shape of the overall sharded tensor.
|
||||
|
||||
Keyword args:
|
||||
process_group (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used.
|
||||
init_rrefs (bool, optional): Whether or not to initialize
|
||||
:class:`torch.distributed.rpc.RRef`s pointing to remote shards.
|
||||
Need to initialize the RPC Framework if specified as ``True``.
|
||||
Default: ``False``.
|
||||
|
||||
Returns:
|
||||
A :class:`ShardedTensor` object handle on this rank
|
||||
|
||||
|
||||
Examples:
|
||||
Suppose we want construct a sharded tensor on two ranks, global size = (10, 5),
|
||||
each shard have a (5, 5) local tensor, we can do it like below:
|
||||
|
||||
on rank 0:
|
||||
>>> local_shard_metadata = ShardMetadata(
|
||||
>>> shard_offsets=[0, 0]
|
||||
>>> shard_lengths=[5, 5]
|
||||
>>> placement="rank:0/cuda:0"
|
||||
>>> )
|
||||
>>> local_shards = [Shard(torch.randn(5, 5), local_shard_metadata)]
|
||||
>>> sharded_tensor = init_from_local_shards(local_shards, [10, 5])
|
||||
|
||||
on rank 1:
|
||||
>>> local_shard_metadata = ShardMetadata(
|
||||
>>> shard_offsets=[5, 0]
|
||||
>>> shard_lengths=[5, 5]
|
||||
>>> placement="rank:1/cuda:1"
|
||||
>>> )
|
||||
>>> local_shards = [Shard(torch.randn(5, 5), local_shard_metadata)]
|
||||
>>> sharded_tensor = init_from_local_shards(local_shards, [10, 5])
|
||||
"""
|
||||
return ShardedTensor._init_from_local_shards(
|
||||
local_shards,
|
||||
*global_size,
|
||||
process_group=process_group,
|
||||
init_rrefs=init_rrefs
|
||||
)
|
||||
|
||||
def state_dict_hook(module, destination, prefix, local_metadata):
|
||||
"""
|
||||
Hook to add ShardedTensor to Module's ``state_dict``. Needs to be
|
||||
registered to the Module using
|
||||
:meth:`torch.nn.Module._register_state_dict_hook`.
|
||||
"""
|
||||
for submodule_name, submodule in module.named_modules():
|
||||
for attr_name, attr in submodule.__dict__.items():
|
||||
if isinstance(attr, ShardedTensor):
|
||||
destination[prefix + submodule_name + '.' + attr_name] = attr
|
||||
|
||||
def pre_load_state_dict_hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||
"""
|
||||
Pre-load state dict hook to add ShardedTensor to the module.
|
||||
"""
|
||||
for submodule_name, submodule in module.named_modules():
|
||||
for attr_name, attr in submodule.__dict__.items():
|
||||
key = prefix + submodule_name + '.' + attr_name
|
||||
if key in state_dict:
|
||||
if isinstance(state_dict[key], ShardedTensor):
|
||||
setattr(submodule, attr_name, state_dict[key])
|
||||
|
||||
def shard_parameter(
|
||||
module: torch.nn.Module,
|
||||
param_name: str,
|
||||
sharding_spec: ShardingSpec,
|
||||
src_rank=0,
|
||||
process_group=None):
|
||||
"""
|
||||
Given a :class:`torch.nn.Module`, a ``param_name`` for a parameter in that
|
||||
module, it shards that parameter according to the provided
|
||||
``sharding_spec``. ``src_rank`` denotes the source rank which would be
|
||||
used as the ground truth of the data which would be scattered as shards
|
||||
across the rest of the ranks.
|
||||
|
||||
This method replaces ``module.param_name`` with a
|
||||
:class:`torch.distributed._sharded_tensor.ShardedTensor`
|
||||
|
||||
Args:
|
||||
module (:class:`torch.nn.Module`): Module whose parameter needs to be sharded.
|
||||
param_name (str): Name of the parameter of ``module`` that needs to be sharded.
|
||||
sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification
|
||||
describing how to shard the Tensor.
|
||||
|
||||
Keyword args:
|
||||
src_rank (int, optional): The source rank which is used as the ground truth of
|
||||
the data for the parameter that would be sharded and scattered
|
||||
across the rest of the ranks.
|
||||
Default: 0.
|
||||
process_group (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used.
|
||||
|
||||
.. warning::
|
||||
Only :class:`torch.distributed._sharding_spec.ShardingSpec` is
|
||||
currently supported as the ``sharding_spec``.
|
||||
"""
|
||||
# Perform some validation first.
|
||||
if not isinstance(sharding_spec, ChunkShardingSpec):
|
||||
raise ValueError('Only ChunkShardingspec is supported.')
|
||||
|
||||
if not hasattr(module, param_name):
|
||||
raise ValueError(f'module: {module} does not have parameter with name: {param_name}')
|
||||
|
||||
tensor = getattr(module, param_name)
|
||||
if not isinstance(tensor, torch.Tensor):
|
||||
raise ValueError(f'Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}')
|
||||
|
||||
if not tensor.is_contiguous():
|
||||
raise ValueError(f'param: {param_name} is not a contiguous Tensor')
|
||||
|
||||
pg = process_group if process_group is not None else distributed_c10d._get_default_group()
|
||||
world_size = dist.get_world_size(pg)
|
||||
rank = dist.get_rank(pg)
|
||||
|
||||
# Validate src_rank and sharding_spec are same across all ranks.
|
||||
gathered_list = [None] * world_size
|
||||
dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg)
|
||||
|
||||
for idx, entry in enumerate(gathered_list):
|
||||
if src_rank != entry[0]: # type: ignore[index]
|
||||
raise ValueError(
|
||||
f'src_rank={src_rank} on rank: {rank} does not ' # type: ignore[index]
|
||||
f'match with src_rank={entry[0]} on rank: {idx}')
|
||||
if sharding_spec != entry[1]: # type: ignore[index]
|
||||
raise ValueError(
|
||||
f'sharding_spec={sharding_spec} on rank: {rank} does not ' # type: ignore[index]
|
||||
f'match with sharding_spec={entry[1]} on rank: {idx}')
|
||||
|
||||
# Rearrange chunks according to placement.
|
||||
local_metadata = None
|
||||
current_offsets = [0] * len(tensor.size())
|
||||
shards_metadata = []
|
||||
sharding_dim_size = tensor.size(sharding_spec.dim) # type: ignore[arg-type]
|
||||
split_size = get_split_size(sharding_dim_size, world_size)
|
||||
tensor_sizes = list(tensor.size())
|
||||
for idx, placement in enumerate(sharding_spec.placements):
|
||||
chunked_dim_size = get_chunked_dim_size(sharding_dim_size, split_size, idx)
|
||||
shard_size = copy.deepcopy(tensor_sizes)
|
||||
shard_size[sharding_spec.dim] = chunked_dim_size # type: ignore[index]
|
||||
|
||||
shard_metadata = ShardMetadata(
|
||||
shard_offsets=copy.deepcopy(current_offsets),
|
||||
shard_sizes=shard_size,
|
||||
placement=placement,
|
||||
)
|
||||
shards_metadata.append(shard_metadata)
|
||||
|
||||
if rank == placement.rank(): # type: ignore[union-attr]
|
||||
local_metadata = shard_metadata
|
||||
|
||||
current_offsets[sharding_spec.dim] += chunked_dim_size # type: ignore[index]
|
||||
|
||||
# Scatter the shards (use broadcast since NCCL doesn't support scatter, this is very inefficient).
|
||||
dist.broadcast(tensor, src=src_rank, group=pg)
|
||||
|
||||
# Reshape to get shard for this rank and we don't want autograd
|
||||
# recording here for the narrow op and 'local_shard' should be a
|
||||
# leaf variable in the autograd graph.
|
||||
local_shard = tensor.narrow(
|
||||
sharding_spec.dim, # type: ignore[arg-type]
|
||||
local_metadata.shard_offsets[sharding_spec.dim], # type: ignore[union-attr, arg-type, index]
|
||||
local_metadata.shard_sizes[sharding_spec.dim], # type: ignore[union-attr, index]
|
||||
).clone().detach().contiguous()
|
||||
|
||||
# Sync requires_grad to local_shard.
|
||||
local_shard.requires_grad = tensor.requires_grad
|
||||
|
||||
# Create ShardedTensor based on local shards.
|
||||
local_shards = [
|
||||
Shard(
|
||||
tensor=local_shard,
|
||||
metadata=local_metadata, # type: ignore[arg-type]
|
||||
)
|
||||
]
|
||||
|
||||
st = ShardedTensor._init_from_local_shards(local_shards, tensor.size(), process_group=pg)
|
||||
|
||||
# Manually set sharding_spec
|
||||
st._sharding_spec = sharding_spec
|
||||
|
||||
# Replace param with ShardedTensor.
|
||||
|
||||
# Need to delete the attribute first since param_name might be
|
||||
# torch.nn.Parameter and can't be replaced with ShardedTensor which is
|
||||
# not torch.nn.Parameter.
|
||||
delattr(module, param_name)
|
||||
|
||||
# Now we can set the attribute appropriately.
|
||||
setattr(module, param_name, st)
|
||||
|
||||
def sharded_op_impl(func):
|
||||
"""
|
||||
Provides a way for users to write their own custom sharded operator. This
|
||||
can be used to override existing ShardedTensor operators or write a new
|
||||
one not supported by ShardedTensor. If the operator in question is covered
|
||||
by ``__torch_function__`` dispatch and has a ShardedTensor as any of its
|
||||
parameters, the function provided will be invoked for that operator.
|
||||
|
||||
Example::
|
||||
>>> @custom_sharded_op(torch.nn.functional.linear)
|
||||
>>> def my_custom_sharded_linear(types, args, kwargs, process_group):
|
||||
>>> ....
|
||||
>>>
|
||||
>>> input = torch.rand(10, 32)
|
||||
>>> weight = _sharded_tensor.rand(32, 16)
|
||||
>>> bias = torch.rand(16)
|
||||
>>> # This will call 'my_custom_sharded_linear'
|
||||
>>> torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
The types, args and kwargs parameters are the same parameters that are
|
||||
passed to ``__torch_function__`` dispatch API
|
||||
(https://pytorch.org/docs/stable/notes/extending.html#extending-torch).
|
||||
There is an additional ``process_group`` parameter which is the
|
||||
process_group used for the ShardedTensor and can be used by
|
||||
implementations for communications within a sharded implementation.
|
||||
|
||||
Args:
|
||||
func(Callable): Torch function for which we want to provide a sharded
|
||||
implementation (ex: torch.nn.functional.linear)
|
||||
"""
|
||||
def decorator_sharded_func(wrapped_func):
|
||||
_register_sharded_op(func, wrapped_func)
|
||||
|
||||
@functools.wraps(wrapped_func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return wrapped_func(*args, **kwargs)
|
||||
return wrapper
|
||||
return decorator_sharded_func
|
||||
|
||||
# Import all builtin sharded ops
|
||||
from .ops import * # noqa: F403
|
||||
sys.modules['torch.distributed._sharded_tensor'] = torch.distributed._shard.sharded_tensor
|
||||
|
|
|
|||
|
|
@ -1,8 +1,12 @@
|
|||
from .api import (
|
||||
ChunkShardingSpec,
|
||||
DevicePlacementSpec,
|
||||
EnumerableShardingSpec,
|
||||
PlacementSpec,
|
||||
ShardMetadata,
|
||||
ShardingSpec,
|
||||
# Keep old package for BC purposes, this file should be removed once
|
||||
# everything moves to the `torch.distributed._shard` package.
|
||||
import sys
|
||||
import torch
|
||||
import warnings
|
||||
|
||||
from torch.distributed._shard.sharding_spec import * # noqa: F403
|
||||
warnings.warn(
|
||||
"torch.distributed._sharding_spec will be deprecated, use torch.distributed._shard.sharding_spec instead",
|
||||
DeprecationWarning
|
||||
)
|
||||
sys.modules['torch.distributed._sharding_spec'] = torch.distributed._shard.sharding_spec
|
||||
|
|
|
|||
0
torch/testing/_internal/distributed/_shard/__init__.py
Normal file
0
torch/testing/_internal/distributed/_shard/__init__.py
Normal file
|
|
@ -1,7 +1,7 @@
|
|||
from torch.distributed._sharding_spec import (
|
||||
from torch.distributed._shard.sharding_spec import (
|
||||
ChunkShardingSpec,
|
||||
)
|
||||
from torch.distributed._sharding_spec._internals import (
|
||||
from torch.distributed._shard.sharding_spec._internals import (
|
||||
get_chunked_dim_size,
|
||||
get_split_size,
|
||||
)
|
||||
Loading…
Reference in a new issue