[reland] Create torch.distributed._shard package. (#72141)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72141

We have many sharding components currently:
torch.distributed._sharded_tensor, torch.distributed._sharding_spec,
torch.distributed._sharded_optimizer and more coming.

As a result, organizing all of this under the `torch.distributed._shard`
package. For BC reasons, I'm still keeping the old packages and have them just
reference the new package.
ghstack-source-id: 148150861
ghstack-source-id: 148150861

Test Plan: waitforbuildbot

Reviewed By: fduwjj

Differential Revision: D33904585

fbshipit-source-id: 057e847eb7521b536a3ee4e0f94871aacc752062
(cherry picked from commit 29a70dd7afde6083bab942081020a13278f38e52)
This commit is contained in:
Pritam Damania 2022-02-01 22:53:18 -08:00 committed by PyTorch MergeBot
parent 7b014cc645
commit 64670e414e
33 changed files with 927 additions and 882 deletions

View file

@ -2,13 +2,13 @@
import torch
import torch.optim as optim
import torch.distributed._sharded_tensor as sharded_tensor
import torch.distributed._shard.sharded_tensor
from copy import deepcopy
from torch.distributed._sharding_spec import (
from torch.distributed._shard.sharding_spec import (
ChunkShardingSpec,
)
from torch.distributed._sharded_optim import (
from torch.distributed._shard.sharded_optim import (
ShardedOptimizer,
named_params_with_sharded_tensor
)
@ -20,7 +20,7 @@ from torch.testing._internal.common_utils import (
run_tests,
)
from torch.testing._internal.distributed._sharded_tensor import (
from torch.testing._internal.distributed._shard.sharded_tensor import (
ShardedTensorTestBase,
with_comms,
)

View file

@ -4,17 +4,17 @@ import sys
import torch
import torch.distributed as dist
from torch.distributed import _sharded_tensor
from torch.distributed._shard import sharded_tensor
from torch.distributed.distributed_c10d import _get_default_group
from torch.testing._internal.common_distributed import (
requires_nccl,
skip_if_lt_x_gpu,
)
from torch.distributed._sharding_spec import (
from torch.distributed._shard.sharding_spec import (
ChunkShardingSpec,
)
from torch.testing._internal.distributed._sharded_tensor import (
from torch.testing._internal.distributed._shard.sharded_tensor import (
ShardedTensorTestBase,
with_comms,
)
@ -35,9 +35,9 @@ class TestShardedTensorBinaryOps(ShardedTensorTestBase):
pg1 = _get_default_group() if pg1 is None else pg1
pg2 = _get_default_group() if pg2 is None else pg2
torch.manual_seed(TestShardedTensorBinaryOps.seed)
st1 = _sharded_tensor.rand(spec1, sizes, process_group=pg1)
st1 = sharded_tensor.rand(spec1, sizes, process_group=pg1)
torch.manual_seed(TestShardedTensorBinaryOps.seed + seed_offset)
st2 = _sharded_tensor.rand(spec2, sizes, process_group=pg2)
st2 = sharded_tensor.rand(spec2, sizes, process_group=pg2)
TestShardedTensorBinaryOps.seed += 1
return st1, st2
@ -72,23 +72,23 @@ class TestShardedTensorBinaryOps(ShardedTensorTestBase):
torch.nn.init.uniform_(st1.local_shards()[0].tensor)
self.assertFalse(cmp_op(st1, st2))
st1 = _sharded_tensor.ones(spec, 10, 10)
st2 = _sharded_tensor.ones(spec, 10, 5)
st1 = sharded_tensor.ones(spec, 10, 10)
st2 = sharded_tensor.ones(spec, 10, 5)
self.assertFalse(cmp_op(st1, st2))
st1, st2 = self.get_random_tensors(spec, alt_spec, 10, 10)
self.assertFalse(cmp_op(st1, st2))
st1 = _sharded_tensor.ones(spec, 10, 10)
st2 = _sharded_tensor.zeros(spec, 10, 10)
st1 = sharded_tensor.ones(spec, 10, 10)
st2 = sharded_tensor.zeros(spec, 10, 10)
self.assertFalse(cmp_op(st1, st2))
st1 = _sharded_tensor.ones(spec, 10, 10)
st2 = _sharded_tensor.ones(spec, 10, 10, dtype=torch.double)
st1 = sharded_tensor.ones(spec, 10, 10)
st2 = sharded_tensor.ones(spec, 10, 10, dtype=torch.double)
self.assertFalse(cmp_op(st1, st2))
st1 = _sharded_tensor.ones(spec, 10, 10)
st2 = _sharded_tensor.ones(spec, 10, 10, requires_grad=True)
st1 = sharded_tensor.ones(spec, 10, 10)
st2 = sharded_tensor.ones(spec, 10, 10, requires_grad=True)
self.assertFalse(cmp_op(st1, st2))
cpu_spec = ChunkShardingSpec(
@ -100,8 +100,8 @@ class TestShardedTensorBinaryOps(ShardedTensorTestBase):
"rank:3/cpu",
],
)
st1 = _sharded_tensor.ones(cpu_spec, 10, 10)
st2 = _sharded_tensor.ones(cpu_spec, 10, 10, pin_memory=True)
st1 = sharded_tensor.ones(cpu_spec, 10, 10)
st2 = sharded_tensor.ones(cpu_spec, 10, 10, pin_memory=True)
self.assertFalse(cmp_op(st1, st2))
pg = dist.new_group([1, 0, 3, 2])
@ -149,7 +149,7 @@ class TestShardedTensorBinaryOps(ShardedTensorTestBase):
# compare different arrays
st1, st2 = self.get_random_tensors(spec, spec, 10, 10, seed_offset=1)
self.assertFalse(torch.allclose(st1, st2))
# _sharded_tensor.rand produces uniform values in the [0,1] range.
# sharded_tensor.rand produces uniform values in the [0,1] range.
self.assertTrue(torch.allclose(st1, st2, atol=1))
if __name__ == '__main__':

View file

@ -4,7 +4,7 @@ import sys
import torch
import torch.distributed as dist
from torch.distributed._sharded_tensor import (
from torch.distributed._shard import (
shard_parameter,
)
from torch.testing._internal.common_distributed import (
@ -15,12 +15,12 @@ from torch.testing._internal.common_utils import (
TEST_WITH_DEV_DBG_ASAN,
run_tests,
)
from torch.testing._internal.distributed._sharded_tensor import (
from torch.testing._internal.distributed._shard.sharded_tensor import (
TEST_GPU_NUM,
ShardedTensorTestBase,
with_comms,
)
from torch.testing._internal.distributed._sharded_tensor._test_ops_common import (
from torch.testing._internal.distributed._shard.sharded_tensor._test_ops_common import (
generate_chunk_sharding_specs_for_test,
generate_local_weight_sharding_params_for_test,
)

View file

@ -4,7 +4,7 @@ import sys
import torch
import torch.distributed as dist
from torch.distributed._sharded_tensor import (
from torch.distributed._shard import (
shard_parameter,
)
from torch.testing._internal.common_distributed import (
@ -15,12 +15,12 @@ from torch.testing._internal.common_utils import (
TEST_WITH_DEV_DBG_ASAN,
run_tests,
)
from torch.testing._internal.distributed._sharded_tensor import (
from torch.testing._internal.distributed._shard.sharded_tensor import (
TEST_GPU_NUM,
ShardedTensorTestBase,
with_comms,
)
from torch.testing._internal.distributed._sharded_tensor._test_ops_common import (
from torch.testing._internal.distributed._shard.sharded_tensor._test_ops_common import (
generate_chunk_sharding_specs_for_test,
generate_local_weight_sharding_params_for_test,
)

View file

@ -3,15 +3,15 @@
import sys
import torch
from torch.distributed import _sharded_tensor
from torch.distributed._sharding_spec import (
from torch.distributed._shard import sharded_tensor
from torch.distributed._shard.sharding_spec import (
ChunkShardingSpec,
)
from torch.testing._internal.common_distributed import (
requires_nccl,
skip_if_lt_x_gpu,
)
from torch.testing._internal.distributed._sharded_tensor import (
from torch.testing._internal.distributed._shard.sharded_tensor import (
ShardedTensorTestBase,
with_comms,
)
@ -50,17 +50,17 @@ class TestShardedTensorNNInit(ShardedTensorTestBase):
seed = 1234
dtype = torch.double
sharded_tensor = _sharded_tensor.empty(spec, h, w, dtype=dtype)
self.assertEqual(1, len(sharded_tensor.local_shards()))
st = sharded_tensor.empty(spec, h, w, dtype=dtype)
self.assertEqual(1, len(st.local_shards()))
# Clone local tensor to ensure torch.nn.init starts from the same input
local_tensor_clone = torch.clone(sharded_tensor.local_shards()[0].tensor)
local_tensor_clone = torch.clone(st.local_shards()[0].tensor)
torch.manual_seed(seed)
torch.nn.init.uniform_(sharded_tensor, a=a, b=b)
torch.nn.init.uniform_(st, a=a, b=b)
torch.manual_seed(seed)
torch.nn.init.uniform_(local_tensor_clone, a=a, b=b)
self.assertEqual(local_tensor_clone, sharded_tensor.local_shards()[0].tensor)
self.assertEqual(local_tensor_clone, st.local_shards()[0].tensor)
@with_comms
@skip_if_lt_x_gpu(4)
@ -85,17 +85,17 @@ class TestShardedTensorNNInit(ShardedTensorTestBase):
seed = 1234
dtype = torch.double
sharded_tensor = _sharded_tensor.empty(spec, h, w, dtype=dtype)
self.assertEqual(1, len(sharded_tensor.local_shards()))
st = sharded_tensor.empty(spec, h, w, dtype=dtype)
self.assertEqual(1, len(st.local_shards()))
# Clone local tensor to ensure torch.nn.init starts from the same input
local_tensor_clone = torch.clone(sharded_tensor.local_shards()[0].tensor)
local_tensor_clone = torch.clone(st.local_shards()[0].tensor)
torch.manual_seed(seed)
torch.nn.init.normal_(sharded_tensor, mean=mean, std=std)
torch.nn.init.normal_(st, mean=mean, std=std)
torch.manual_seed(seed)
torch.nn.init.normal_(local_tensor_clone, mean=mean, std=std)
self.assertEqual(local_tensor_clone, sharded_tensor.local_shards()[0].tensor)
self.assertEqual(local_tensor_clone, st.local_shards()[0].tensor)
@with_comms
@skip_if_lt_x_gpu(4)
@ -120,17 +120,17 @@ class TestShardedTensorNNInit(ShardedTensorTestBase):
seed = 1234
dtype = torch.double
sharded_tensor = _sharded_tensor.empty(spec, h, w, dtype=dtype)
self.assertEqual(1, len(sharded_tensor.local_shards()))
st = sharded_tensor.empty(spec, h, w, dtype=dtype)
self.assertEqual(1, len(st.local_shards()))
# Clone local tensor to ensure torch.nn.init starts from the same input
local_tensor_clone = torch.clone(sharded_tensor.local_shards()[0].tensor)
local_tensor_clone = torch.clone(st.local_shards()[0].tensor)
torch.manual_seed(seed)
torch.nn.init.kaiming_uniform_(sharded_tensor, a=a, mode=mode, nonlinearity=nonlinearity)
torch.nn.init.kaiming_uniform_(st, a=a, mode=mode, nonlinearity=nonlinearity)
torch.manual_seed(seed)
torch.nn.init.kaiming_uniform_(local_tensor_clone, a=a, mode=mode, nonlinearity=nonlinearity)
self.assertEqual(local_tensor_clone, sharded_tensor.local_shards()[0].tensor)
self.assertEqual(local_tensor_clone, st.local_shards()[0].tensor)
if __name__ == '__main__':
run_tests()

View file

@ -4,16 +4,16 @@ import sys
import torch
import torch.distributed as dist
from torch.distributed._sharded_tensor import (
from torch.distributed._shard import shard_parameter
from torch.distributed._shard.sharded_tensor import (
empty,
shard_parameter,
)
from torch.distributed._sharding_spec import (
from torch.distributed._shard.sharding_spec import (
ChunkShardingSpec,
EnumerableShardingSpec,
ShardMetadata
)
from torch.distributed._sharded_optim import (
from torch.distributed._shard.sharded_optim import (
ShardedOptimizer,
named_params_with_sharded_tensor,
)
@ -25,12 +25,12 @@ from torch.testing._internal.common_utils import (
TEST_WITH_DEV_DBG_ASAN,
run_tests,
)
from torch.testing._internal.distributed._sharded_tensor import (
from torch.testing._internal.distributed._shard.sharded_tensor import (
TEST_GPU_NUM,
ShardedTensorTestBase,
with_comms,
)
from torch.testing._internal.distributed._sharded_tensor._test_ops_common import (
from torch.testing._internal.distributed._shard.sharded_tensor._test_ops_common import (
generate_chunk_sharding_specs_for_test,
generate_local_weight_sharding_params_for_test,
)

View file

@ -2,13 +2,13 @@
import torch
from torch.testing._internal.common_utils import TestCase
from torch.distributed._sharding_spec import (
from torch.distributed._shard.sharding_spec import (
ChunkShardingSpec,
DevicePlacementSpec,
EnumerableShardingSpec,
ShardMetadata,
)
from torch.distributed._sharding_spec._internals import (
from torch.distributed._shard.sharding_spec._internals import (
check_tensor,
get_split_size,
get_chunked_dim_size,

View file

@ -202,13 +202,14 @@ WINDOWS_BLOCKLIST = [
"distributed/pipeline/sync/test_worker",
"distributed/elastic/agent/server/test/api_test",
"distributed/elastic/multiprocessing/api_test",
"distributed/_sharded_tensor/test_sharded_tensor",
"distributed/_sharded_tensor/ops/test_embedding",
"distributed/_sharded_tensor/ops/test_embedding_bag",
"distributed/_sharded_tensor/ops/test_binary_cmp",
"distributed/_sharded_tensor/ops/test_init",
"distributed/_sharded_tensor/ops/test_linear",
"distributed/_sharded_optim/test_sharded_optim",
"distributed/_shard/sharding_spec/test_sharding_spec",
"distributed/_shard/sharded_tensor/test_sharded_tensor",
"distributed/_shard/sharded_tensor/ops/test_embedding",
"distributed/_shard/sharded_tensor/ops/test_embedding_bag",
"distributed/_shard/sharded_tensor/ops/test_binary_cmp",
"distributed/_shard/sharded_tensor/ops/test_init",
"distributed/_shard/sharded_tensor/ops/test_linear",
"distributed/_shard/sharded_optim/test_sharded_optim",
] + FSDP_TEST + FX2TRT_TESTS
ROCM_BLOCKLIST = [
@ -216,13 +217,13 @@ ROCM_BLOCKLIST = [
"distributed/rpc/test_faulty_agent",
"distributed/rpc/test_tensorpipe_agent",
"distributed/rpc/cuda/test_tensorpipe_agent",
"distributed/_sharded_tensor/test_sharded_tensor",
"distributed/_sharded_tensor/ops/test_embedding",
"distributed/_sharded_tensor/ops/test_embedding_bag",
"distributed/_sharded_tensor/ops/test_binary_cmp",
"distributed/_sharded_tensor/ops/test_init",
"distributed/_sharded_tensor/ops/test_linear",
"distributed/_sharded_optim/test_sharded_optim",
"distributed/_shard/sharded_tensor/test_sharded_tensor",
"distributed/_shard/sharded_tensor/ops/test_embedding",
"distributed/_shard/sharded_tensor/ops/test_embedding_bag",
"distributed/_shard/sharded_tensor/ops/test_binary_cmp",
"distributed/_shard/sharded_tensor/ops/test_init",
"distributed/_shard/sharded_tensor/ops/test_linear",
"distributed/_shard/sharded_optim/test_sharded_optim",
"test_determination",
"test_multiprocessing",
"test_jit_legacy",
@ -354,14 +355,14 @@ DISTRIBUTED_TESTS = [
"distributed/elastic/utils/util_test",
"distributed/elastic/utils/distributed_test",
"distributed/elastic/multiprocessing/api_test",
"distributed/_sharding_spec/test_sharding_spec",
"distributed/_sharded_tensor/test_sharded_tensor",
"distributed/_sharded_tensor/ops/test_embedding",
"distributed/_sharded_tensor/ops/test_embedding_bag",
"distributed/_sharded_tensor/ops/test_binary_cmp",
"distributed/_sharded_tensor/ops/test_init",
"distributed/_sharded_tensor/ops/test_linear",
"distributed/_sharded_optim/test_sharded_optim",
"distributed/_shard/sharding_spec/test_sharding_spec",
"distributed/_shard/sharded_tensor/test_sharded_tensor",
"distributed/_shard/sharded_tensor/ops/test_embedding",
"distributed/_shard/sharded_tensor/ops/test_embedding_bag",
"distributed/_shard/sharded_tensor/ops/test_binary_cmp",
"distributed/_shard/sharded_tensor/ops/test_init",
"distributed/_shard/sharded_tensor/ops/test_linear",
"distributed/_shard/sharded_optim/test_sharded_optim",
] + [test for test in TESTS if test.startswith("distributed/fsdp")]
# Dictionary matching test modules (in TESTS) to lists of test cases (within that test_module) that would be run when

View file

@ -0,0 +1 @@
from .api import shard_parameter

View file

@ -0,0 +1,145 @@
import copy
import torch
import torch.distributed as dist
from torch.distributed import distributed_c10d
from .sharding_spec import (
ChunkShardingSpec,
ShardingSpec,
)
from torch.distributed._shard.sharding_spec._internals import (
get_chunked_dim_size,
get_split_size,
)
from torch.distributed._shard.sharded_tensor import (
Shard,
ShardMetadata,
ShardedTensor,
)
def shard_parameter(
module: torch.nn.Module,
param_name: str,
sharding_spec: ShardingSpec,
src_rank=0,
process_group=None):
"""
Given a :class:`torch.nn.Module`, a ``param_name`` for a parameter in that
module, it shards that parameter according to the provided
``sharding_spec``. ``src_rank`` denotes the source rank which would be
used as the ground truth of the data which would be scattered as shards
across the rest of the ranks.
This method replaces ``module.param_name`` with a
:class:`torch.distributed._shard.sharded_tensor.ShardedTensor`
Args:
module (:class:`torch.nn.Module`): Module whose parameter needs to be sharded.
param_name (str): Name of the parameter of ``module`` that needs to be sharded.
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
describing how to shard the Tensor.
Keyword args:
src_rank (int, optional): The source rank which is used as the ground truth of
the data for the parameter that would be sharded and scattered
across the rest of the ranks.
Default: 0.
process_group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
.. warning::
Only :class:`torch.distributed._shard.sharding_spec.ShardingSpec` is
currently supported as the ``sharding_spec``.
"""
# Perform some validation first.
if not isinstance(sharding_spec, ChunkShardingSpec):
raise ValueError('Only ChunkShardingspec is supported.')
if not hasattr(module, param_name):
raise ValueError(f'module: {module} does not have parameter with name: {param_name}')
tensor = getattr(module, param_name)
if not isinstance(tensor, torch.Tensor):
raise ValueError(f'Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}')
if not tensor.is_contiguous():
raise ValueError(f'param: {param_name} is not a contiguous Tensor')
pg = process_group if process_group is not None else distributed_c10d._get_default_group()
world_size = dist.get_world_size(pg)
rank = dist.get_rank(pg)
# Validate src_rank and sharding_spec are same across all ranks.
gathered_list = [None] * world_size
dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg)
for idx, entry in enumerate(gathered_list):
if src_rank != entry[0]: # type: ignore[index]
raise ValueError(
f'src_rank={src_rank} on rank: {rank} does not ' # type: ignore[index]
f'match with src_rank={entry[0]} on rank: {idx}')
if sharding_spec != entry[1]: # type: ignore[index]
raise ValueError(
f'sharding_spec={sharding_spec} on rank: {rank} does not ' # type: ignore[index]
f'match with sharding_spec={entry[1]} on rank: {idx}')
# Rearrange chunks according to placement.
local_metadata = None
current_offsets = [0] * len(tensor.size())
shards_metadata = []
sharding_dim_size = tensor.size(sharding_spec.dim) # type: ignore[arg-type]
split_size = get_split_size(sharding_dim_size, world_size)
tensor_sizes = list(tensor.size())
for idx, placement in enumerate(sharding_spec.placements):
chunked_dim_size = get_chunked_dim_size(sharding_dim_size, split_size, idx)
shard_size = copy.deepcopy(tensor_sizes)
shard_size[sharding_spec.dim] = chunked_dim_size # type: ignore[index]
shard_metadata = ShardMetadata(
shard_offsets=copy.deepcopy(current_offsets),
shard_sizes=shard_size,
placement=placement,
)
shards_metadata.append(shard_metadata)
if rank == placement.rank(): # type: ignore[union-attr]
local_metadata = shard_metadata
current_offsets[sharding_spec.dim] += chunked_dim_size # type: ignore[index]
# Scatter the shards (use broadcast since NCCL doesn't support scatter, this is very inefficient).
dist.broadcast(tensor, src=src_rank, group=pg)
# Reshape to get shard for this rank and we don't want autograd
# recording here for the narrow op and 'local_shard' should be a
# leaf variable in the autograd graph.
local_shard = tensor.narrow(
sharding_spec.dim, # type: ignore[arg-type]
local_metadata.shard_offsets[sharding_spec.dim], # type: ignore[union-attr, arg-type, index]
local_metadata.shard_sizes[sharding_spec.dim], # type: ignore[union-attr, index]
).clone().detach().contiguous()
# Sync requires_grad to local_shard.
local_shard.requires_grad = tensor.requires_grad
# Create ShardedTensor based on local shards.
local_shards = [
Shard(
tensor=local_shard,
metadata=local_metadata, # type: ignore[arg-type]
)
]
st = ShardedTensor._init_from_local_shards(local_shards, tensor.size(), process_group=pg)
# Manually set sharding_spec
st._sharding_spec = sharding_spec
# Replace param with ShardedTensor.
# Need to delete the attribute first since param_name might be
# torch.nn.Parameter and can't be replaced with ShardedTensor which is
# not torch.nn.Parameter.
delattr(module, param_name)
# Now we can set the attribute appropriately.
setattr(module, param_name, st)

View file

@ -3,7 +3,7 @@ from .api import ShardedOptimizer
import torch.nn as nn
from torch.distributed._sharded_tensor import (
from torch.distributed._shard.sharded_tensor import (
ShardedTensor
)
@ -16,7 +16,7 @@ def named_params_with_sharded_tensor(
r"""Returns an iterator over module parameters (together with the
ShardedTensor parameters), yielding both the name of the parameter
as well as the parameter itself. This is typically passed to a
:class:torch.distributed._sharded_optim.ShardedOptimizer
:class:torch.distributed._shard.sharded_optim.ShardedOptimizer
Args:
prefix (str): prefix to prepend to all parameter names.

View file

@ -2,7 +2,7 @@ from typing import List, Union, Mapping, Dict, Any
import torch.optim as optim
from torch import Tensor
from torch.distributed._sharded_tensor import ShardedTensor
from torch.distributed._shard.sharded_tensor import ShardedTensor
class ShardedOptimizer(optim.Optimizer):

View file

@ -0,0 +1,412 @@
# coding=utf-8
import copy
import functools
import torch
from torch.distributed._shard.sharding_spec import (
ChunkShardingSpec,
ShardingSpec,
)
from torch.distributed._shard.sharding_spec._internals import (
get_chunked_dim_size,
get_split_size,
)
from typing import List
from .api import (
_register_sharded_op,
CreateOp,
Shard,
ShardMetadata,
ShardedTensor,
ShardedTensorMetadata,
TensorInitParams,
TensorProperties,
)
from .utils import load_with_process_group
import torch.distributed as dist
from torch.distributed import distributed_c10d
def empty(sharding_spec: ShardingSpec,
*size,
dtype=None,
layout=torch.strided,
requires_grad=False,
pin_memory=False,
memory_format=torch.contiguous_format,
process_group=None,
init_rrefs=False) -> ShardedTensor:
"""
Returns a :class:`ShardedTensor` filled with uninitialized data.
Needs to be called on all ranks in an SPMD fashion.
Args:
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
describing how to shard the Tensor.
size (int...): a sequence of integers defining the shape of the output
tensor. Can be a variable number of arguments or a collection like a list or tuple.
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).
layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
Default: ``torch.strided``.
requires_grad (bool, optional): If autograd should record operations on the
returned tensor. Default: ``False``.
pin_memory (bool, optional): If set, returned tensor would be allocated in
the pinned memory. Works only for CPU tensors. Default: ``False``.
memory_format (:class:`torch.memory_format`, optional): the desired memory format of
returned Tensor. Default: ``torch.contiguous_format``.
process_group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
init_rrefs (bool, optional): Whether or not to initialize
:class:`torch.distributed.rpc.RRef`s pointing to remote shards.
Need to initialize the RPC Framework if specified as ``True``.
Default: ``False``.
Returns:
A :class:`ShardedTensor` object on each rank
"""
tensor_properties = TensorProperties(dtype=dtype, layout=layout,
requires_grad=requires_grad,
pin_memory=pin_memory, memory_format=memory_format, )
tensor_init_params = TensorInitParams(create_op=CreateOp.EMPTY, tensor_properties=tensor_properties, )
return ShardedTensor(
sharding_spec,
*size,
tensor_init_params=tensor_init_params,
process_group=process_group,
init_rrefs=init_rrefs,
)
def ones(sharding_spec: ShardingSpec,
*size,
dtype=None,
layout=torch.strided,
requires_grad=False,
pin_memory=False,
memory_format=torch.contiguous_format,
process_group=None,
init_rrefs=False) -> ShardedTensor:
"""
Returns a :class:`ShardedTensor` with the scalar value 1.
Needs to be called on all ranks in an SPMD fashion.
Args:
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
describing how to shard the Tensor.
size (int...): a sequence of integers defining the shape of the output
tensor. Can be a variable number of arguments or a collection like a list or tuple.
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).
layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
Default: ``torch.strided``.
requires_grad (bool, optional): If autograd should record operations on the
returned tensor. Default: ``False``.
pin_memory (bool, optional): If set, returned tensor would be allocated in
the pinned memory. Works only for CPU tensors. Default: ``False``.
process_group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
init_rrefs (bool, optional): Whether or not to initialize
:class:`torch.distributed.rpc.RRef`s pointing to remote shards.
Need to initialize the RPC Framework if specified as ``True``.
Default: ``False``.
Returns:
A :class:`ShardedTensor` object on each rank
"""
tensor_properties = TensorProperties(dtype=dtype, layout=layout,
requires_grad=requires_grad,
pin_memory=pin_memory, memory_format=memory_format, )
tensor_init_params = TensorInitParams(create_op=CreateOp.ONES, tensor_properties=tensor_properties)
return ShardedTensor(
sharding_spec,
*size,
tensor_init_params=tensor_init_params,
process_group=process_group,
init_rrefs=init_rrefs,
)
def rand(sharding_spec: ShardingSpec,
*size,
dtype=None,
layout=torch.strided,
requires_grad=False,
pin_memory=False,
memory_format=torch.contiguous_format,
process_group=None,
init_rrefs=False) -> ShardedTensor:
"""
Returns a :class:`ShardedTensor` filled with random numbers from a uniform distribution on the
interval :math:`[0, 1)`. Needs to be called on all ranks in an SPMD fashion.
Args:
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
describing how to shard the Tensor.
size (int...): a sequence of integers defining the shape of the output
tensor. Can be a variable number of arguments or a collection like a list or tuple.
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).
layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
Default: ``torch.strided``.
requires_grad (bool, optional): If autograd should record operations on the
returned tensor. Default: ``False``.
pin_memory (bool, optional): If set, returned tensor would be allocated in
the pinned memory. Works only for CPU tensors. Default: ``False``.
process_group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
init_rrefs (bool, optional): Whether or not to initialize
:class:`torch.distributed.rpc.RRef`s pointing to remote shards.
Need to initialize the RPC Framework if specified as ``True``.
Default: ``False``.
Returns:
A :class:`ShardedTensor` object on each rank
"""
tensor_properties = TensorProperties(
dtype=dtype, layout=layout, requires_grad=requires_grad,
pin_memory=pin_memory, memory_format=memory_format
)
tensor_init_params = TensorInitParams(create_op=CreateOp.RAND, tensor_properties=tensor_properties, )
return ShardedTensor(
sharding_spec,
*size,
tensor_init_params=tensor_init_params,
process_group=process_group,
init_rrefs=init_rrefs,
)
def zeros(sharding_spec: ShardingSpec,
*size,
dtype=None,
layout=torch.strided,
requires_grad=False,
pin_memory=False,
memory_format=torch.contiguous_format,
process_group=None,
init_rrefs=False) -> ShardedTensor:
"""
Returns a :class:`ShardedTensor` filled with the scalar value 0.
Needs to be called on all ranks in an SPMD fashion.
Args:
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
describing how to shard the Tensor.
size (int...): a sequence of integers defining the shape of the output
tensor. Can be a variable number of arguments or a collection like a list or tuple.
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).
layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
Default: ``torch.strided``.
requires_grad (bool, optional): If autograd should record operations on the
returned tensor. Default: ``False``.
pin_memory (bool, optional): If set, returned tensor would be allocated in
the pinned memory. Works only for CPU tensors. Default: ``False``.
process_group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
init_rrefs (bool, optional): Whether or not to initialize
:class:`torch.distributed.rpc.RRef`s pointing to remote shards.
Need to initialize the RPC Framework if specified as ``True``.
Default: ``False``.
Returns:
A :class:`ShardedTensor` object on each rank
"""
tensor_properties = TensorProperties(
dtype=dtype, layout=layout, requires_grad=requires_grad,
pin_memory=pin_memory, memory_format=memory_format,
)
tensor_init_params = TensorInitParams(create_op=CreateOp.ZEROS, tensor_properties=tensor_properties, )
return ShardedTensor(
sharding_spec,
*size,
tensor_init_params=tensor_init_params,
process_group=process_group,
init_rrefs=init_rrefs,
)
def full(sharding_spec: ShardingSpec,
size,
fill_value=torch.types.Number,
dtype=None,
layout=torch.strided,
requires_grad=False,
pin_memory=False,
memory_format=torch.contiguous_format,
process_group=None,
init_rrefs=False) -> ShardedTensor:
"""
Creates a :class:`ShardedTensor` filled with fill_value. The tensors dtype
is inferred from fill_value. If dtype is specified, it will override the
inferred type from fill_value. Needs to be called on all ranks in an SPMD fashion.
Args:
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
describing how to shard the Tensor.
size (int...): a list, tuple, or `torch.Size` of integers defining the shape of the
output tensor.
fill_value (Scalar) the value to fill the output tensor with.
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).
layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
Default: ``torch.strided``.
requires_grad (bool, optional): If autograd should record operations on the
returned tensor. Default: ``False``.
pin_memory (bool, optional): If set, returned tensor would be allocated in
the pinned memory. Works only for CPU tensors. Default: ``False``.
process_group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
init_rrefs (bool, optional): Whether or not to initialize
:class:`torch.distributed.rpc.RRef`s pointing to remote shards.
Need to initialize the RPC Framework if specified as ``True``.
Default: ``False``.
Returns:
A :class:`ShardedTensor` object on each rank
"""
tensor_properties = TensorProperties(
dtype=dtype, layout=layout, requires_grad=requires_grad,
pin_memory=pin_memory, memory_format=memory_format,
)
tensor_init_params = TensorInitParams(
create_op=CreateOp.FULL, fill_value=fill_value, tensor_properties=tensor_properties)
return ShardedTensor(
sharding_spec,
*size,
tensor_init_params=tensor_init_params,
process_group=process_group,
init_rrefs=init_rrefs,
)
def init_from_local_shards(
local_shards: List[Shard],
*global_size,
process_group=None,
init_rrefs=False) -> ShardedTensor:
"""
Creates an :class:`ShardedTensor` from local shards and the global metadata.
Needs to be called on all ranks in an SPMD fashion.
Args:
local_shards (List[:class `torch.distributed._shard.sharded_tensor.Shard`]): A list
of shards that represent the local shards on this rank.
global_size (int...): a list, tuple, or `torch.Size` of integers defining the
shape of the overall sharded tensor.
Keyword args:
process_group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
init_rrefs (bool, optional): Whether or not to initialize
:class:`torch.distributed.rpc.RRef`s pointing to remote shards.
Need to initialize the RPC Framework if specified as ``True``.
Default: ``False``.
Returns:
A :class:`ShardedTensor` object handle on this rank
Examples:
Suppose we want construct a sharded tensor on two ranks, global size = (10, 5),
each shard have a (5, 5) local tensor, we can do it like below:
on rank 0:
>>> local_shard_metadata = ShardMetadata(
>>> shard_offsets=[0, 0]
>>> shard_lengths=[5, 5]
>>> placement="rank:0/cuda:0"
>>> )
>>> local_shards = [Shard(torch.randn(5, 5), local_shard_metadata)]
>>> sharded_tensor = init_from_local_shards(local_shards, [10, 5])
on rank 1:
>>> local_shard_metadata = ShardMetadata(
>>> shard_offsets=[5, 0]
>>> shard_lengths=[5, 5]
>>> placement="rank:1/cuda:1"
>>> )
>>> local_shards = [Shard(torch.randn(5, 5), local_shard_metadata)]
>>> sharded_tensor = init_from_local_shards(local_shards, [10, 5])
"""
return ShardedTensor._init_from_local_shards(
local_shards,
*global_size,
process_group=process_group,
init_rrefs=init_rrefs
)
def state_dict_hook(module, destination, prefix, local_metadata):
"""
Hook to add ShardedTensor to Module's ``state_dict``. Needs to be
registered to the Module using
:meth:`torch.nn.Module._register_state_dict_hook`.
"""
for submodule_name, submodule in module.named_modules():
for attr_name, attr in submodule.__dict__.items():
if isinstance(attr, ShardedTensor):
destination[prefix + submodule_name + '.' + attr_name] = attr
def pre_load_state_dict_hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
"""
Pre-load state dict hook to add ShardedTensor to the module.
"""
for submodule_name, submodule in module.named_modules():
for attr_name, attr in submodule.__dict__.items():
key = prefix + submodule_name + '.' + attr_name
if key in state_dict:
if isinstance(state_dict[key], ShardedTensor):
setattr(submodule, attr_name, state_dict[key])
def sharded_op_impl(func):
"""
Provides a way for users to write their own custom sharded operator. This
can be used to override existing ShardedTensor operators or write a new
one not supported by ShardedTensor. If the operator in question is covered
by ``__torch_function__`` dispatch and has a ShardedTensor as any of its
parameters, the function provided will be invoked for that operator.
Example::
>>> @custom_sharded_op(torch.nn.functional.linear)
>>> def my_custom_sharded_linear(types, args, kwargs, process_group):
>>> ....
>>>
>>> input = torch.rand(10, 32)
>>> weight = sharded_tensor.rand(32, 16)
>>> bias = torch.rand(16)
>>> # This will call 'my_custom_sharded_linear'
>>> torch.nn.functional.linear(input, weight, bias)
The types, args and kwargs parameters are the same parameters that are
passed to ``__torch_function__`` dispatch API
(https://pytorch.org/docs/stable/notes/extending.html#extending-torch).
There is an additional ``process_group`` parameter which is the
process_group used for the ShardedTensor and can be used by
implementations for communications within a sharded implementation.
Args:
func(Callable): Torch function for which we want to provide a sharded
implementation (ex: torch.nn.functional.linear)
"""
def decorator_sharded_func(wrapped_func):
_register_sharded_op(func, wrapped_func)
@functools.wraps(wrapped_func)
def wrapper(*args, **kwargs):
return wrapped_func(*args, **kwargs)
return wrapper
return decorator_sharded_func
# Import all builtin sharded ops
from ._ops import * # noqa: F403

View file

@ -4,7 +4,7 @@ from typing import List
import torch
import torch.distributed as dist
from torch.distributed._sharding_spec._internals import (
from torch.distributed._shard.sharding_spec._internals import (
get_split_size,
get_chunked_dim_size,
)

View file

@ -1,7 +1,7 @@
import torch
import torch.distributed as dist
import torch.distributed.distributed_c10d as distributed_c10d
from torch.distributed._sharded_tensor import (
from torch.distributed._shard.sharded_tensor import (
ShardedTensor,
sharded_op_impl
)

View file

@ -4,14 +4,14 @@ from typing import cast
import torch
import torch.distributed as dist
from torch.distributed._sharded_tensor.ops._common import (
from ._common import (
_communicate_size_to_each_rank,
_handle_col_wise_sharding_base,
_handle_row_wise_lookup_distribute,
_handle_max_norm_col_wise,
)
from torch.distributed._sharding_spec import ChunkShardingSpec
from torch.distributed._sharded_tensor import (
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
from torch.distributed._shard.sharded_tensor import (
sharded_op_impl,
ShardedTensor
)

View file

@ -7,15 +7,15 @@ import torch.distributed as dist
from torch._C._distributed_c10d import (
ReduceOp,
)
from torch.distributed._sharded_tensor.ops._common import (
from ._common import (
_communicate_list_to_each_rank,
_communicate_size_to_each_rank,
_handle_col_wise_sharding_base,
_handle_row_wise_lookup_distribute,
_handle_max_norm_col_wise,
)
from torch.distributed._sharding_spec import ChunkShardingSpec
from torch.distributed._sharded_tensor import (
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
from torch.distributed._shard.sharded_tensor import (
sharded_op_impl,
ShardedTensor
)

View file

@ -1,5 +1,5 @@
import torch
from torch.distributed._sharded_tensor import (
from torch.distributed._shard.sharded_tensor import (
sharded_op_impl,
)

View file

@ -2,11 +2,11 @@ from typing import List, cast
import torch
import torch.distributed as dist
from torch.distributed._sharded_tensor.ops._common import (
from ._common import (
_handle_col_wise_sharding_base,
)
from torch.distributed._sharding_spec import ChunkShardingSpec
from torch.distributed._sharding_spec._internals import (
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
from torch.distributed._shard.sharding_spec._internals import (
get_split_size,
get_chunked_dim_size,
)
@ -15,7 +15,7 @@ from torch.distributed.nn.functional import (
reduce_scatter,
)
from torch.distributed._sharded_tensor import (
from torch.distributed._shard.sharded_tensor import (
sharded_op_impl,
ShardedTensor
)

View file

@ -14,13 +14,13 @@ import torch
import torch.distributed as dist
from torch.distributed import rpc
from torch.distributed import distributed_c10d
from torch.distributed._sharding_spec import (
from torch.distributed._shard.sharding_spec import (
ChunkShardingSpec,
EnumerableShardingSpec,
ShardMetadata,
ShardingSpec,
)
from torch.distributed._sharding_spec._internals import (
from torch.distributed._shard.sharding_spec._internals import (
check_tensor,
get_split_size,
get_chunked_dim_size,
@ -108,13 +108,13 @@ class ShardedTensor(object):
ShardedTensor doesn't provide any Tensor like operations but is a wrapper
providing the Tensor representing the local shard and the global metadata.
Using these, users can build their custom distributed sharded computations
Using these, users can build their custom distributed._sharded computations
on top of this primitive. The local shards are all initialized using the
create_op specified by tensor_init_params.create_op, e.g., torch.ones, or
torch.empty
Args:
sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
describing how to shard the Tensor.
size (int...): a sequence of integers defining the shape of the output
tensor. Can be a variable number of arguments or a collection like a list or tuple.
@ -138,7 +138,7 @@ class ShardedTensor(object):
def __new__(cls, *args, **kwargs):
# Use __new__ for logging purposes.
torch._C._log_api_usage_once("torch.distributed.sharded_tensor")
torch._C._log_api_usage_once("torch.distributed._shard.sharded_tensor")
return super(ShardedTensor, cls).__new__(cls)
def __init__(

View file

@ -3,7 +3,7 @@ from enum import Enum
from typing import List
import torch
from torch.distributed._sharding_spec import ShardMetadata
from torch.distributed._shard.sharding_spec import ShardMetadata
class MEM_FORMAT_ENCODING(Enum):

View file

@ -2,7 +2,7 @@ from dataclasses import dataclass
from typing import List, cast
import torch
from torch.distributed._sharding_spec import ShardMetadata
from torch.distributed._shard.sharding_spec import ShardMetadata
from torch.distributed.remote_device import _remote_device
@ -14,7 +14,7 @@ class Shard(object):
Args:
tensor(torch.Tensor): Local tensor for the shard.
metadata(:class `torch.distributed._sharded_tensor.ShardMetadata`):
metadata(:class `torch.distributed._shard.sharded_tensor.ShardMetadata`):
The metadata for the shard, including offsets, lengths and device placement.
"""
__slots__ = ['tensor', 'metadata']

View file

@ -5,10 +5,10 @@ from typing import Optional, List, Sequence
import torch
from torch.distributed import distributed_c10d
from torch.distributed import rpc
from torch.distributed._sharding_spec import (
from torch.distributed._shard.sharding_spec import (
ShardMetadata,
)
from torch.distributed._sharding_spec._internals import (
from torch.distributed._shard.sharding_spec._internals import (
check_tensor,
validate_non_overlapping_shards_metadata,
)

View file

@ -0,0 +1,8 @@
from .api import (
ChunkShardingSpec,
DevicePlacementSpec,
EnumerableShardingSpec,
PlacementSpec,
ShardMetadata,
ShardingSpec,
)

View file

@ -1,540 +1,12 @@
# coding=utf-8
import copy
import functools
# Keep old package for BC purposes, this file should be removed once
# everything moves to the `torch.distributed._shard` package.
import sys
import torch
from torch.distributed._sharding_spec import (
ChunkShardingSpec,
ShardingSpec,
import warnings
from torch.distributed._shard.sharded_tensor import * # noqa: F403
warnings.warn(
"torch.distributed._sharded_tensor will be deprecated, use torch.distributed._shard.sharded_tensor instead",
DeprecationWarning
)
from torch.distributed._sharding_spec._internals import (
get_chunked_dim_size,
get_split_size,
)
from typing import List
from .api import (
_register_sharded_op,
CreateOp,
Shard,
ShardMetadata,
ShardedTensor,
ShardedTensorMetadata,
TensorInitParams,
TensorProperties,
)
from .utils import load_with_process_group
import torch.distributed as dist
from torch.distributed import distributed_c10d
def empty(sharding_spec: ShardingSpec,
*size,
dtype=None,
layout=torch.strided,
requires_grad=False,
pin_memory=False,
memory_format=torch.contiguous_format,
process_group=None,
init_rrefs=False) -> ShardedTensor:
"""
Returns a :class:`ShardedTensor` filled with uninitialized data.
Needs to be called on all ranks in an SPMD fashion.
Args:
sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification
describing how to shard the Tensor.
size (int...): a sequence of integers defining the shape of the output
tensor. Can be a variable number of arguments or a collection like a list or tuple.
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).
layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
Default: ``torch.strided``.
requires_grad (bool, optional): If autograd should record operations on the
returned tensor. Default: ``False``.
pin_memory (bool, optional): If set, returned tensor would be allocated in
the pinned memory. Works only for CPU tensors. Default: ``False``.
memory_format (:class:`torch.memory_format`, optional): the desired memory format of
returned Tensor. Default: ``torch.contiguous_format``.
process_group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
init_rrefs (bool, optional): Whether or not to initialize
:class:`torch.distributed.rpc.RRef`s pointing to remote shards.
Need to initialize the RPC Framework if specified as ``True``.
Default: ``False``.
Returns:
A :class:`ShardedTensor` object on each rank
"""
tensor_properties = TensorProperties(dtype=dtype, layout=layout,
requires_grad=requires_grad,
pin_memory=pin_memory, memory_format=memory_format, )
tensor_init_params = TensorInitParams(create_op=CreateOp.EMPTY, tensor_properties=tensor_properties, )
return ShardedTensor(
sharding_spec,
*size,
tensor_init_params=tensor_init_params,
process_group=process_group,
init_rrefs=init_rrefs,
)
def ones(sharding_spec: ShardingSpec,
*size,
dtype=None,
layout=torch.strided,
requires_grad=False,
pin_memory=False,
memory_format=torch.contiguous_format,
process_group=None,
init_rrefs=False) -> ShardedTensor:
"""
Returns a :class:`ShardedTensor` with the scalar value 1.
Needs to be called on all ranks in an SPMD fashion.
Args:
sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification
describing how to shard the Tensor.
size (int...): a sequence of integers defining the shape of the output
tensor. Can be a variable number of arguments or a collection like a list or tuple.
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).
layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
Default: ``torch.strided``.
requires_grad (bool, optional): If autograd should record operations on the
returned tensor. Default: ``False``.
pin_memory (bool, optional): If set, returned tensor would be allocated in
the pinned memory. Works only for CPU tensors. Default: ``False``.
process_group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
init_rrefs (bool, optional): Whether or not to initialize
:class:`torch.distributed.rpc.RRef`s pointing to remote shards.
Need to initialize the RPC Framework if specified as ``True``.
Default: ``False``.
Returns:
A :class:`ShardedTensor` object on each rank
"""
tensor_properties = TensorProperties(dtype=dtype, layout=layout,
requires_grad=requires_grad,
pin_memory=pin_memory, memory_format=memory_format, )
tensor_init_params = TensorInitParams(create_op=CreateOp.ONES, tensor_properties=tensor_properties)
return ShardedTensor(
sharding_spec,
*size,
tensor_init_params=tensor_init_params,
process_group=process_group,
init_rrefs=init_rrefs,
)
def rand(sharding_spec: ShardingSpec,
*size,
dtype=None,
layout=torch.strided,
requires_grad=False,
pin_memory=False,
memory_format=torch.contiguous_format,
process_group=None,
init_rrefs=False) -> ShardedTensor:
"""
Returns a :class:`ShardedTensor` filled with random numbers from a uniform distribution on the
interval :math:`[0, 1)`. Needs to be called on all ranks in an SPMD fashion.
Args:
sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification
describing how to shard the Tensor.
size (int...): a sequence of integers defining the shape of the output
tensor. Can be a variable number of arguments or a collection like a list or tuple.
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).
layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
Default: ``torch.strided``.
requires_grad (bool, optional): If autograd should record operations on the
returned tensor. Default: ``False``.
pin_memory (bool, optional): If set, returned tensor would be allocated in
the pinned memory. Works only for CPU tensors. Default: ``False``.
process_group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
init_rrefs (bool, optional): Whether or not to initialize
:class:`torch.distributed.rpc.RRef`s pointing to remote shards.
Need to initialize the RPC Framework if specified as ``True``.
Default: ``False``.
Returns:
A :class:`ShardedTensor` object on each rank
"""
tensor_properties = TensorProperties(
dtype=dtype, layout=layout, requires_grad=requires_grad,
pin_memory=pin_memory, memory_format=memory_format
)
tensor_init_params = TensorInitParams(create_op=CreateOp.RAND, tensor_properties=tensor_properties, )
return ShardedTensor(
sharding_spec,
*size,
tensor_init_params=tensor_init_params,
process_group=process_group,
init_rrefs=init_rrefs,
)
def zeros(sharding_spec: ShardingSpec,
*size,
dtype=None,
layout=torch.strided,
requires_grad=False,
pin_memory=False,
memory_format=torch.contiguous_format,
process_group=None,
init_rrefs=False) -> ShardedTensor:
"""
Returns a :class:`ShardedTensor` filled with the scalar value 0.
Needs to be called on all ranks in an SPMD fashion.
Args:
sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification
describing how to shard the Tensor.
size (int...): a sequence of integers defining the shape of the output
tensor. Can be a variable number of arguments or a collection like a list or tuple.
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).
layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
Default: ``torch.strided``.
requires_grad (bool, optional): If autograd should record operations on the
returned tensor. Default: ``False``.
pin_memory (bool, optional): If set, returned tensor would be allocated in
the pinned memory. Works only for CPU tensors. Default: ``False``.
process_group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
init_rrefs (bool, optional): Whether or not to initialize
:class:`torch.distributed.rpc.RRef`s pointing to remote shards.
Need to initialize the RPC Framework if specified as ``True``.
Default: ``False``.
Returns:
A :class:`ShardedTensor` object on each rank
"""
tensor_properties = TensorProperties(
dtype=dtype, layout=layout, requires_grad=requires_grad,
pin_memory=pin_memory, memory_format=memory_format,
)
tensor_init_params = TensorInitParams(create_op=CreateOp.ZEROS, tensor_properties=tensor_properties, )
return ShardedTensor(
sharding_spec,
*size,
tensor_init_params=tensor_init_params,
process_group=process_group,
init_rrefs=init_rrefs,
)
def full(sharding_spec: ShardingSpec,
size,
fill_value=torch.types.Number,
dtype=None,
layout=torch.strided,
requires_grad=False,
pin_memory=False,
memory_format=torch.contiguous_format,
process_group=None,
init_rrefs=False) -> ShardedTensor:
"""
Creates a :class:`ShardedTensor` filled with fill_value. The tensors dtype
is inferred from fill_value. If dtype is specified, it will override the
inferred type from fill_value. Needs to be called on all ranks in an SPMD fashion.
Args:
sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification
describing how to shard the Tensor.
size (int...): a list, tuple, or `torch.Size` of integers defining the shape of the
output tensor.
fill_value (Scalar) the value to fill the output tensor with.
Keyword args:
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).
layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
Default: ``torch.strided``.
requires_grad (bool, optional): If autograd should record operations on the
returned tensor. Default: ``False``.
pin_memory (bool, optional): If set, returned tensor would be allocated in
the pinned memory. Works only for CPU tensors. Default: ``False``.
process_group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
init_rrefs (bool, optional): Whether or not to initialize
:class:`torch.distributed.rpc.RRef`s pointing to remote shards.
Need to initialize the RPC Framework if specified as ``True``.
Default: ``False``.
Returns:
A :class:`ShardedTensor` object on each rank
"""
tensor_properties = TensorProperties(
dtype=dtype, layout=layout, requires_grad=requires_grad,
pin_memory=pin_memory, memory_format=memory_format,
)
tensor_init_params = TensorInitParams(
create_op=CreateOp.FULL, fill_value=fill_value, tensor_properties=tensor_properties)
return ShardedTensor(
sharding_spec,
*size,
tensor_init_params=tensor_init_params,
process_group=process_group,
init_rrefs=init_rrefs,
)
def init_from_local_shards(
local_shards: List[Shard],
*global_size,
process_group=None,
init_rrefs=False) -> ShardedTensor:
"""
Creates an :class:`ShardedTensor` from local shards and the global metadata.
Needs to be called on all ranks in an SPMD fashion.
Args:
local_shards (List[:class `torch.distributed._sharded_tensor.Shard`]): A list
of shards that represent the local shards on this rank.
global_size (int...): a list, tuple, or `torch.Size` of integers defining the
shape of the overall sharded tensor.
Keyword args:
process_group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
init_rrefs (bool, optional): Whether or not to initialize
:class:`torch.distributed.rpc.RRef`s pointing to remote shards.
Need to initialize the RPC Framework if specified as ``True``.
Default: ``False``.
Returns:
A :class:`ShardedTensor` object handle on this rank
Examples:
Suppose we want construct a sharded tensor on two ranks, global size = (10, 5),
each shard have a (5, 5) local tensor, we can do it like below:
on rank 0:
>>> local_shard_metadata = ShardMetadata(
>>> shard_offsets=[0, 0]
>>> shard_lengths=[5, 5]
>>> placement="rank:0/cuda:0"
>>> )
>>> local_shards = [Shard(torch.randn(5, 5), local_shard_metadata)]
>>> sharded_tensor = init_from_local_shards(local_shards, [10, 5])
on rank 1:
>>> local_shard_metadata = ShardMetadata(
>>> shard_offsets=[5, 0]
>>> shard_lengths=[5, 5]
>>> placement="rank:1/cuda:1"
>>> )
>>> local_shards = [Shard(torch.randn(5, 5), local_shard_metadata)]
>>> sharded_tensor = init_from_local_shards(local_shards, [10, 5])
"""
return ShardedTensor._init_from_local_shards(
local_shards,
*global_size,
process_group=process_group,
init_rrefs=init_rrefs
)
def state_dict_hook(module, destination, prefix, local_metadata):
"""
Hook to add ShardedTensor to Module's ``state_dict``. Needs to be
registered to the Module using
:meth:`torch.nn.Module._register_state_dict_hook`.
"""
for submodule_name, submodule in module.named_modules():
for attr_name, attr in submodule.__dict__.items():
if isinstance(attr, ShardedTensor):
destination[prefix + submodule_name + '.' + attr_name] = attr
def pre_load_state_dict_hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
"""
Pre-load state dict hook to add ShardedTensor to the module.
"""
for submodule_name, submodule in module.named_modules():
for attr_name, attr in submodule.__dict__.items():
key = prefix + submodule_name + '.' + attr_name
if key in state_dict:
if isinstance(state_dict[key], ShardedTensor):
setattr(submodule, attr_name, state_dict[key])
def shard_parameter(
module: torch.nn.Module,
param_name: str,
sharding_spec: ShardingSpec,
src_rank=0,
process_group=None):
"""
Given a :class:`torch.nn.Module`, a ``param_name`` for a parameter in that
module, it shards that parameter according to the provided
``sharding_spec``. ``src_rank`` denotes the source rank which would be
used as the ground truth of the data which would be scattered as shards
across the rest of the ranks.
This method replaces ``module.param_name`` with a
:class:`torch.distributed._sharded_tensor.ShardedTensor`
Args:
module (:class:`torch.nn.Module`): Module whose parameter needs to be sharded.
param_name (str): Name of the parameter of ``module`` that needs to be sharded.
sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification
describing how to shard the Tensor.
Keyword args:
src_rank (int, optional): The source rank which is used as the ground truth of
the data for the parameter that would be sharded and scattered
across the rest of the ranks.
Default: 0.
process_group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
.. warning::
Only :class:`torch.distributed._sharding_spec.ShardingSpec` is
currently supported as the ``sharding_spec``.
"""
# Perform some validation first.
if not isinstance(sharding_spec, ChunkShardingSpec):
raise ValueError('Only ChunkShardingspec is supported.')
if not hasattr(module, param_name):
raise ValueError(f'module: {module} does not have parameter with name: {param_name}')
tensor = getattr(module, param_name)
if not isinstance(tensor, torch.Tensor):
raise ValueError(f'Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}')
if not tensor.is_contiguous():
raise ValueError(f'param: {param_name} is not a contiguous Tensor')
pg = process_group if process_group is not None else distributed_c10d._get_default_group()
world_size = dist.get_world_size(pg)
rank = dist.get_rank(pg)
# Validate src_rank and sharding_spec are same across all ranks.
gathered_list = [None] * world_size
dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg)
for idx, entry in enumerate(gathered_list):
if src_rank != entry[0]: # type: ignore[index]
raise ValueError(
f'src_rank={src_rank} on rank: {rank} does not ' # type: ignore[index]
f'match with src_rank={entry[0]} on rank: {idx}')
if sharding_spec != entry[1]: # type: ignore[index]
raise ValueError(
f'sharding_spec={sharding_spec} on rank: {rank} does not ' # type: ignore[index]
f'match with sharding_spec={entry[1]} on rank: {idx}')
# Rearrange chunks according to placement.
local_metadata = None
current_offsets = [0] * len(tensor.size())
shards_metadata = []
sharding_dim_size = tensor.size(sharding_spec.dim) # type: ignore[arg-type]
split_size = get_split_size(sharding_dim_size, world_size)
tensor_sizes = list(tensor.size())
for idx, placement in enumerate(sharding_spec.placements):
chunked_dim_size = get_chunked_dim_size(sharding_dim_size, split_size, idx)
shard_size = copy.deepcopy(tensor_sizes)
shard_size[sharding_spec.dim] = chunked_dim_size # type: ignore[index]
shard_metadata = ShardMetadata(
shard_offsets=copy.deepcopy(current_offsets),
shard_sizes=shard_size,
placement=placement,
)
shards_metadata.append(shard_metadata)
if rank == placement.rank(): # type: ignore[union-attr]
local_metadata = shard_metadata
current_offsets[sharding_spec.dim] += chunked_dim_size # type: ignore[index]
# Scatter the shards (use broadcast since NCCL doesn't support scatter, this is very inefficient).
dist.broadcast(tensor, src=src_rank, group=pg)
# Reshape to get shard for this rank and we don't want autograd
# recording here for the narrow op and 'local_shard' should be a
# leaf variable in the autograd graph.
local_shard = tensor.narrow(
sharding_spec.dim, # type: ignore[arg-type]
local_metadata.shard_offsets[sharding_spec.dim], # type: ignore[union-attr, arg-type, index]
local_metadata.shard_sizes[sharding_spec.dim], # type: ignore[union-attr, index]
).clone().detach().contiguous()
# Sync requires_grad to local_shard.
local_shard.requires_grad = tensor.requires_grad
# Create ShardedTensor based on local shards.
local_shards = [
Shard(
tensor=local_shard,
metadata=local_metadata, # type: ignore[arg-type]
)
]
st = ShardedTensor._init_from_local_shards(local_shards, tensor.size(), process_group=pg)
# Manually set sharding_spec
st._sharding_spec = sharding_spec
# Replace param with ShardedTensor.
# Need to delete the attribute first since param_name might be
# torch.nn.Parameter and can't be replaced with ShardedTensor which is
# not torch.nn.Parameter.
delattr(module, param_name)
# Now we can set the attribute appropriately.
setattr(module, param_name, st)
def sharded_op_impl(func):
"""
Provides a way for users to write their own custom sharded operator. This
can be used to override existing ShardedTensor operators or write a new
one not supported by ShardedTensor. If the operator in question is covered
by ``__torch_function__`` dispatch and has a ShardedTensor as any of its
parameters, the function provided will be invoked for that operator.
Example::
>>> @custom_sharded_op(torch.nn.functional.linear)
>>> def my_custom_sharded_linear(types, args, kwargs, process_group):
>>> ....
>>>
>>> input = torch.rand(10, 32)
>>> weight = _sharded_tensor.rand(32, 16)
>>> bias = torch.rand(16)
>>> # This will call 'my_custom_sharded_linear'
>>> torch.nn.functional.linear(input, weight, bias)
The types, args and kwargs parameters are the same parameters that are
passed to ``__torch_function__`` dispatch API
(https://pytorch.org/docs/stable/notes/extending.html#extending-torch).
There is an additional ``process_group`` parameter which is the
process_group used for the ShardedTensor and can be used by
implementations for communications within a sharded implementation.
Args:
func(Callable): Torch function for which we want to provide a sharded
implementation (ex: torch.nn.functional.linear)
"""
def decorator_sharded_func(wrapped_func):
_register_sharded_op(func, wrapped_func)
@functools.wraps(wrapped_func)
def wrapper(*args, **kwargs):
return wrapped_func(*args, **kwargs)
return wrapper
return decorator_sharded_func
# Import all builtin sharded ops
from .ops import * # noqa: F403
sys.modules['torch.distributed._sharded_tensor'] = torch.distributed._shard.sharded_tensor

View file

@ -1,8 +1,12 @@
from .api import (
ChunkShardingSpec,
DevicePlacementSpec,
EnumerableShardingSpec,
PlacementSpec,
ShardMetadata,
ShardingSpec,
# Keep old package for BC purposes, this file should be removed once
# everything moves to the `torch.distributed._shard` package.
import sys
import torch
import warnings
from torch.distributed._shard.sharding_spec import * # noqa: F403
warnings.warn(
"torch.distributed._sharding_spec will be deprecated, use torch.distributed._shard.sharding_spec instead",
DeprecationWarning
)
sys.modules['torch.distributed._sharding_spec'] = torch.distributed._shard.sharding_spec

View file

@ -1,7 +1,7 @@
from torch.distributed._sharding_spec import (
from torch.distributed._shard.sharding_spec import (
ChunkShardingSpec,
)
from torch.distributed._sharding_spec._internals import (
from torch.distributed._shard.sharding_spec._internals import (
get_chunked_dim_size,
get_split_size,
)